Skip to content

Commit

Permalink
feat: trying to set NUMBA_NUM_THREADS before importing numba
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Jan 23, 2025
1 parent 6416d3b commit 60a2d2b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
29 changes: 16 additions & 13 deletions src/qibojit/backends/cpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import os
import sys

import numpy as np
import psutil
from qibo.backends.numpy import NumpyBackend
from qibo.config import log
from qibo.gates.abstract import ParametrizedGate
Expand All @@ -22,15 +26,23 @@
"GeneralizedfSim": "apply_fsim",
}

if os.environ.get("NUMBA_NUM_THREADS") is None:
NTHREADS = (
psutil.cpu_count(logical=False)
if sys.platform == "darwin"
else len(psutil.Process().cpu_affinity())
)

os.environ["NUMBA_NUM_THREADS"] = str(NTHREADS)


class NumbaBackend(NumpyBackend):
def __init__(self):
super().__init__()
# import sys
import sys

# import psutil
import psutil
from numba import __version__ as numba_version
from numba import get_num_threads

from qibojit import __version__ as qibojit_version
from qibojit.custom_operators import gates, ops
Expand Down Expand Up @@ -65,32 +77,23 @@ def __init__(self):
4: self.gates.apply_four_qubit_gate_kernel,
5: self.gates.apply_five_qubit_gate_kernel,
}
"""

if sys.platform == "darwin": # pragma: no cover
self.set_threads(psutil.cpu_count(logical=False))
else:
self.set_threads(len(psutil.Process().cpu_affinity()))
"""
self.nthreads = get_num_threads()

def set_precision(self, precision):
if precision != self.precision:
super().set_precision(precision)
if self.custom_matrices:
self.custom_matrices = CustomMatrices(self.dtype)

"""
def set_threads(self, nthreads):
import numba

numba.set_num_threads(nthreads)
self.nthreads = nthreads
"""

def set_threads(self, nthreads):
raise RuntimeError(
"Unable to change the number threads. Use the global variable ``NUMBA_NUM_THREADS`` instead."
)

# def cast(self, x, dtype=None, copy=False): Inherited from ``NumpyBackend``

Expand Down
15 changes: 1 addition & 14 deletions src/qibojit/custom_operators/ops.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import os
import sys

import numpy as np
import psutil
from numba import njit, prange, set_num_threads

NTHREADS = (
psutil.cpu_count(logical=False)
if sys.platform == "darwin"
else len(psutil.Process().cpu_affinity())
)
MAX_THREADS = os.environ.get("NUMBA_NUM_THREADS")
if MAX_THREADS is not None:
NTHREADS = min(NTHREADS, int(MAX_THREADS))
set_num_threads(NTHREADS)
from numba import njit, prange


@njit(
Expand Down

0 comments on commit 60a2d2b

Please sign in to comment.