Skip to content

Commit

Permalink
feat: moving numba threads setting out of the init
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Jan 23, 2025
1 parent b782805 commit d4d8816
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
11 changes: 8 additions & 3 deletions src/qibojit/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
class NumbaBackend(NumpyBackend):
def __init__(self):
super().__init__()
import sys
# import sys

import psutil
# import psutil
from numba import __version__ as numba_version
from numba import get_num_threads

from qibojit import __version__ as qibojit_version
from qibojit.custom_operators import gates, ops
Expand Down Expand Up @@ -64,23 +65,27 @@ def __init__(self):
4: self.gates.apply_four_qubit_gate_kernel,
5: self.gates.apply_five_qubit_gate_kernel,
}
"""
if sys.platform == "darwin": # pragma: no cover
self.set_threads(psutil.cpu_count(logical=False))
else:
self.set_threads(len(psutil.Process().cpu_affinity()))
"""
self.nthreads = get_num_threads()

def set_precision(self, precision):
if precision != self.precision:
super().set_precision(precision)
if self.custom_matrices:
self.custom_matrices = CustomMatrices(self.dtype)

"""
def set_threads(self, nthreads):
import numba
numba.set_num_threads(nthreads)
self.nthreads = nthreads

"""
# def cast(self, x, dtype=None, copy=False): Inherited from ``NumpyBackend``

# def to_numpy(self, x): Inherited from ``NumpyBackend``
Expand Down
16 changes: 15 additions & 1 deletion src/qibojit/custom_operators/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import os
import sys

import numpy as np
from numba import njit, prange
import psutil
from numba import njit, prange, set_num_threads

NTHREADS = (
psutil.cpu_count(logical=False)
if sys.platform == "darwin"
else len(psutil.Process().cpu_affinity())
)
MAX_THREADS = os.environ.get("NUMBA_NUM_THREADS")
if MAX_THREADS is not None:
NTHREADS = min(NTHREADS, MAX_THREADS)
set_num_threads(NTHREADS)


@njit(
Expand Down

0 comments on commit d4d8816

Please sign in to comment.