From e0ef39b412c92421599fd2609ec1d0d2968dedbe Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Nov 2024 12:47:18 +0100 Subject: [PATCH 01/16] proof of concept of chunkexecutor with thread --- src/spikeinterface/core/job_tools.py | 104 ++++++++++++------ .../core/tests/test_job_tools.py | 25 ++++- 2 files changed, 89 insertions(+), 40 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 7a6172369b..70a4fe2345 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -12,7 +12,7 @@ import sys from tqdm.auto import tqdm -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor import multiprocessing as mp from threadpoolctl import threadpool_limits @@ -329,6 +329,7 @@ def __init__( progress_bar=False, handle_returns=False, gather_func=None, + pool_engine="process", n_jobs=1, total_memory=None, chunk_size=None, @@ -370,6 +371,8 @@ def __init__( self.job_name = job_name self.max_threads_per_process = max_threads_per_process + self.pool_engine = pool_engine + if verbose: chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize total_memory = chunk_memory * self.n_jobs @@ -402,7 +405,7 @@ def run(self, recording_slices=None): if self.n_jobs == 1: if self.progress_bar: - recording_slices = tqdm(recording_slices, ascii=True, desc=self.job_name) + recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices)) worker_ctx = self.init_func(*self.init_args) for segment_index, frame_start, frame_stop in recording_slices: @@ -411,60 +414,89 @@ def run(self, recording_slices=None): returns.append(res) if self.gather_func is not None: self.gather_func(res) + else: n_jobs = min(self.n_jobs, len(recording_slices)) - # parallel - with ProcessPoolExecutor( - max_workers=n_jobs, - initializer=worker_initializer, - mp_context=mp.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), - ) as executor: - results = executor.map(function_wrapper, recording_slices) + if self.pool_engine == "process": + + # parallel + with ProcessPoolExecutor( + max_workers=n_jobs, + initializer=process_worker_initializer, + mp_context=mp.get_context(self.mp_context), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), + ) as executor: + results = executor.map(process_function_wrapper, recording_slices) + + elif self.pool_engine == "thread": + # only one shared context + + worker_dict = self.init_func(*self.init_args) + thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + + with ThreadPoolExecutor( + max_workers=n_jobs, + ) as executor: + results = executor.map(thread_func, recording_slices) + - if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + else: + raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") + + + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + - for res in results: - if self.handle_returns: - returns.append(res) - if self.gather_func is not None: - self.gather_func(res) return returns + +class WorkerFuncWrapper: + def __init__(self, func, worker_dict, max_threads_per_process): + self.func = func + self.worker_dict = worker_dict + self.max_threads_per_process = max_threads_per_process + + def __call__(self, args): + segment_index, start_frame, end_frame = args + if self.max_threads_per_process is None: + return self.func(segment_index, start_frame, end_frame, self.worker_dict) + else: + with threadpool_limits(limits=self.max_threads_per_process): + return self.func(segment_index, start_frame, end_frame, self.worker_dict) + # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool # the tricks is : theses 2 variables are global per worker # so they are not share in the same process -global _worker_ctx -global _func +# global _worker_ctx +# global _func +global _process_func_wrapper -def worker_initializer(func, init_func, init_args, max_threads_per_process): - global _worker_ctx +def process_worker_initializer(func, init_func, init_args, max_threads_per_process): + global _process_func_wrapper if max_threads_per_process is None: - _worker_ctx = init_func(*init_args) + worker_dict = init_func(*init_args) else: with threadpool_limits(limits=max_threads_per_process): - _worker_ctx = init_func(*init_args) - _worker_ctx["max_threads_per_process"] = max_threads_per_process - global _func - _func = func + worker_dict = init_func(*init_args) + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) -def function_wrapper(args): - segment_index, start_frame, end_frame = args - global _func - global _worker_ctx - max_threads_per_process = _worker_ctx["max_threads_per_process"] - if max_threads_per_process is None: - return _func(segment_index, start_frame, end_frame, _worker_ctx) - else: - with threadpool_limits(limits=max_threads_per_process): - return _func(segment_index, start_frame, end_frame, _worker_ctx) +def process_function_wrapper(args): + global _process_func_wrapper + return _process_func_wrapper(args) + # Here some utils copy/paste from DART (Charlie Windolf) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 2f3aff0023..3ed4272af0 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -139,7 +139,7 @@ def __call__(self, res): gathering_func2 = GatherClass() - # chunk + parallel + gather_func + # process + gather_func processor = ChunkRecordingExecutor( recording, func, @@ -148,6 +148,7 @@ def __call__(self, res): verbose=True, progress_bar=True, gather_func=gathering_func2, + pool_engine="process", n_jobs=2, chunk_duration="200ms", job_name="job_name", @@ -157,7 +158,7 @@ def __call__(self, res): assert gathering_func2.pos == num_chunks - # chunk + parallel + spawn + # process spawn processor = ChunkRecordingExecutor( recording, func, @@ -165,6 +166,7 @@ def __call__(self, res): init_args, verbose=True, progress_bar=True, + pool_engine="process", mp_context="spawn", n_jobs=2, chunk_duration="200ms", @@ -172,6 +174,21 @@ def __call__(self, res): ) processor.run() + # thread + processor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + verbose=True, + progress_bar=True, + pool_engine="thread", + n_jobs=2, + chunk_duration="200ms", + job_name="job_name", + ) + processor.run() + def test_fix_job_kwargs(): # test negative n_jobs @@ -224,6 +241,6 @@ def test_split_job_kwargs(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() - test_fix_job_kwargs() + test_ChunkRecordingExecutor() + # test_fix_job_kwargs() # test_split_job_kwargs() From a28c33d5af6f153f1bdfbe2998959ee2139ed250 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Nov 2024 13:12:01 +0100 Subject: [PATCH 02/16] for progress_bar the for res in results need to be inside the with --- src/spikeinterface/core/job_tools.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 70a4fe2345..4e0819d0d9 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -428,6 +428,15 @@ def run(self, recording_slices=None): initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), ) as executor: results = executor.map(process_function_wrapper, recording_slices) + + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) elif self.pool_engine == "thread": # only one shared context @@ -440,19 +449,20 @@ def run(self, recording_slices=None): ) as executor: results = executor.map(thread_func, recording_slices) + if self.progress_bar: + results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + for res in results: + if self.handle_returns: + returns.append(res) + if self.gather_func is not None: + self.gather_func(res) + else: raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") - if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(recording_slices)) - - for res in results: - if self.handle_returns: - returns.append(res) - if self.gather_func is not None: - self.gather_func(res) From 67b055b946d4878249e48ee1ce56ab3ffb765181 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 8 Nov 2024 13:46:56 +0100 Subject: [PATCH 03/16] wip --- src/spikeinterface/core/job_tools.py | 9 ++++----- src/spikeinterface/core/tests/test_job_tools.py | 2 -- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 4e0819d0d9..db23a78b31 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -39,6 +39,7 @@ job_keys = ( + "pool_engine", "n_jobs", "total_memory", "chunk_size", @@ -292,6 +293,8 @@ class ChunkRecordingExecutor: gather_func : None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. + pool_engine : "process" | "thread" + If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor n_jobs : int, default: 1 Number of jobs to be used. Use -1 to use as many jobs as number of cores total_memory : str, default: None @@ -383,6 +386,7 @@ def __init__( print( self.job_name, "\n" + f"engine={self.pool_engine} - " f"n_jobs={self.n_jobs} - " f"samples_per_chunk={self.chunk_size:,} - " f"chunk_memory={chunk_memory_str} - " @@ -458,14 +462,9 @@ def run(self, recording_slices=None): if self.gather_func is not None: self.gather_func(res) - else: raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") - - - - return returns diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 3ed4272af0..c46914ab03 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -97,8 +97,6 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make serializable - recording = recording.save() init_args = "a", 120, "yep" From cecb211b4482af487dd0278fa2fd5e67f2efb0bf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 12 Nov 2024 13:55:49 +0100 Subject: [PATCH 04/16] wip --- src/spikeinterface/core/job_tools.py | 33 +++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index db23a78b31..c514d4c74e 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -14,6 +14,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor import multiprocessing as mp +import threading from threadpoolctl import threadpool_limits @@ -445,13 +446,18 @@ def run(self, recording_slices=None): elif self.pool_engine == "thread": # only one shared context - worker_dict = self.init_func(*self.init_args) - thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + # worker_dict = self.init_func(*self.init_args) + # thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + + thread_data = threading.local() with ThreadPoolExecutor( max_workers=n_jobs, + initializer=thread_worker_initializer, + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_data), ) as executor: - results = executor.map(thread_func, recording_slices) + recording_slices2 = [(thread_data, ) + args for args in recording_slices] + results = executor.map(thread_function_wrapper, recording_slices2) if self.progress_bar: results = tqdm(results, desc=self.job_name, total=len(recording_slices)) @@ -485,7 +491,7 @@ def __call__(self, args): # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : theses 2 variables are global per worker +# the tricks is : thiw variables are global per worker # so they are not share in the same process # global _worker_ctx # global _func @@ -501,11 +507,28 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce worker_dict = init_func(*init_args) _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) - def process_function_wrapper(args): global _process_func_wrapper return _process_func_wrapper(args) +def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_data): + if max_threads_per_process is None: + worker_dict = init_func(*init_args) + else: + with threadpool_limits(limits=max_threads_per_process): + worker_dict = init_func(*init_args) + thread_data._func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + # print("ici", thread_data._func_wrapper) + +def thread_function_wrapper(args): + thread_data = args[0] + args = args[1:] + # thread_data = threading.local() + # print("la", thread_data._func_wrapper) + return thread_data._func_wrapper(args) + + + # Here some utils copy/paste from DART (Charlie Windolf) From 43653213d36d988eb674a42e14596eed94d139a3 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 13:50:35 +0100 Subject: [PATCH 05/16] Move worker_index to job_tools.py --- src/spikeinterface/core/job_tools.py | 108 ++++++++--- .../core/tests/test_job_tools.py | 59 +++++- .../core/tests/test_waveform_tools.py | 16 +- src/spikeinterface/core/waveform_tools.py | 170 ++++++++---------- 4 files changed, 218 insertions(+), 135 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c514d4c74e..2a4af1288c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -13,7 +13,7 @@ from tqdm.auto import tqdm from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -import multiprocessing as mp +import multiprocessing import threading from threadpoolctl import threadpool_limits @@ -289,6 +289,8 @@ class ChunkRecordingExecutor: If True, output is verbose job_name : str, default: "" Job name + progress_bar : bool, default: False + If True, a progress bar is printed to monitor the progress of the process handle_returns : bool, default: False If True, the function can return values gather_func : None or callable, default: None @@ -313,9 +315,8 @@ class ChunkRecordingExecutor: Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. - progress_bar : bool, default: False - If True, a progress bar is printed to monitor the progress of the process - + need_worker_index : bool, default False + If True then each worker will also have a "worker_index" injected in the local worker dict. Returns ------- @@ -342,6 +343,7 @@ def __init__( mp_context=None, job_name="", max_threads_per_process=1, + need_worker_index=False, ): self.recording = recording self.func = func @@ -377,6 +379,8 @@ def __init__( self.pool_engine = pool_engine + self.need_worker_index = need_worker_index + if verbose: chunk_memory = self.chunk_size * recording.get_num_channels() * np.dtype(recording.get_dtype()).itemsize total_memory = chunk_memory * self.n_jobs @@ -412,9 +416,12 @@ def run(self, recording_slices=None): if self.progress_bar: recording_slices = tqdm(recording_slices, desc=self.job_name, total=len(recording_slices)) - worker_ctx = self.init_func(*self.init_args) + worker_dict = self.init_func(*self.init_args) + if self.need_worker_index: + worker_dict["worker_index"] = 0 + for segment_index, frame_start, frame_stop in recording_slices: - res = self.func(segment_index, frame_start, frame_stop, worker_ctx) + res = self.func(segment_index, frame_start, frame_stop, worker_dict) if self.handle_returns: returns.append(res) if self.gather_func is not None: @@ -425,12 +432,21 @@ def run(self, recording_slices=None): if self.pool_engine == "process": + if self.need_worker_index: + lock = multiprocessing.Lock() + array_pid = multiprocessing.Array("i", n_jobs) + for i in range(n_jobs): + array_pid[i] = -1 + else: + lock = None + array_pid = None + # parallel with ProcessPoolExecutor( max_workers=n_jobs, initializer=process_worker_initializer, - mp_context=mp.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process), + mp_context=multiprocessing.get_context(self.mp_context), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, self.need_worker_index, lock, array_pid), ) as executor: results = executor.map(process_function_wrapper, recording_slices) @@ -444,29 +460,41 @@ def run(self, recording_slices=None): self.gather_func(res) elif self.pool_engine == "thread": - # only one shared context + # this is need to create a per worker local dict where the initializer will push the func wrapper + thread_local_data = threading.local() - # worker_dict = self.init_func(*self.init_args) - # thread_func = WorkerFuncWrapper(self.func, worker_dict, self.max_threads_per_process) + global _thread_started + _thread_started = 0 - thread_data = threading.local() + if self.progress_bar: + # here the tqdm threading do not work (maybe collision) so we need to create a pbar + # before thread spawning + pbar = tqdm(desc=self.job_name, total=len(recording_slices)) + + if self.need_worker_index: + lock = threading.Lock() + thread_started = 0 with ThreadPoolExecutor( max_workers=n_jobs, initializer=thread_worker_initializer, - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_data), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_local_data, self.need_worker_index, lock), ) as executor: - recording_slices2 = [(thread_data, ) + args for args in recording_slices] - results = executor.map(thread_function_wrapper, recording_slices2) - if self.progress_bar: - results = tqdm(results, desc=self.job_name, total=len(recording_slices)) + + recording_slices2 = [(thread_local_data, ) + args for args in recording_slices] + results = executor.map(thread_function_wrapper, recording_slices2) for res in results: + if self.progress_bar: + pbar.update(1) if self.handle_returns: returns.append(res) if self.gather_func is not None: self.gather_func(res) + if self.progress_bar: + pbar.close() + del pbar else: raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") @@ -476,6 +504,11 @@ def run(self, recording_slices=None): class WorkerFuncWrapper: + """ + small wraper that handle: + * local worker_dict + * max_threads_per_process + """ def __init__(self, func, worker_dict, max_threads_per_process): self.func = func self.worker_dict = worker_dict @@ -498,36 +531,57 @@ def __call__(self, args): global _process_func_wrapper -def process_worker_initializer(func, init_func, init_args, max_threads_per_process): +def process_worker_initializer(func, init_func, init_args, max_threads_per_process, need_worker_index, lock, array_pid): global _process_func_wrapper if max_threads_per_process is None: worker_dict = init_func(*init_args) else: with threadpool_limits(limits=max_threads_per_process): worker_dict = init_func(*init_args) + + if need_worker_index: + child_process = multiprocessing.current_process() + lock.acquire() + worker_index = None + for i in range(len(array_pid)): + if array_pid[i] == -1: + worker_index = i + array_pid[i] = child_process.ident + break + worker_dict["worker_index"] = worker_index + lock.release() + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) def process_function_wrapper(args): global _process_func_wrapper return _process_func_wrapper(args) -def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_data): + +# use by thread at init +global _thread_started + +def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_local_data, need_worker_index, lock): if max_threads_per_process is None: worker_dict = init_func(*init_args) else: with threadpool_limits(limits=max_threads_per_process): worker_dict = init_func(*init_args) - thread_data._func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) - # print("ici", thread_data._func_wrapper) -def thread_function_wrapper(args): - thread_data = args[0] - args = args[1:] - # thread_data = threading.local() - # print("la", thread_data._func_wrapper) - return thread_data._func_wrapper(args) + if need_worker_index: + lock.acquire() + global _thread_started + worker_index = _thread_started + _thread_started += 1 + worker_dict["worker_index"] = worker_index + lock.release() + thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) +def thread_function_wrapper(args): + thread_local_data = args[0] + args = args[1:] + return thread_local_data.func_wrapper(args) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index c46914ab03..5a32898411 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -1,6 +1,8 @@ import pytest import os +import time + from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs from spikeinterface.core.job_tools import ( @@ -77,22 +79,22 @@ def test_ensure_chunk_size(): assert end_frame == recording.get_num_frames(segment_index=segment_index) -def func(segment_index, start_frame, end_frame, worker_ctx): +def func(segment_index, start_frame, end_frame, worker_dict): import os import time - #  print('func', segment_index, start_frame, end_frame, worker_ctx, os.getpid()) + #  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid()) time.sleep(0.010) # time.sleep(1.0) return os.getpid() def init_func(arg1, arg2, arg3): - worker_ctx = {} - worker_ctx["arg1"] = arg1 - worker_ctx["arg2"] = arg2 - worker_ctx["arg3"] = arg3 - return worker_ctx + worker_dict = {} + worker_dict["arg1"] = arg1 + worker_dict["arg2"] = arg2 + worker_dict["arg3"] = arg3 + return worker_dict def test_ChunkRecordingExecutor(): @@ -235,10 +237,51 @@ def test_split_job_kwargs(): assert "other_param" not in job_kwargs and "n_jobs" in job_kwargs and "progress_bar" in job_kwargs + + +def func2(segment_index, start_frame, end_frame, worker_dict): + time.sleep(0.010) + # print(os.getpid(), worker_dict["worker_index"]) + return worker_dict["worker_index"] + + +def init_func2(): + # this leave time for other thread/process to start + time.sleep(0.010) + worker_dict = {} + return worker_dict + + +def test_worker_index(): + recording = generate_recording(num_channels=2) + init_args = tuple() + + for i in range(2): + # making this 2 times ensure to test that global variables are correctly reset + for pool_engine in ("process", "thread"): + processor = ChunkRecordingExecutor( + recording, + func2, + init_func2, + init_args, + progress_bar=False, + gather_func=None, + pool_engine=pool_engine, + n_jobs=2, + handle_returns=True, + chunk_duration="200ms", + need_worker_index=True + ) + res = processor.run() + # we should have a mix of 0 and 1 + assert 0 in res + assert 1 in res + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - test_ChunkRecordingExecutor() + # test_ChunkRecordingExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() + test_worker_index() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index 845eaf1310..d0e9358164 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -176,17 +176,25 @@ def test_estimate_templates_with_accumulator(): templates = estimate_templates_with_accumulator( recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs ) - print(templates.shape) + # print(templates.shape) assert templates.shape[0] == sorting.unit_ids.size assert templates.shape[1] == nbefore + nafter assert templates.shape[2] == recording.get_num_channels() assert np.any(templates != 0) + job_kwargs = dict(n_jobs=1, progress_bar=True, chunk_duration="1s") + templates_loop = estimate_templates_with_accumulator( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + np.testing.assert_almost_equal(templates, templates_loop, decimal=4) + # import matplotlib.pyplot as plt # fig, ax = plt.subplots() # for unit_index, unit_id in enumerate(sorting.unit_ids): - # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") + # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") # plt.show() @@ -225,6 +233,6 @@ def test_estimate_templates(): if __name__ == "__main__": - test_waveform_tools() + # test_waveform_tools() test_estimate_templates_with_accumulator() - test_estimate_templates() + # test_estimate_templates() diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index 3affd7f0ec..8a7b15f886 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -296,17 +296,17 @@ def _init_worker_distribute_buffers( recording, unit_ids, spikes, arrays_info, nbefore, nafter, return_scaled, inds_by_unit, mode, sparsity_mask ): # create a local dict per worker - worker_ctx = {} + worker_dict = {} if isinstance(recording, dict): from spikeinterface.core import load_extractor recording = load_extractor(recording) - worker_ctx["recording"] = recording + worker_dict["recording"] = recording if mode == "memmap": # in memmap mode we have the "too many open file" problem with linux # memmap file will be open on demand and not globally per worker - worker_ctx["arrays_info"] = arrays_info + worker_dict["arrays_info"] = arrays_info elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory @@ -321,33 +321,33 @@ def _init_worker_distribute_buffers( waveforms_by_units[unit_id] = arr # we need a reference to all sham otherwise we get segment fault!!! shms[unit_id] = shm - worker_ctx["shms"] = shms - worker_ctx["waveforms_by_units"] = waveforms_by_units + worker_dict["shms"] = shms + worker_dict["waveforms_by_units"] = waveforms_by_units - worker_ctx["unit_ids"] = unit_ids - worker_ctx["spikes"] = spikes + worker_dict["unit_ids"] = unit_ids + worker_dict["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["inds_by_unit"] = inds_by_unit - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["inds_by_unit"] = inds_by_unit + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - unit_ids = worker_ctx["unit_ids"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - inds_by_unit = worker_ctx["inds_by_unit"] - sparsity_mask = worker_ctx["sparsity_mask"] + recording = worker_dict["recording"] + unit_ids = worker_dict["unit_ids"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + inds_by_unit = worker_dict["inds_by_unit"] + sparsity_mask = worker_dict["sparsity_mask"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -383,12 +383,12 @@ def _worker_distribute_buffers(segment_index, start_frame, end_frame, worker_ctx if in_chunk_pos.size == 0: continue - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": # open file in demand (and also autoclose it after) - filename = worker_ctx["arrays_info"][unit_id] + filename = worker_dict["arrays_info"][unit_id] wfs = np.load(str(filename), mmap_mode="r+") - elif worker_ctx["mode"] == "shared_memory": - wfs = worker_ctx["waveforms_by_units"][unit_id] + elif worker_dict["mode"] == "shared_memory": + wfs = worker_dict["waveforms_by_units"][unit_id] for pos in in_chunk_pos: sample_index = spikes[inds[pos]]["sample_index"] @@ -548,50 +548,50 @@ def extract_waveforms_to_single_buffer( def _init_worker_distribute_single_buffer( recording, spikes, wf_array_info, nbefore, nafter, return_scaled, mode, sparsity_mask ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["wf_array_info"] = wf_array_info - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled - worker_ctx["sparsity_mask"] = sparsity_mask - worker_ctx["mode"] = mode + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["wf_array_info"] = wf_array_info + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled + worker_dict["sparsity_mask"] = sparsity_mask + worker_dict["mode"] = mode if mode == "memmap": filename = wf_array_info["filename"] all_waveforms = np.load(str(filename), mmap_mode="r+") - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["all_waveforms"] = all_waveforms elif mode == "shared_memory": from multiprocessing.shared_memory import SharedMemory shm_name, dtype, shape = wf_array_info["shm_name"], wf_array_info["dtype"], wf_array_info["shape"] shm = SharedMemory(shm_name) all_waveforms = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["all_waveforms"] = all_waveforms + worker_dict["shm"] = shm + worker_dict["all_waveforms"] = all_waveforms # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_ctx): +def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - return_scaled = worker_ctx["return_scaled"] - sparsity_mask = worker_ctx["sparsity_mask"] - all_waveforms = worker_ctx["all_waveforms"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + return_scaled = worker_dict["return_scaled"] + sparsity_mask = worker_dict["sparsity_mask"] + all_waveforms = worker_dict["all_waveforms"] seg_size = recording.get_num_samples(segment_index=segment_index) @@ -630,7 +630,7 @@ def _worker_distribute_single_buffer(segment_index, start_frame, end_frame, work wf = wf[:, mask] all_waveforms[spike_index, :, : wf.shape[1]] = wf - if worker_ctx["mode"] == "memmap": + if worker_dict["mode"] == "memmap": all_waveforms.flush() @@ -843,12 +843,6 @@ def estimate_templates_with_accumulator( waveform_squared_accumulator_per_worker = None shm_squared_name = None - # trick to get the work_index given pid arrays - lock = multiprocessing.Lock() - array_pid = multiprocessing.Array("i", num_worker) - for i in range(num_worker): - array_pid[i] = -1 - func = _worker_estimate_templates init_func = _init_worker_estimate_templates @@ -862,14 +856,12 @@ def estimate_templates_with_accumulator( nbefore, nafter, return_scaled, - lock, - array_pid, ) if job_name is None: job_name = "estimate_templates_with_accumulator" processor = ChunkRecordingExecutor( - recording, func, init_func, init_args, job_name=job_name, verbose=verbose, **job_kwargs + recording, func, init_func, init_args, job_name=job_name, verbose=verbose, need_worker_index=True, **job_kwargs ) processor.run() @@ -920,15 +912,13 @@ def _init_worker_estimate_templates( nbefore, nafter, return_scaled, - lock, - array_pid, ): - worker_ctx = {} - worker_ctx["recording"] = recording - worker_ctx["spikes"] = spikes - worker_ctx["nbefore"] = nbefore - worker_ctx["nafter"] = nafter - worker_ctx["return_scaled"] = return_scaled + worker_dict = {} + worker_dict["recording"] = recording + worker_dict["spikes"] = spikes + worker_dict["nbefore"] = nbefore + worker_dict["nafter"] = nafter + worker_dict["return_scaled"] = return_scaled from multiprocessing.shared_memory import SharedMemory import multiprocessing @@ -936,48 +926,36 @@ def _init_worker_estimate_templates( shm = SharedMemory(shm_name) waveform_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm.buf) - worker_ctx["shm"] = shm - worker_ctx["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker + worker_dict["shm"] = shm + worker_dict["waveform_accumulator_per_worker"] = waveform_accumulator_per_worker if shm_squared_name is not None: shm_squared = SharedMemory(shm_squared_name) waveform_squared_accumulator_per_worker = np.ndarray(shape=shape, dtype=dtype, buffer=shm_squared.buf) - worker_ctx["shm_squared"] = shm_squared - worker_ctx["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker + worker_dict["shm_squared"] = shm_squared + worker_dict["waveform_squared_accumulator_per_worker"] = waveform_squared_accumulator_per_worker # prepare segment slices segment_slices = [] for segment_index in range(recording.get_num_segments()): s0, s1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append((s0, s1)) - worker_ctx["segment_slices"] = segment_slices - - child_process = multiprocessing.current_process() - - lock.acquire() - num_worker = None - for i in range(len(array_pid)): - if array_pid[i] == -1: - num_worker = i - array_pid[i] = child_process.ident - break - worker_ctx["worker_index"] = num_worker - lock.release() + worker_dict["segment_slices"] = segment_slices - return worker_ctx + return worker_dict # used by ChunkRecordingExecutor -def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_ctx): +def _worker_estimate_templates(segment_index, start_frame, end_frame, worker_dict): # recover variables of the worker - recording = worker_ctx["recording"] - segment_slices = worker_ctx["segment_slices"] - spikes = worker_ctx["spikes"] - nbefore = worker_ctx["nbefore"] - nafter = worker_ctx["nafter"] - waveform_accumulator_per_worker = worker_ctx["waveform_accumulator_per_worker"] - waveform_squared_accumulator_per_worker = worker_ctx.get("waveform_squared_accumulator_per_worker", None) - worker_index = worker_ctx["worker_index"] - return_scaled = worker_ctx["return_scaled"] + recording = worker_dict["recording"] + segment_slices = worker_dict["segment_slices"] + spikes = worker_dict["spikes"] + nbefore = worker_dict["nbefore"] + nafter = worker_dict["nafter"] + waveform_accumulator_per_worker = worker_dict["waveform_accumulator_per_worker"] + waveform_squared_accumulator_per_worker = worker_dict.get("waveform_squared_accumulator_per_worker", None) + worker_index = worker_dict["worker_index"] + return_scaled = worker_dict["return_scaled"] seg_size = recording.get_num_samples(segment_index=segment_index) From e929820b06c0c2bf881742cb4e459e88e96be4cf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 14:19:35 +0100 Subject: [PATCH 06/16] change max_threads_per_process to max_threads_per_worker --- doc/get_started/quickstart.rst | 2 +- src/spikeinterface/core/globals.py | 2 +- src/spikeinterface/core/job_tools.py | 52 ++++++++++++------- src/spikeinterface/core/tests/test_globals.py | 6 +-- .../core/tests/test_job_tools.py | 4 +- .../postprocessing/principal_component.py | 16 +++--- .../tests/test_principal_component.py | 2 +- .../qualitymetrics/pca_metrics.py | 10 ++-- .../qualitymetrics/tests/test_pca_metrics.py | 6 +-- .../sortingcomponents/clustering/merge.py | 12 ++--- .../sortingcomponents/clustering/split.py | 10 ++-- 11 files changed, 67 insertions(+), 55 deletions(-) diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index 3d45606a78..d1bf311340 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -287,7 +287,7 @@ available parameters are dictionaries and can be accessed with: 'detect_threshold': 5, 'freq_max': 5000.0, 'freq_min': 400.0, - 'max_threads_per_process': 1, + 'max_threads_per_worker': 1, 'mp_context': None, 'n_jobs': 20, 'nested_params': None, diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 38f39c5481..195440c061 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,7 +97,7 @@ def is_set_global_dataset_folder() -> bool: ######################################## -_default_job_kwargs = dict(n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) +_default_job_kwargs = dict(pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global global_job_kwargs global_job_kwargs = _default_job_kwargs.copy() diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 2a4af1288c..b37c9b7d69 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -48,7 +48,7 @@ "chunk_duration", "progress_bar", "mp_context", - "max_threads_per_process", + "max_threads_per_worker", ) # theses key are the same and should not be in th final dict @@ -65,6 +65,17 @@ def fix_job_kwargs(runtime_job_kwargs): job_kwargs = get_global_job_kwargs() + # deprecation with backward compatibility + # this can be removed in 0.104.0 + if "max_threads_per_process" in runtime_job_kwargs: + runtime_job_kwargs = runtime_job_kwargs.copy() + runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") + warnings.warn( + "job_kwargs: max_threads_per_worker was changed to max_threads_per_worker", + DeprecationWarning, + stacklevel=2, + ) + for k in runtime_job_kwargs: assert k in job_keys, ( f"{k} is not a valid job keyword argument. " f"Available keyword arguments are: {list(job_keys)}" @@ -311,7 +322,7 @@ class ChunkRecordingExecutor: mp_context : "fork" | "spawn" | None, default: None "fork" or "spawn". If None, the context is taken by the recording.get_preferred_mp_context(). "fork" is only safely available on LINUX systems. - max_threads_per_process : int or None, default: None + max_threads_per_worker : int or None, default: None Limit the number of thread per process using threadpoolctl modules. This used only when n_jobs>1 If None, no limits. @@ -342,7 +353,7 @@ def __init__( chunk_duration=None, mp_context=None, job_name="", - max_threads_per_process=1, + max_threads_per_worker=1, need_worker_index=False, ): self.recording = recording @@ -375,7 +386,7 @@ def __init__( n_jobs=self.n_jobs, ) self.job_name = job_name - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker self.pool_engine = pool_engine @@ -446,7 +457,7 @@ def run(self, recording_slices=None): max_workers=n_jobs, initializer=process_worker_initializer, mp_context=multiprocessing.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, self.need_worker_index, lock, array_pid), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, self.need_worker_index, lock, array_pid), ) as executor: results = executor.map(process_function_wrapper, recording_slices) @@ -473,12 +484,13 @@ def run(self, recording_slices=None): if self.need_worker_index: lock = threading.Lock() - thread_started = 0 + else: + lock = None with ThreadPoolExecutor( max_workers=n_jobs, initializer=thread_worker_initializer, - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_process, thread_local_data, self.need_worker_index, lock), + initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, thread_local_data, self.need_worker_index, lock), ) as executor: @@ -507,19 +519,19 @@ class WorkerFuncWrapper: """ small wraper that handle: * local worker_dict - * max_threads_per_process + * max_threads_per_worker """ - def __init__(self, func, worker_dict, max_threads_per_process): + def __init__(self, func, worker_dict, max_threads_per_worker): self.func = func self.worker_dict = worker_dict - self.max_threads_per_process = max_threads_per_process + self.max_threads_per_worker = max_threads_per_worker def __call__(self, args): segment_index, start_frame, end_frame = args - if self.max_threads_per_process is None: + if self.max_threads_per_worker is None: return self.func(segment_index, start_frame, end_frame, self.worker_dict) else: - with threadpool_limits(limits=self.max_threads_per_process): + with threadpool_limits(limits=self.max_threads_per_worker): return self.func(segment_index, start_frame, end_frame, self.worker_dict) # see @@ -531,12 +543,12 @@ def __call__(self, args): global _process_func_wrapper -def process_worker_initializer(func, init_func, init_args, max_threads_per_process, need_worker_index, lock, array_pid): +def process_worker_initializer(func, init_func, init_args, max_threads_per_worker, need_worker_index, lock, array_pid): global _process_func_wrapper - if max_threads_per_process is None: + if max_threads_per_worker is None: worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): + with threadpool_limits(limits=max_threads_per_worker): worker_dict = init_func(*init_args) if need_worker_index: @@ -551,7 +563,7 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_proce worker_dict["worker_index"] = worker_index lock.release() - _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) def process_function_wrapper(args): global _process_func_wrapper @@ -561,11 +573,11 @@ def process_function_wrapper(args): # use by thread at init global _thread_started -def thread_worker_initializer(func, init_func, init_args, max_threads_per_process, thread_local_data, need_worker_index, lock): - if max_threads_per_process is None: +def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock): + if max_threads_per_worker is None: worker_dict = init_func(*init_args) else: - with threadpool_limits(limits=max_threads_per_process): + with threadpool_limits(limits=max_threads_per_worker): worker_dict = init_func(*init_args) if need_worker_index: @@ -576,7 +588,7 @@ def thread_worker_initializer(func, init_func, init_args, max_threads_per_proces worker_dict["worker_index"] = worker_index lock.release() - thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_process) + thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) def thread_function_wrapper(args): thread_local_data = args[0] diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 9677378fc5..2b21cd8978 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,7 +36,7 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1) + job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global_job_kwargs = get_global_job_kwargs() # test warning when not setting n_jobs and calling fix_job_kwargs @@ -44,7 +44,7 @@ def test_global_job_kwargs(): job_kwargs_split = fix_job_kwargs({}) assert global_job_kwargs == dict( - n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs @@ -59,7 +59,7 @@ def test_global_job_kwargs(): set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_process=1 + n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 5a32898411..8872a259bf 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -281,7 +281,7 @@ def test_worker_index(): # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - # test_ChunkRecordingExecutor() + test_ChunkRecordingExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() - test_worker_index() + # test_worker_index() diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 809f2c5bba..84fbfc5965 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -316,13 +316,13 @@ def _run(self, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) n_jobs = job_kwargs["n_jobs"] progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + max_threads_per_worker = job_kwargs["max_threads_per_worker"] mp_context = job_kwargs["mp_context"] # fit model/models # TODO : make parralel for by_channel_global and concatenated if mode == "by_channel_local": - pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_process, mp_context) + pca_models = self._fit_by_channel_local(n_jobs, progress_bar, max_threads_per_worker, mp_context) for chan_ind, chan_id in enumerate(self.sorting_analyzer.channel_ids): self.data[f"pca_model_{mode}_{chan_id}"] = pca_models[chan_ind] pca_model = pca_models @@ -415,7 +415,7 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): ) processor.run() - def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, mp_context): + def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_worker, mp_context): from sklearn.decomposition import IncrementalPCA p = self.params @@ -444,10 +444,10 @@ def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, m pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - # create list of args to parallelize. For convenience, the max_threads_per_process is passed + # create list of args to parallelize. For convenience, the max_threads_per_worker is passed # as last argument items = [ - (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_process) + (chan_ind, pca_models[chan_ind], wfs[:, :, wf_ind], max_threads_per_worker) for wf_ind, chan_ind in enumerate(channel_inds) ] n_jobs = min(n_jobs, len(items)) @@ -687,12 +687,12 @@ def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafte def _partial_fit_one_channel(args): - chan_ind, pca_model, wf_chan, max_threads_per_process = args + chan_ind, pca_model, wf_chan, max_threads_per_worker = args - if max_threads_per_process is None: + if max_threads_per_worker is None: pca_model.partial_fit(wf_chan) return chan_ind, pca_model else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): pca_model.partial_fit(wf_chan) return chan_ind, pca_model diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 7a509c410f..ecfc39f2c6 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -27,7 +27,7 @@ def test_multi_processing(self): ) sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) sorting_analyzer.compute( - "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" + "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_worker=4, mp_context="spawn" ) def test_mode_concatenated(self): diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 4c68dfea59..55f91fd87f 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -58,7 +58,7 @@ def compute_pc_metrics( n_jobs=1, progress_bar=False, mp_context=None, - max_threads_per_process=None, + max_threads_per_worker=None, ) -> dict: """ Calculate principal component derived metrics. @@ -147,7 +147,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_worker) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -977,12 +977,12 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_worker) = args - if max_threads_per_process is None: + if max_threads_per_worker is None: return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) else: - with threadpool_limits(limits=int(max_threads_per_process)): + with threadpool_limits(limits=int(max_threads_per_worker)): return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index f2e912c6b4..ba8dae4619 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -31,13 +31,13 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer): print(f"Computing PCA metrics with 1 thread per process") res1 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=1, progress_bar=True ) print(f"Computing PCA metrics with 2 thread per process") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) print("Computing PCA metrics with spawn context") res2 = compute_pc_metrics( - sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True + sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4a7b722aea..e618cfbfb6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -261,7 +261,7 @@ def find_merge_pairs( **job_kwargs, # n_jobs=1, # mp_context="fork", - # max_threads_per_process=1, + # max_threads_per_worker=1, # progress_bar=True, ): """ @@ -299,7 +299,7 @@ def find_merge_pairs( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) @@ -316,7 +316,7 @@ def find_merge_pairs( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ), ) as pool: jobs = [] @@ -354,7 +354,7 @@ def find_pair_worker_init( templates, method, method_kwargs, - max_threads_per_process, + max_threads_per_worker, ): global _ctx _ctx = {} @@ -366,7 +366,7 @@ def find_pair_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker # if isinstance(features_dict_or_folder, dict): # _ctx["features"] = features_dict_or_folder @@ -380,7 +380,7 @@ def find_pair_worker_init( def find_pair_function_wrapper(label0, label1): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( label0, label1, diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index 15917934a8..3c2e878c39 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -65,7 +65,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) + max_threads_per_worker = job_kwargs.get("max_threads_per_worker", 1) original_labels = peak_labels peak_labels = peak_labels.copy() @@ -77,7 +77,7 @@ def split_clusters( max_workers=n_jobs, initializer=split_worker_init, mp_context=get_context(method=mp_context), - initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker), ) as pool: labels_set = np.setdiff1d(peak_labels, [-1]) current_max_label = np.max(labels_set) + 1 @@ -133,7 +133,7 @@ def split_clusters( def split_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_worker ): global _ctx _ctx = {} @@ -144,14 +144,14 @@ def split_worker_init( _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = split_methods_dict[method] - _ctx["max_threads_per_process"] = max_threads_per_process + _ctx["max_threads_per_worker"] = max_threads_per_worker _ctx["features"] = FeaturesLoader.from_dict_or_folder(features_dict_or_folder) _ctx["peaks"] = _ctx["features"]["peaks"] def split_function_wrapper(peak_indices, recursion_level): global _ctx - with threadpool_limits(limits=_ctx["max_threads_per_process"]): + with threadpool_limits(limits=_ctx["max_threads_per_worker"]): is_split, local_labels = _ctx["method_class"].split( peak_indices, _ctx["peaks"], _ctx["features"], recursion_level, **_ctx["method_kwargs"] ) From d4a6e95d1c9f6a7d5cb9d5f4ca017a2240c187ad Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 14:35:06 +0100 Subject: [PATCH 07/16] implement get_best_job_kwargs() --- src/spikeinterface/core/__init__.py | 2 +- src/spikeinterface/core/job_tools.py | 39 +++++++++++++++++++ .../core/tests/test_job_tools.py | 9 ++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index ead7007920..bea77decfc 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -90,7 +90,7 @@ write_python, normal_pdf, ) -from .job_tools import ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs +from .job_tools import get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index b37c9b7d69..b12ad7fc4d 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -59,6 +59,45 @@ "chunk_duration", ) +def get_best_job_kwargs(): + """ + Given best possible job_kwargs for the platform. + """ + + n_cpu = os.cpu_count() + + if platform.system() == "Linux": + # maybe we should test this more but with linux the fork is still faster than threading + pool_engine = "process" + mp_context = "fork" + + # this is totally empiricat but this is a good start + if n_cpu <= 16: + # for small n_cpu lets make many process + n_jobs = n_cpu + max_threads_per_worker = 1 + else: + # lets have less process with more thread each + n_cpu = int(n_cpu / 4) + max_threads_per_worker = 8 + + else: # windows and mac + # on windows and macos the fork is forbidden and process+spwan is super slow at startup + # so lets go to threads + pool_engine = "thread" + mp_context = None + n_jobs = n_cpu + max_threads_per_worker = 1 + + return dict( + pool_engine=pool_engine, + mp_context=mp_context, + n_jobs=n_jobs, + max_threads_per_worker=max_threads_per_worker, + ) + + + def fix_job_kwargs(runtime_job_kwargs): from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 8872a259bf..3918fe8ec0 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -3,7 +3,7 @@ import time -from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs +from spikeinterface.core import generate_recording, set_global_job_kwargs, get_global_job_kwargs, get_best_job_kwargs from spikeinterface.core.job_tools import ( divide_segment_into_chunks, @@ -277,11 +277,16 @@ def test_worker_index(): assert 0 in res assert 1 in res +def test_get_best_job_kwargs(): + job_kwargs = get_best_job_kwargs() + print(job_kwargs) + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() # test_ensure_chunk_size() - test_ChunkRecordingExecutor() + # test_ChunkRecordingExecutor() # test_fix_job_kwargs() # test_split_job_kwargs() # test_worker_index() + test_get_best_job_kwargs() From 423801b2d99d3a1cead66f9a9150c4b344568b04 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 14:38:30 +0100 Subject: [PATCH 08/16] oups --- src/spikeinterface/core/tests/test_globals.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 2b21cd8978..896b737c88 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -44,7 +44,7 @@ def test_global_job_kwargs(): job_kwargs_split = fix_job_kwargs({}) assert global_job_kwargs == dict( - n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 + pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs @@ -80,6 +80,6 @@ def test_global_job_kwargs(): if __name__ == "__main__": - test_global_dataset_folder() - test_global_tmp_folder() + # test_global_dataset_folder() + # test_global_tmp_folder() test_global_job_kwargs() From 1736b65ceb13adcae78eb161b7dbbc52ad666400 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Nov 2024 14:41:02 +0100 Subject: [PATCH 09/16] oups --- src/spikeinterface/core/tests/test_globals.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 896b737c88..580287eb21 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,7 +36,7 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) + job_kwargs = dict(pool_engine="thread", n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) global_job_kwargs = get_global_job_kwargs() # test warning when not setting n_jobs and calling fix_job_kwargs @@ -59,7 +59,7 @@ def test_global_job_kwargs(): set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 + pool_engine="thread", n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) From f0ec139fdd52b43048becd9323d7475a6d41097e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Nov 2024 10:06:37 +0100 Subject: [PATCH 10/16] Feedback from Zach and Alessio better test for waveforms_tools --- src/spikeinterface/core/job_tools.py | 25 ++++---- .../core/tests/test_job_tools.py | 1 - .../core/tests/test_waveform_tools.py | 58 ++++++++++++------- 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index b12ad7fc4d..64a5c6cdbf 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -110,7 +110,7 @@ def fix_job_kwargs(runtime_job_kwargs): runtime_job_kwargs = runtime_job_kwargs.copy() runtime_job_kwargs["max_threads_per_worker"] = runtime_job_kwargs.pop("max_threads_per_process") warnings.warn( - "job_kwargs: max_threads_per_worker was changed to max_threads_per_worker", + "job_kwargs: max_threads_per_process was changed to max_threads_per_worker, max_threads_per_process will be removed in 0.104", DeprecationWarning, stacklevel=2, ) @@ -346,7 +346,7 @@ class ChunkRecordingExecutor: gather_func : None or callable, default: None Optional function that is called in the main thread and retrieves the results of each worker. This function can be used instead of `handle_returns` to implement custom storage on-the-fly. - pool_engine : "process" | "thread" + pool_engine : "process" | "thread", default: "thread" If n_jobs>1 then use ProcessPoolExecutor or ThreadPoolExecutor n_jobs : int, default: 1 Number of jobs to be used. Use -1 to use as many jobs as number of cores @@ -384,7 +384,7 @@ def __init__( progress_bar=False, handle_returns=False, gather_func=None, - pool_engine="process", + pool_engine="thread", n_jobs=1, total_memory=None, chunk_size=None, @@ -400,12 +400,13 @@ def __init__( self.init_func = init_func self.init_args = init_args - if mp_context is None: - mp_context = recording.get_preferred_mp_context() - if mp_context is not None and platform.system() == "Windows": - assert mp_context != "fork", "'fork' mp_context not supported on Windows!" - elif mp_context == "fork" and platform.system() == "Darwin": - warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + if pool_engine == "process": + if mp_context is None: + mp_context = recording.get_preferred_mp_context() + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') self.mp_context = mp_context @@ -572,13 +573,9 @@ def __call__(self, args): else: with threadpool_limits(limits=self.max_threads_per_worker): return self.func(segment_index, start_frame, end_frame, self.worker_dict) - # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : thiw variables are global per worker -# so they are not share in the same process -# global _worker_ctx -# global _func +# the tricks is : this variable are global per worker (so not shared in the same process) global _process_func_wrapper diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 3918fe8ec0..824532a11e 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -81,7 +81,6 @@ def test_ensure_chunk_size(): def func(segment_index, start_frame, end_frame, worker_dict): import os - import time #  print('func', segment_index, start_frame, end_frame, worker_dict, os.getpid()) time.sleep(0.010) diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index d0e9358164..ed27815758 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -173,29 +173,45 @@ def test_estimate_templates_with_accumulator(): job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - templates = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - # print(templates.shape) - assert templates.shape[0] == sorting.unit_ids.size - assert templates.shape[1] == nbefore + nafter - assert templates.shape[2] == recording.get_num_channels() + # here we compare the result with the same mechanism with with several worker pool size + # this means that that acumulator are splitted and then agglomerated back + # this should lead to very small diff + # n_jobs=1 is done in loop + templates_by_worker = [] + + if platform.system() == "Linux": + engine_loop = ["thread", "process"] + else: + engine_loop = ["thread"] + + for pool_engine in engine_loop: + for n_jobs in (1, 2, 8): + job_kwargs = dict(pool_engine=pool_engine, n_jobs=n_jobs, progress_bar=True, chunk_duration="1s") + templates = estimate_templates_with_accumulator( + recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs + ) + assert templates.shape[0] == sorting.unit_ids.size + assert templates.shape[1] == nbefore + nafter + assert templates.shape[2] == recording.get_num_channels() + assert np.any(templates != 0) + + templates_by_worker.append(templates) + if len(templates_by_worker) > 1: + templates_loop = templates_by_worker[0] + np.testing.assert_almost_equal(templates, templates_loop, decimal=4) + + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, sharex=True) + # for unit_index, unit_id in enumerate(sorting.unit_ids): + # ax = axs[0] + # ax.set_title(f"{pool_engine} {n_jobs}") + # ax.plot(templates[unit_index, :, :].T.flatten()) + # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") + # ax = axs[1] + # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") + # plt.show() - assert np.any(templates != 0) - job_kwargs = dict(n_jobs=1, progress_bar=True, chunk_duration="1s") - templates_loop = estimate_templates_with_accumulator( - recording, spikes, sorting.unit_ids, nbefore, nafter, return_scaled=True, **job_kwargs - ) - np.testing.assert_almost_equal(templates, templates_loop, decimal=4) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # for unit_index, unit_id in enumerate(sorting.unit_ids): - # ax.plot(templates[unit_index, :, :].T.flatten()) - # ax.plot(templates_loop[unit_index, :, :].T.flatten(), color="k", ls="--") - # ax.plot((templates - templates_loop)[unit_index, :, :].T.flatten(), color="k", ls="--") - # plt.show() def test_estimate_templates(): From cc8b4c4a976d7f60dc7c70358f229e49f034dec8 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Nov 2024 11:04:33 +0100 Subject: [PATCH 11/16] fix tests --- .../qualitymetrics/tests/conftest.py | 7 +++++-- .../qualitymetrics/tests/test_pca_metrics.py | 14 +++++++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..9878adf142 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -8,8 +8,8 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -@pytest.fixture(scope="module") -def small_sorting_analyzer(): + +def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -34,6 +34,9 @@ def small_sorting_analyzer(): return sorting_analyzer +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return make_small_analyzer() @pytest.fixture(scope="module") def sorting_analyzer_simple(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index ba8dae4619..897c2837cc 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -19,7 +19,14 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res1[metric_name].values)) assert not np.all(np.isnan(res2[metric_name].values)) - assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + + np.testing.assert_almost_equal(res1[metric_name].values, res2[metric_name].values, decimal=4) def test_pca_metrics_multi_processing(small_sorting_analyzer): @@ -41,3 +48,8 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer): res2 = compute_pc_metrics( sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) + +if __name__ == "__main__": + from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer + small_sorting_analyzer = make_small_analyzer() + test_calculate_pc_metrics(small_sorting_analyzer) \ No newline at end of file From c16ca722e057671a503c2a38a4d80a1b0cd6b7cd Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Nov 2024 12:11:46 +0100 Subject: [PATCH 12/16] oups --- .../qualitymetrics/tests/test_pca_metrics.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 897c2837cc..312c3949b3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -15,18 +15,25 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res2 = pd.DataFrame(res2) for metric_name in res1.columns: - if metric_name != "nn_unit_id": - assert not np.all(np.isnan(res1[metric_name].values)) - assert not np.all(np.isnan(res2[metric_name].values)) - - # import matplotlib.pyplot as plt - # fig, ax = plt.subplots() - # ax.plot(res1[metric_name].values) - # ax.plot(res2[metric_name].values) - # ax.plot(res2[metric_name].values - res1[metric_name].values) - # plt.show() + values1 = res1[metric_name].values + values2 = res1[metric_name].values - np.testing.assert_almost_equal(res1[metric_name].values, res2[metric_name].values, decimal=4) + if metric_name != "nn_unit_id": + assert not np.all(np.isnan(values1)) + assert not np.all(np.isnan(values2)) + + if values1.dtype.kind == "f": + np.testing.assert_almost_equal(values1, values2, decimal=4) + # import matplotlib.pyplot as plt + # fig, axs = plt.subplots(nrows=2, share=True) + # ax =a xs[0] + # ax.plot(res1[metric_name].values) + # ax.plot(res2[metric_name].values) + # ax =a xs[1] + # ax.plot(res2[metric_name].values - res1[metric_name].values) + # plt.show() + else: + assert np.array_equal(values1, values2) def test_pca_metrics_multi_processing(small_sorting_analyzer): From 155ab31b45f119a56a5318c3a0e030d17c36af07 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Mon, 25 Nov 2024 08:35:12 +0100 Subject: [PATCH 13/16] merci zach Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/job_tools.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 64a5c6cdbf..ce7eb05dbc 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -61,7 +61,8 @@ def get_best_job_kwargs(): """ - Given best possible job_kwargs for the platform. + Gives best possible job_kwargs for the platform. + Currently this function is from developer experience, but may be adapted in the future. """ n_cpu = os.cpu_count() @@ -71,19 +72,19 @@ def get_best_job_kwargs(): pool_engine = "process" mp_context = "fork" - # this is totally empiricat but this is a good start + # this is totally empirical but this is a good start if n_cpu <= 16: - # for small n_cpu lets make many process + # for small n_cpu let's make many process n_jobs = n_cpu max_threads_per_worker = 1 else: - # lets have less process with more thread each + # let's have fewer processes with more threads each n_cpu = int(n_cpu / 4) max_threads_per_worker = 8 else: # windows and mac # on windows and macos the fork is forbidden and process+spwan is super slow at startup - # so lets go to threads + # so let's go to threads pool_engine = "thread" mp_context = None n_jobs = n_cpu @@ -557,7 +558,7 @@ def run(self, recording_slices=None): class WorkerFuncWrapper: """ - small wraper that handle: + small wrapper that handles: * local worker_dict * max_threads_per_worker """ @@ -575,7 +576,7 @@ def __call__(self, args): return self.func(segment_index, start_frame, end_frame, self.worker_dict) # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool -# the tricks is : this variable are global per worker (so not shared in the same process) +# the trick is : this variable is global per worker (so not shared in the same process) global _process_func_wrapper From 6ecbb014ec9a83fa1c867c06f5f8572a83194d5f Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 8 Jan 2025 08:44:36 +0100 Subject: [PATCH 14/16] quick benchmark --- .../core/tests/test_job_tools.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 824532a11e..552fe7b00b 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -280,6 +280,48 @@ def test_get_best_job_kwargs(): job_kwargs = get_best_job_kwargs() print(job_kwargs) + +# def quick_becnhmark(): +# # keep this commented do not remove + + +# from spikeinterface.generation import generate_drifting_recording +# from spikeinterface.sortingcomponents.peak_detection import detect_peaks +# from spikeinterface import get_noise_levels +# import time + +# all_job_kwargs = [ +# dict(pool_engine="process", n_jobs=2, mp_context="spawn", max_threads_per_worker=2), +# dict(pool_engine="process", n_jobs=4, mp_context="spawn", max_threads_per_worker=1), +# dict(pool_engine="thread", n_jobs=4, mp_context=None, max_threads_per_worker=1), +# dict(pool_engine="thread", n_jobs=2, mp_context=None, max_threads_per_worker=2), +# dict(n_jobs=1), +# ] + + + +# rec, _, sorting = generate_drifting_recording( +# num_units=50, +# duration=120.0, +# sampling_frequency=30000.0, +# probe_name="Neuropixel-128", + +# ) +# # print(rec) + +# noise_levels = get_noise_levels(rec, return_scaled=False) +# for job_kwargs in all_job_kwargs: +# print() +# print(job_kwargs) +# t0 = time.perf_counter() +# peaks = detect_peaks(rec, method="locally_exclusive", noise_levels=noise_levels, **job_kwargs) +# t1 = time.perf_counter() +# print("time included the spawn:", t1-t0) + + + + + if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() @@ -289,3 +331,5 @@ def test_get_best_job_kwargs(): # test_split_job_kwargs() # test_worker_index() test_get_best_job_kwargs() + + # quick_becnhmark() From 9329243e87c8634ca60b2fd0cd7e189d9096cc29 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 8 Jan 2025 08:56:51 +0100 Subject: [PATCH 15/16] Pierre suggestion --- src/spikeinterface/core/job_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index ce7eb05dbc..b8970eaf59 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -535,7 +535,7 @@ def run(self, recording_slices=None): ) as executor: - recording_slices2 = [(thread_local_data, ) + args for args in recording_slices] + recording_slices2 = [(thread_local_data, ) + tuple(args) for args in recording_slices] results = executor.map(thread_function_wrapper, recording_slices2) for res in results: From 61f8509d4c0cb082dc574390d8fb6e64d71890d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 08:19:20 +0000 Subject: [PATCH 16/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/__init__.py | 9 +++- src/spikeinterface/core/globals.py | 4 +- src/spikeinterface/core/job_tools.py | 48 +++++++++++++------ src/spikeinterface/core/tests/test_globals.py | 23 +++++++-- .../core/tests/test_job_tools.py | 9 +--- .../core/tests/test_waveform_tools.py | 4 +- .../qualitymetrics/tests/conftest.py | 3 +- .../qualitymetrics/tests/test_pca_metrics.py | 4 +- 8 files changed, 73 insertions(+), 31 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index bea77decfc..f68b70b895 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -90,7 +90,14 @@ write_python, normal_pdf, ) -from .job_tools import get_best_job_kwargs, ensure_n_jobs, ensure_chunk_size, ChunkRecordingExecutor, split_job_kwargs, fix_job_kwargs +from .job_tools import ( + get_best_job_kwargs, + ensure_n_jobs, + ensure_chunk_size, + ChunkRecordingExecutor, + split_job_kwargs, + fix_job_kwargs, +) from .recording_tools import ( write_binary_recording, write_to_h5_dataset_format, diff --git a/src/spikeinterface/core/globals.py b/src/spikeinterface/core/globals.py index 195440c061..e9974adff7 100644 --- a/src/spikeinterface/core/globals.py +++ b/src/spikeinterface/core/globals.py @@ -97,7 +97,9 @@ def is_set_global_dataset_folder() -> bool: ######################################## -_default_job_kwargs = dict(pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) +_default_job_kwargs = dict( + pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 +) global global_job_kwargs global_job_kwargs = _default_job_kwargs.copy() diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index b8970eaf59..ed8a26683c 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -59,6 +59,7 @@ "chunk_duration", ) + def get_best_job_kwargs(): """ Gives best possible job_kwargs for the platform. @@ -82,7 +83,7 @@ def get_best_job_kwargs(): n_cpu = int(n_cpu / 4) max_threads_per_worker = 8 - else: # windows and mac + else: # windows and mac # on windows and macos the fork is forbidden and process+spwan is super slow at startup # so let's go to threads pool_engine = "thread" @@ -98,8 +99,6 @@ def get_best_job_kwargs(): ) - - def fix_job_kwargs(runtime_job_kwargs): from .globals import get_global_job_kwargs, is_set_global_job_kwargs_set @@ -498,7 +497,15 @@ def run(self, recording_slices=None): max_workers=n_jobs, initializer=process_worker_initializer, mp_context=multiprocessing.get_context(self.mp_context), - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, self.need_worker_index, lock, array_pid), + initargs=( + self.func, + self.init_func, + self.init_args, + self.max_threads_per_worker, + self.need_worker_index, + lock, + array_pid, + ), ) as executor: results = executor.map(process_function_wrapper, recording_slices) @@ -510,7 +517,7 @@ def run(self, recording_slices=None): returns.append(res) if self.gather_func is not None: self.gather_func(res) - + elif self.pool_engine == "thread": # this is need to create a per worker local dict where the initializer will push the func wrapper thread_local_data = threading.local() @@ -522,7 +529,7 @@ def run(self, recording_slices=None): # here the tqdm threading do not work (maybe collision) so we need to create a pbar # before thread spawning pbar = tqdm(desc=self.job_name, total=len(recording_slices)) - + if self.need_worker_index: lock = threading.Lock() else: @@ -531,11 +538,18 @@ def run(self, recording_slices=None): with ThreadPoolExecutor( max_workers=n_jobs, initializer=thread_worker_initializer, - initargs=(self.func, self.init_func, self.init_args, self.max_threads_per_worker, thread_local_data, self.need_worker_index, lock), + initargs=( + self.func, + self.init_func, + self.init_args, + self.max_threads_per_worker, + thread_local_data, + self.need_worker_index, + lock, + ), ) as executor: - - recording_slices2 = [(thread_local_data, ) + tuple(args) for args in recording_slices] + recording_slices2 = [(thread_local_data,) + tuple(args) for args in recording_slices] results = executor.map(thread_function_wrapper, recording_slices2) for res in results: @@ -551,9 +565,8 @@ def run(self, recording_slices=None): else: raise ValueError("If n_jobs>1 pool_engine must be 'process' or 'thread'") - - return returns + return returns class WorkerFuncWrapper: @@ -562,11 +575,12 @@ class WorkerFuncWrapper: * local worker_dict * max_threads_per_worker """ + def __init__(self, func, worker_dict, max_threads_per_worker): self.func = func self.worker_dict = worker_dict self.max_threads_per_worker = max_threads_per_worker - + def __call__(self, args): segment_index, start_frame, end_frame = args if self.max_threads_per_worker is None: @@ -574,6 +588,8 @@ def __call__(self, args): else: with threadpool_limits(limits=self.max_threads_per_worker): return self.func(segment_index, start_frame, end_frame, self.worker_dict) + + # see # https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool # the trick is : this variable is global per worker (so not shared in the same process) @@ -602,6 +618,7 @@ def process_worker_initializer(func, init_func, init_args, max_threads_per_worke _process_func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + def process_function_wrapper(args): global _process_func_wrapper return _process_func_wrapper(args) @@ -610,7 +627,10 @@ def process_function_wrapper(args): # use by thread at init global _thread_started -def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock): + +def thread_worker_initializer( + func, init_func, init_args, max_threads_per_worker, thread_local_data, need_worker_index, lock +): if max_threads_per_worker is None: worker_dict = init_func(*init_args) else: @@ -627,13 +647,13 @@ def thread_worker_initializer(func, init_func, init_args, max_threads_per_worker thread_local_data.func_wrapper = WorkerFuncWrapper(func, worker_dict, max_threads_per_worker) + def thread_function_wrapper(args): thread_local_data = args[0] args = args[1:] return thread_local_data.func_wrapper(args) - # Here some utils copy/paste from DART (Charlie Windolf) diff --git a/src/spikeinterface/core/tests/test_globals.py b/src/spikeinterface/core/tests/test_globals.py index 580287eb21..3f86558303 100644 --- a/src/spikeinterface/core/tests/test_globals.py +++ b/src/spikeinterface/core/tests/test_globals.py @@ -36,7 +36,14 @@ def test_global_tmp_folder(create_cache_folder): def test_global_job_kwargs(): - job_kwargs = dict(pool_engine="thread", n_jobs=4, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1) + job_kwargs = dict( + pool_engine="thread", + n_jobs=4, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, + ) global_job_kwargs = get_global_job_kwargs() # test warning when not setting n_jobs and calling fix_job_kwargs @@ -44,7 +51,12 @@ def test_global_job_kwargs(): job_kwargs_split = fix_job_kwargs({}) assert global_job_kwargs == dict( - pool_engine="thread", n_jobs=1, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 + pool_engine="thread", + n_jobs=1, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, ) set_global_job_kwargs(**job_kwargs) assert get_global_job_kwargs() == job_kwargs @@ -59,7 +71,12 @@ def test_global_job_kwargs(): set_global_job_kwargs(**partial_job_kwargs) global_job_kwargs = get_global_job_kwargs() assert global_job_kwargs == dict( - pool_engine="thread", n_jobs=2, chunk_duration="1s", progress_bar=True, mp_context=None, max_threads_per_worker=1 + pool_engine="thread", + n_jobs=2, + chunk_duration="1s", + progress_bar=True, + mp_context=None, + max_threads_per_worker=1, ) # test that fix_job_kwargs grabs global kwargs new_job_kwargs = dict(n_jobs=cpu_count()) diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 552fe7b00b..88d52ebb1f 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -236,8 +236,6 @@ def test_split_job_kwargs(): assert "other_param" not in job_kwargs and "n_jobs" in job_kwargs and "progress_bar" in job_kwargs - - def func2(segment_index, start_frame, end_frame, worker_dict): time.sleep(0.010) # print(os.getpid(), worker_dict["worker_index"]) @@ -269,13 +267,14 @@ def test_worker_index(): n_jobs=2, handle_returns=True, chunk_duration="200ms", - need_worker_index=True + need_worker_index=True, ) res = processor.run() # we should have a mix of 0 and 1 assert 0 in res assert 1 in res + def test_get_best_job_kwargs(): job_kwargs = get_best_job_kwargs() print(job_kwargs) @@ -298,7 +297,6 @@ def test_get_best_job_kwargs(): # dict(n_jobs=1), # ] - # rec, _, sorting = generate_drifting_recording( # num_units=50, @@ -319,9 +317,6 @@ def test_get_best_job_kwargs(): # print("time included the spawn:", t1-t0) - - - if __name__ == "__main__": # test_divide_segment_into_chunks() # test_ensure_n_jobs() diff --git a/src/spikeinterface/core/tests/test_waveform_tools.py b/src/spikeinterface/core/tests/test_waveform_tools.py index ed27815758..a516e6d42b 100644 --- a/src/spikeinterface/core/tests/test_waveform_tools.py +++ b/src/spikeinterface/core/tests/test_waveform_tools.py @@ -199,7 +199,7 @@ def test_estimate_templates_with_accumulator(): if len(templates_by_worker) > 1: templates_loop = templates_by_worker[0] np.testing.assert_almost_equal(templates, templates_loop, decimal=4) - + # import matplotlib.pyplot as plt # fig, axs = plt.subplots(nrows=2, sharex=True) # for unit_index, unit_id in enumerate(sorting.unit_ids): @@ -212,8 +212,6 @@ def test_estimate_templates_with_accumulator(): # plt.show() - - def test_estimate_templates(): recording, sorting = get_dataset() diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index fb65338a1b..39bc62ae12 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -8,7 +8,6 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") - def make_small_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], @@ -39,10 +38,12 @@ def make_small_analyzer(): return sorting_analyzer + @pytest.fixture(scope="module") def small_sorting_analyzer(): return make_small_analyzer() + @pytest.fixture(scope="module") def sorting_analyzer_simple(): # we need high firing rate for amplitude_cutoff diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 312c3949b3..287439a4f7 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -56,7 +56,9 @@ def test_pca_metrics_multi_processing(small_sorting_analyzer): sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_worker=2, progress_bar=True ) + if __name__ == "__main__": from spikeinterface.qualitymetrics.tests.conftest import make_small_analyzer + small_sorting_analyzer = make_small_analyzer() - test_calculate_pc_metrics(small_sorting_analyzer) \ No newline at end of file + test_calculate_pc_metrics(small_sorting_analyzer)