diff --git a/CHANGELOG.md b/CHANGELOG.md index c8754a9fc..a1b4f4674 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,13 @@ Changelog --------- +8.12.3 +------ +* bugfix: getting training status of subject not present on local server +* skipping of bpod initialization now optional (used in GUI) +* disable button for status LED if not supported by hardware +* tests, type-hints, removal of dead code + 8.12.2 ------ * bugfix: rollback skipping of bpod initialization (possible source of integer overflow) diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 8b15c4f07..355d1f80c 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -430,7 +430,9 @@ def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None): contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE) assert len(self.task_params.STIM_POSITIONS) == 2, "Only two positions are supported" position = position or int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft])) - quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.texp(factor=0.35, min_=0.2, max_=0.5) + quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(scale=0.35, + min_value=0.2, + max_value=0.5) self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period self.trials_table.at[self.trial_num, 'contrast'] = contrast self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi) @@ -711,10 +713,10 @@ def new_block(self): if self.task_params.BLOCK_INIT_5050 and self.block_num == 0: block_len = 90 else: - block_len = int(misc.texp( - factor=self.task_params.BLOCK_LEN_FACTOR, - min_=self.task_params.BLOCK_LEN_MIN, - max_=self.task_params.BLOCK_LEN_MAX + block_len = int(misc.truncated_exponential( + scale=self.task_params.BLOCK_LEN_FACTOR, + min_value=self.task_params.BLOCK_LEN_MIN, + max_value=self.task_params.BLOCK_LEN_MAX )) if self.block_num == 0: if self.task_params.BLOCK_INIT_5050: diff --git a/iblrig/gui/wizard.py b/iblrig/gui/wizard.py index 2752d7524..d982a91a1 100644 --- a/iblrig/gui/wizard.py +++ b/iblrig/gui/wizard.py @@ -232,6 +232,10 @@ def __init__(self, *args, **kwargs): # anydesk_worker.signals.result.connect(lambda var: print(f'Your AnyDesk ID: {var:s}')) # QThreadPool.globalInstance().tryStart(anydesk_worker) + # disable control of LED if Bpod does not have the respective capability + bpod = Bpod(self.model.hardware_settings['device_bpod']['COM_BPOD'], skip_initialization=True) + self.uiPushStatusLED.setEnabled(bpod.can_control_led) + # check for update update_worker = Worker(check_for_updates) update_worker.signals.result.connect(self._on_check_update_result) @@ -651,7 +655,7 @@ def flush(self): self.enable_UI_elements() try: - bpod = Bpod(self.model.hardware_settings['device_bpod']['COM_BPOD']) # bpod is a singleton + bpod = Bpod(self.model.hardware_settings['device_bpod']['COM_BPOD'], skip_initialization=True) bpod.manual_override(bpod.ChannelTypes.OUTPUT, bpod.ChannelNames.VALVE, 1, self.uiPushFlush.isChecked()) except (OSError, exceptions.bpod_error.BpodErrorException): print(traceback.format_exc()) @@ -668,7 +672,7 @@ def toggle_status_led(self, is_toggled: bool): self.enable_UI_elements() try: - bpod = Bpod(self.model.hardware_settings['device_bpod']['COM_BPOD']) + bpod = Bpod(self.model.hardware_settings['device_bpod']['COM_BPOD'], skip_initialization=True) bpod.set_status_led(is_toggled) except (OSError, exceptions.bpod_error.BpodErrorException, AttributeError): self.uiPushStatusLED.setChecked(False) diff --git a/iblrig/hardware.py b/iblrig/hardware.py index 96032b7e0..b0f6614af 100644 --- a/iblrig/hardware.py +++ b/iblrig/hardware.py @@ -28,10 +28,10 @@ class Bpod(BpodIO): + can_control_led = True _instances = {} _lock = threading.Lock() _is_initialized = False - _can_control_led = True def __new__(cls, *args, **kwargs): serial_port = args[0] if len(args) > 0 else '' @@ -44,10 +44,11 @@ def __new__(cls, *args, **kwargs): Bpod._instances[serial_port] = instance return instance - def __init__(self, *args, **kwargs): - # # skip initialization if it has already been performed before - # if self._is_initialized: - # return + def __init__(self, *args, skip_initialization: bool = False, **kwargs): + # skip initialization if it has already been performed before + # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI) + if skip_initialization and self._is_initialized: + return # try to instantiate once for nothing try: @@ -65,7 +66,7 @@ def __init__(self, *args, **kwargs): "Please unplug the Bpod USB cable from the computer and plug it back in to start the task. ") from e self.default_message_idx = 0 self.actions = Bunch({}) - self._can_control_led = self.set_status_led(True) + self.can_control_led = self.set_status_led(True) self._is_initialized = True def close(self) -> None: @@ -182,7 +183,7 @@ def toggle_valve(self, duration=None): @static_vars(supported=True) def set_status_led(self, state: bool) -> bool: - if self._can_control_led and self._arcom is not None: + if self.can_control_led and self._arcom is not None: try: log.info(f'{"en" if state else "dis"}abling Bpod Status LED') command = struct.pack("cB", b":", state) @@ -193,7 +194,7 @@ def set_status_led(self, state: bool) -> bool: pass self._arcom.serial_object.reset_input_buffer() self._arcom.serial_object.reset_output_buffer() - log.error('Bpod device does not support control of the status LED. Please update firmware.') + log.warning('Bpod device does not support control of the status LED. Please update firmware.') return False def valve(self, valve_id: int, state: bool): diff --git a/iblrig/misc.py b/iblrig/misc.py index 094ce2aee..0a2493379 100644 --- a/iblrig/misc.py +++ b/iblrig/misc.py @@ -7,14 +7,12 @@ """ import argparse import datetime -import json import logging from pathlib import Path -from typing import Optional, Union +from typing import Optional, Union, Literal import numpy as np -from iblrig.raw_data_loaders import load_settings FLAG_FILE_NAMES = [ "transfer_me.flag", @@ -82,13 +80,19 @@ def get_task_arguments(parents=None): return _post_parse_arguments(**kwargs) -def _isdatetime(x: str) -> Optional[bool]: +def _is_datetime(x: str) -> bool: """ - Check if string is a date in the format YYYY-MM-DD. + Check if a string is a date in the format YYYY-MM-DD. - :param x: The string to check - :return: True if the string matches the date format, False otherwise. - :rtype: Optional[bool] + Parameters + ---------- + x : str + The string to check. + + Returns + ------- + bool or None + True if the string matches the date format, False otherwise, or None if there's an exception. """ try: datetime.strptime(x, "%Y-%m-%d") @@ -106,83 +110,12 @@ def get_session_path(path: Union[str, Path]) -> Optional[Path]: path = Path(path) sess = None for i, p in enumerate(path.parts): - if p.isdigit() and _isdatetime(path.parts[i - 1]): + if p.isdigit() and _is_datetime(path.parts[i - 1]): sess = Path().joinpath(*path.parts[: i + 1]) return sess -def smooth_rolling_window(x, window_len=11, window="blackman"): - """ - Smooth the data using a window with requested size. - - This method is based on the convolution of a scaled window with the signal. - The signal is prepared by introducing reflected copies of the signal - (with the window size) in both ends so that transient parts are minimized - in the beginning and end part of the output signal. - - :param x: The input signal - :type x: list or numpy.array - :param window_len: The dimension of the smoothing window, - should be an **odd** integer, defaults to 11 - :type window_len: int, optional - :param window: The type of window from ['flat', 'hanning', 'hamming', - 'bartlett', 'blackman'] - flat window will produce a moving average smoothing, - defaults to 'blackman' - :type window: str, optional - :raises ValueError: Smooth only accepts 1 dimension arrays. - :raises ValueError: Input vector needs to be bigger than window size. - :raises ValueError: Window is not one of 'flat', 'hanning', 'hamming', - 'bartlett', 'blackman' - :return: Smoothed array - :rtype: numpy.array - """ - # **NOTE:** length(output) != length(input), to correct this: - # return y[(window_len/2-1):-(window_len/2)] instead of just y. - if isinstance(x, list): - x = np.array(x) - - if x.ndim != 1: - raise ValueError("smooth only accepts 1 dimension arrays.") - - if x.size < window_len: - raise ValueError("Input vector needs to be bigger than window size.") - - if window_len < 3: - return x - - if window not in ["flat", "hanning", "hamming", "bartlett", "blackman"]: - raise ValueError( - "Window is not one of 'flat', 'hanning', 'hamming',\ -'bartlett', 'blackman'" - ) - - s = np.r_[x[window_len - 1: 0: -1], x, x[-1:-window_len:-1]] - # print(len(s)) - if window == "flat": # moving average - w = np.ones(window_len, "d") - else: - w = eval("np." + window + "(window_len)") - - y = np.convolve(w / w.sum(), s, mode="valid") - return y[round((window_len / 2 - 1)): round(-(window_len / 2))] - - -def checkerboard(shape): - return np.indices(shape).sum(axis=0) % 2 - - -def make_square_dvamat(size, dva): - c = np.arange(size) - int(size / 2) - x = np.array([c] * 15) - y = np.rot90(x) - dvamat = np.array(list(zip(y.ravel() * dva, x.ravel() * dva)), dtype="int, int").reshape( - x.shape - ) - return dvamat - - def get_port_events(events: dict, name: str = "") -> list: out: list = [] for k in events: @@ -193,183 +126,143 @@ def get_port_events(events: dict, name: str = "") -> list: return out -def update_buffer(buffer: list, val) -> list: - buffer = np.roll(buffer, -1, axis=0) - buffer[-1] = val - return buffer.tolist() - - -def texp(factor: float = 0.35, min_: float = 0.2, max_: float = 0.5) -> float: - """Truncated exponential - mean = 0.35 - min = 0.2 - max = 0.5 +def truncated_exponential(scale: float = 0.35, min_value: float = 0.2, max_value: float = 0.5) -> float: """ - x = np.random.exponential(factor) - if min_ <= x <= max_: + Generate a truncated exponential random variable within a specified range. + + Parameters + ---------- + scale : float, optional + Scale of the exponential distribution (inverse of rate parameter). Defaults to 0.35. + min_value : float, optional + Minimum value for the truncated range. Defaults to 0.2. + max_value : float, optional + Maximum value for the truncated range. Defaults to 0.5. + + Returns + ------- + float + Truncated exponential random variable. + + Notes + ----- + This function generates a random variable from an exponential distribution + with the specified `scale`. It then checks if the generated value is within + the specified range `[min_value, max_value]`. If it is within the range, it returns + the generated value; otherwise, it recursively generates a new value until it falls + within the specified range. + + The `scale` should typically be greater than or equal to the `min_value` to avoid + potential issues with infinite recursion. + """ + x = np.random.exponential(scale) + if min_value <= x <= max_value: return x else: - return texp(factor=factor, min_=min_, max_=max_) + return truncated_exponential(scale=scale, min_value=min_value, max_value=max_value) -def get_biased_probs(n: int, idx: int = -1, prob: float = 0.5) -> list: +def get_biased_probs(n: int, idx: int = -1, p_idx: float = 0.5) -> list[float]: """ - get_biased_probs [summary] - - Calculate the biased probability for all elements of an array so that - the value has probability of being drawn in respect to the - remaining values. - https://github.com/int-brain-lab/iblrig/issues/74 - For prob == 0.5 - p = [2 / (2 * len(contrast_set) - 1) for x in contrast_set] - p[-1] *= 1 / 2 - For arbitrary probs - p = [1/(n-1 + 0.5)] * (n - 1) - - e.g. get_biased_probs(3, idx=-1, prob=0.5) - >>> [0.4, 0.4, 0.2] - - :param n: The length of the array, i.e. the num of probas to generate - :type n: int - :param idx: The index of the value that has the biased probability, - defaults to -1 - :type idx: int, optional - :param prob: The probability of the idxth value relative top the rest, - defaults to 0.5 - :type prob: float, optional - :return: List of biased probabilities - :rtype: list - + Calculate biased probabilities for all elements of an array such that the + `i`th value has probability `p_i` for being drawn relative to the remaining + values. + + See: https://github.com/int-brain-lab/iblrig/issues/74 + + Parameters + ---------- + n : int + The length of the array, i.e., the number of probabilities to generate. + idx : int, optional + The index of the value that has the biased probability. Defaults to -1. + p_idx : float, optional + The probability of the `idx`-th value relative to the rest. Defaults to 0.5. + + Returns + ------- + List[float] + List of biased probabilities. + + Raises + ------ + ValueError + If `idx` is outside the valid range [-1, n), or if `p_idx` is 0. """ - n_1 = n - 1 - z = n_1 + prob - p = [1 / z] * (n_1 + 1) - p[idx] *= prob + if idx < -1 or idx >= n: + raise ValueError("Invalid index. Index should be in the range [-1, n).") + if n == 1: + return [1.0] + if p_idx == 0: + raise ValueError("Probability must be larger than 0.") + z = n - 1 + p_idx + p = [1 / z] * n + p[idx] *= p_idx return p -def draw_contrast(contrast_set: list, prob_type: str = "biased", idx: int = -1, idx_prob: float = 0.5) -> float: - - if prob_type in ["skew_zero", "biased"]: - p = get_biased_probs(len(contrast_set), idx=idx, prob=idx_prob) +def draw_contrast(contrast_set: list[float], + probability_type: Literal["skew_zero", "biased", "uniform"] = "biased", + idx: int = -1, + idx_probability: float = 0.5) -> float: + """ + Draw a contrast value from a given iterable based to the specified probability type + + Parameters + ---------- + contrast_set : list[float] + The set of contrast values from which to draw. + probability_type : Literal["skew_zero", "biased", "uniform"], optional + The type of probability distribution to use. + - "skew_zero" or "biased": Draws with a biased probability distribution based on idx and idx_probability, + - "uniform": Draws with a uniform probability distribution. + Defaults to "biased". + idx : int, optional + Index for probability manipulation (with "skew_zero" or "biased"), default: -1. + idx_probability : float, optional + Probability for the specified index (with "skew_zero" or "biased"), default: 0.5. + + Returns + ------- + float + The drawn contrast value. + + Raises + ------ + ValueError + If an unsupported `probability_type` is provided. + """ + if probability_type in ["skew_zero", "biased"]: + p = get_biased_probs(n=len(contrast_set), idx=idx, p_idx=idx_probability) return np.random.choice(contrast_set, p=p) - elif prob_type == "uniform": + elif probability_type == "uniform": return np.random.choice(contrast_set) - - -def check_stop_criterions(init_datetime, rt_buffer, trial_num) -> int: - # STOPPING CRITERIONS - # < than 400 trials in 45 minutes - time_up = init_datetime + datetime.timedelta(minutes=45) - if time_up <= datetime.datetime.now() and trial_num <= 400: - return 1 - - # Median response time of latest N = 20 trials > than 5 times - # the median response time and more than 400 trials performed - N, T = 20, 400 - if len(rt_buffer) >= N and trial_num > T: - latest_median = np.median(rt_buffer[-N:]) - all_median = np.median(rt_buffer) - - if latest_median > all_median * 5: - return 2 - - end_time = init_datetime + datetime.timedelta(minutes=90) - if end_time <= datetime.datetime.now(): - return 3 - - return False - - -def create_flag(session_folder_path: str, flag: str) -> None: - if not flag.endswith(".flag"): - flag = flag + ".flag" - if flag not in FLAG_FILE_NAMES: - log.warning(f"Creating unknown flag file {flag} in {session_folder_path}") - - path = Path(session_folder_path) / flag - open(path, "a").close() - - -def draw_session_order(): - first = list(range(0, 4)) - second = list(range(4, 8)) - third = list(range(8, 12)) - for x in [first, second, third]: - np.random.shuffle(x) - first.extend(second) - first.extend(third) - - return first - - -def patch_settings_file(sess_or_file: str, patch: dict) -> None: - sess_or_file = Path(sess_or_file) - if sess_or_file.is_file() and sess_or_file.name.endswith("_iblrig_taskSettings.raw.json"): - session = sess_or_file.parent.parent - file = sess_or_file - elif sess_or_file.is_dir() and sess_or_file.name.isdecimal(): - file = sess_or_file / "raw_behavior_data" / "_iblrig_taskSettings.raw.json" - session = sess_or_file else: - print("not a settings file or a session folder") - return + raise ValueError("Unsupported probability_type. Use 'skew_zero', 'biased', or 'uniform'.") - settings = load_settings(session) - settings.update(patch) - # Rename file on disk keeps pathlib ref to "file" intact - file.rename(file.with_suffix(".json_bk")) - with open(file, "w") as f: - f.write(json.dumps(settings, indent=1)) - f.write("\n") - f.flush() - # Check if properly saved - saved_settings = load_settings(session) - if settings == saved_settings: - file.with_suffix(".json_bk").unlink() - return - - -# TODO: Consider migrating this to ephys_session_file_creator -def generate_position_contrasts( - contrasts: list = [1.0, 0.25, 0.125, 0.0625], - positions: list = [-35, 35], - cp_repeats: int = 20, - shuffle: bool = True, - to_string: bool = False, -): - """generate_position_contrasts generate contrasts and positions - - :param contrasts: Set of contrasts in floats, defaults to [1.0, 0.25, 0.125, 0.0625] - :type contrasts: list, optional - :param positions: Set of positions in int, defaults to [-35, 35] - :type positions: list, optional - :param cp_repeats: Number of repetitions for each contrast position pair, defaults to 20 - :type cp_repeats: int, optional - :param shuffle: Shuffle the result or return sorted, defaults to True - :type shuffle: bool, optional - :param to_string: Return strings instead of int/float pairs, defaults to False - :type to_string: bool, optional - :return: 2D array with positions and contrasts - :rtype: numpy.array() + +def online_std(new_sample: float, new_count: int, old_mean: float, old_std: float) -> tuple[float, float]: + """ + Updates the mean and standard deviation of a group of values after a sample update + + Parameters + ---------- + new_sample : float + The new sample to be included. + new_count : int + The new count of samples (including new_sample). + old_mean : float + The previous mean (N - 1). + old_std : float + The previous standard deviation (N - 1). + + Returns + ------- + tuple[float, float] + Updated mean and standard deviation. """ - # Generate a set of positions and contrasts - pos = sorted(positions * len(contrasts) * cp_repeats) - cont = contrasts * cp_repeats * 2 - - data = np.array([[int(p), c] for p, c in zip(pos, cont)]) - if shuffle: - np.random.shuffle(data) - if to_string: - data = np.array([[str(int(p)), str(c)] for p, c in data]) - return data - - -if __name__ == "__main__": - get_biased_probs(4) - print(draw_contrast([1, 2, 3])) - print(draw_contrast([1, 2, 3, 4, 5])) - print(draw_contrast([1, 2, 3, 4, 5, 6, 7])) - print(draw_contrast([1, 2, 3], prob=0.3, idx=0)) - print(draw_contrast([1, 2, 3, 4, 5], prob=0.5, idx=0)) - print(draw_contrast([1, 2, 3, 4, 5, 6, 7], prob=0.3, idx=-1)) + if new_count == 1: + return new_sample, 0.0 + new_mean = (old_mean * (new_count - 1) + new_sample) / new_count + new_std = np.sqrt((old_std ** 2 * (new_count - 1) + (new_sample - old_mean) * (new_sample - new_mean)) / new_count) + return new_mean, new_std diff --git a/iblrig/online_plots.py b/iblrig/online_plots.py index 5bdfd9e75..511d8247b 100644 --- a/iblrig/online_plots.py +++ b/iblrig/online_plots.py @@ -9,6 +9,8 @@ from pandas.api.types import CategoricalDtype import one.alf.io + +from iblrig.misc import online_std from iblrig.raw_data_loaders import load_task_jsonable from iblutil.util import Bunch @@ -22,22 +24,6 @@ sns.set_style('white') -def online_std(new_sample, count, mean, std): - """ - Updates the mean and standard deviation of a group of values after a sample update - :param new: new sample value - :param count: number of samples after the new addition - :param mu: (N - 1) mean - :param std: (N - 1) standard deviation - :return: - """ - if count == 1: - return new_sample, 0.0 - mean_ = (mean * (count - 1) + new_sample) / count - std_ = np.sqrt((std ** 2 * (count - 1) + (new_sample - mean) * (new_sample - mean_)) / count) - return mean_, std_ - - class DataModel(object): """ The data model is a pure numpy / pandas container for the choice world task. @@ -116,7 +102,7 @@ def __init__(self, task_file): self.last_contrasts[ileft, 0] = np.abs(self.last_trials.signed_contrast[ileft]) self.last_contrasts[iright, 1] = np.abs(self.last_trials.signed_contrast[iright]) - def update_trial(self, trial_data, bpod_data): + def update_trial(self, trial_data, bpod_data) -> None: # update counters self.time_elapsed = bpod_data['Trial end timestamp'] - bpod_data['Bpod start timestamp'] if self.time_elapsed <= (ENGAGED_CRITIERION['secs']): @@ -133,15 +119,15 @@ def update_trial(self, trial_data, bpod_data): self.psychometrics.loc[indexer, ('count')] += 1 self.psychometrics.loc[indexer, ('response_time')], self.psychometrics.loc[indexer, ('response_time_std')] = online_std( new_sample=trial_data.response_time, - count=self.psychometrics.loc[indexer, ('count')], - mean=self.psychometrics.loc[indexer, ('response_time')], - std=self.psychometrics.loc[indexer, ('response_time_std')] + new_count=self.psychometrics.loc[indexer, ('count')], + old_mean=self.psychometrics.loc[indexer, ('response_time')], + old_std=self.psychometrics.loc[indexer, ('response_time_std')] ) self.psychometrics.loc[indexer, ('choice')], self.psychometrics.loc[indexer, ('choice_std')] = online_std( new_sample=float(choice), - count=self.psychometrics.loc[indexer, ('count')], - mean=self.psychometrics.loc[indexer, ('choice')], - std=self.psychometrics.loc[indexer, ('choice_std')] + new_count=self.psychometrics.loc[indexer, ('count')], + old_mean=self.psychometrics.loc[indexer, ('choice')], + old_std=self.psychometrics.loc[indexer, ('choice_std')] ) # update last trials table self.last_trials = self.last_trials.shift(-1) @@ -268,7 +254,7 @@ def update_trial(self, trial_data, bpod_data): self.data.update_trial(trial_data, bpod_data) self.update_graphics(pupdate=trial_data.stim_probability_left) - def update_graphics(self, pupdate=None): + def update_graphics(self, pupdate: float | None = None): background_color = self.data.compute_end_session_criteria() h = self.h h.fig.set_facecolor(background_color) @@ -293,10 +279,10 @@ def update_graphics(self, pupdate=None): h.ax_performance.set(ylim=[0, (self.data.ntrials // 50 + 1) * 50]) @property - def _session_string(self): + def _session_string(self) -> str: return ' - '.join(self.data.session_path.parts[-3:]) if self.data.session_path != "" else "" - def run(self, task_file): + def run(self, task_file: Path | str) -> None: """ This methods is for online use, it will watch for a file in conjunction with an iblrigv8 running task :param task_file: diff --git a/iblrig/path_helper.py b/iblrig/path_helper.py index 7d59b1edd..73760bd54 100644 --- a/iblrig/path_helper.py +++ b/iblrig/path_helper.py @@ -36,7 +36,8 @@ def iterate_previous_sessions(subject_name, task_name, n=1, **kwargs): if rig_paths.remote_subjects_folder is not None: remote_sessions = _iterate_protocols( rig_paths.remote_subjects_folder.joinpath(subject_name), task_name=task_name, n=n) - sessions.extend(remote_sessions) + if remote_sessions is not None: + sessions.extend(remote_sessions) _, ises = np.unique([s['session_stub'] for s in sessions], return_index=True) sessions = [sessions[i] for i in ises] return sessions diff --git a/iblrig/session_creator.py b/iblrig/session_creator.py index a3c4870f8..bca61b018 100644 --- a/iblrig/session_creator.py +++ b/iblrig/session_creator.py @@ -13,7 +13,7 @@ def draw_position(position_set, stim_probability_left) -> int: def draw_block_len(factor, min_=20, max_=100): - return int(misc.texp(factor=factor, min_=min_, max_=max_)) + return int(misc.truncated_exponential(scale=factor, min_value=min_, max_value=max_)) # EPHYS CHOICE WORLD @@ -38,7 +38,7 @@ def make_ephysCW_pc(prob_type='biased'): len_block.append(draw_block_len(60, min_=20, max_=100)) for x in range(len_block[-1]): p = draw_position([-35, 35], prob_left) - c = misc.draw_contrast(contrasts, prob_type=prob_type) + c = misc.draw_contrast(contrasts, probability_type=prob_type) pc = np.append(pc, np.array([[p, c, prob_left]]), axis=0) # do this in PC space prob_left = np.round(np.abs(1 - prob_left), 1) @@ -52,7 +52,7 @@ def make_ephysCW_pcqs(pc): qperiod = [] for i in pc: sphase.append(np.random.uniform(0, 2 * math.pi)) - qperiod.append(qperiod_base + misc.texp(factor=0.35, min_=0.2, max_=0.5)) + qperiod.append(qperiod_base + misc.truncated_exponential(scale=0.35, min_value=0.2, max_value=0.5)) qs = np.array([qperiod, sphase]).T pcqs = np.append(pc, qs, axis=1) perm = [0, 1, 3, 4, 2] diff --git a/iblrig/test/test_misc.py b/iblrig/test/test_misc.py new file mode 100644 index 000000000..926cebfd3 --- /dev/null +++ b/iblrig/test/test_misc.py @@ -0,0 +1,41 @@ +import unittest + +import numpy as np +from scipy import stats + +from iblrig import misc +from iblrig.misc import online_std + + +class TestMisc(unittest.TestCase): + def test_draw_contrast(self): + n_draws = 1000 + n_contrasts = 10 + contrast_set = np.linspace(0, 1, n_contrasts) + + def assert_distribution(values: list[int], f_exp: list[float] | None = None) -> None: + f_obs = np.unique(values, return_counts=True)[1] + assert stats.chisquare(f_obs, f_exp).pvalue > 0.05 + + # uniform distribution + contrasts = [misc.draw_contrast(contrast_set, "uniform") for i in range(n_draws)] + assert_distribution(contrasts) + + # biased distribution + for p_idx in [0.25, 0.5, 0.75, 1.25]: + contrasts = [misc.draw_contrast(contrast_set, "biased", 0, p_idx) for i in range(n_draws)] + expected = np.ones(n_contrasts) + expected[0] = p_idx + expected = expected / expected.sum() * n_draws + assert_distribution(contrasts, expected) + + self.assertRaises(ValueError, misc.draw_contrast, [], "incorrect_type") # assert exception for incorrect type + self.assertRaises(ValueError, misc.draw_contrast, [0, 1], "biased", 2) # assert exception for out-of-range index + + def test_online_std(self): + n = 41 + b = np.random.rand(n) + a = b[:-1] + mu, std = online_std(new_sample=b[-1], new_count=n, old_mean=np.mean(a), old_std=np.std(a)) + np.testing.assert_almost_equal(std, np.std(b)) + np.testing.assert_almost_equal(mu, np.mean(b)) diff --git a/iblrig/test/test_online_plots.py b/iblrig/test/test_online_plots.py index cbcf08a08..f37a05bc0 100644 --- a/iblrig/test/test_online_plots.py +++ b/iblrig/test/test_online_plots.py @@ -11,17 +11,6 @@ matplotlib.use('Agg') # avoid pyqt testing issues -class TestOnlineStd(unittest.TestCase): - - def test_online_std(self): - n = 41 - b = np.random.rand(n) - a = b[:-1] - mu, std = op.online_std(new_sample=b[-1], count=n, mean=np.mean(a), std=np.std(a)) - np.testing.assert_almost_equal(std, np.std(b)) - np.testing.assert_almost_equal(mu, np.mean(b)) - - class TestOnlinePlots(unittest.TestCase): @classmethod def setUpClass(cls) -> None: diff --git a/iblrig/tools.py b/iblrig/tools.py index 08f6b6536..787d5108f 100644 --- a/iblrig/tools.py +++ b/iblrig/tools.py @@ -48,10 +48,10 @@ def get_anydesk_id(silent: bool = False) -> Optional[str]: anydesk_id = None try: if cmd := shutil.which('anydesk'): - cmd = Path(cmd) + pass elif os.name == 'nt': - cmd = Path(os.environ["ProgramFiles(x86)"], 'AnyDesk', 'anydesk.exe') - if not cmd.exists(): + cmd = str(Path(os.environ["ProgramFiles(x86)"], 'AnyDesk', 'anydesk.exe')) + if cmd is None or not Path(cmd).exists(): raise FileNotFoundError("AnyDesk executable not found") proc = subprocess.Popen([cmd, '--get-id'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) diff --git a/iblrig/version_management.py b/iblrig/version_management.py index d92754ebd..afcdbc98d 100644 --- a/iblrig/version_management.py +++ b/iblrig/version_management.py @@ -36,7 +36,7 @@ def check_for_updates() -> tuple[bool, str]: v_local = get_local_version() v_remote = get_remote_version() - if all((v_remote, v_local)): + if v_remote is not None and v_local is not None: v_remote_base = version.parse(v_remote.base_version) v_local_base = version.parse(v_local.base_version) @@ -144,7 +144,7 @@ def get_branch() -> Union[str, None]: ----- This method will only work with installations managed through Git. """ - if get_branch.branch: + if get_branch.branch is not None: return get_branch.branch if not IS_GIT: log.error('This installation of iblrig is not managed through git') @@ -205,7 +205,7 @@ def get_changelog() -> str: This method relies on the presence of a CHANGELOG.md file either in the repository or locally. """ - if get_changelog.changelog: + if get_changelog.changelog is not None: return get_changelog.changelog try: changelog = requests.get(f'https://raw.githubusercontent.com/int-brain-lab/iblrig/{get_branch()}/CHANGELOG.md', @@ -235,7 +235,7 @@ def get_remote_version() -> Union[version.Version, None]: ----- This method will only work with installations managed through Git. """ - if get_remote_version.remote_version: + if get_remote_version.remote_version is not None: log.debug(f'Using cached remote version: {get_remote_version.remote_version}') return get_remote_version.remote_version @@ -328,6 +328,6 @@ def upgrade() -> int: check_call(["git", "reset", "--hard"], cwd=BASE_DIR) check_call(["git", "pull", "--tags"], cwd=BASE_DIR) - check_call(["pip", "install", "-U", "pip"]) - check_call(["pip", "install", "-U", "-e", "."]) + check_call([sys.executable, "-m", "pip", "install", "-U", "pip"], cwd=BASE_DIR) + check_call([sys.executable, "-m", "pip", "install", "-U", "-e", BASE_DIR], cwd=BASE_DIR) return 0 diff --git a/iblrig_tasks/_iblrig_tasks_ImagingChoiceWorld/task.py b/iblrig_tasks/_iblrig_tasks_ImagingChoiceWorld/task.py index 908a47d41..b88a1cff4 100644 --- a/iblrig_tasks/_iblrig_tasks_ImagingChoiceWorld/task.py +++ b/iblrig_tasks/_iblrig_tasks_ImagingChoiceWorld/task.py @@ -11,7 +11,7 @@ def draw_quiescent_period(self): For this task we double the quiescence period texp draw and remove the absolute offset of 200ms. The resulting is a truncated exp distribution between 400ms and 1 sec """ - return iblrig.misc.texp(factor=0.35 * 2, min_=0.2 * 2, max_=0.5 * 2) + return iblrig.misc.truncated_exponential(factor=0.35 * 2, min_value=0.2 * 2, max_value=0.5 * 2) if __name__ == "__main__": # pragma: no cover diff --git a/pyproject.toml b/pyproject.toml index 1e63f133c..e71b10a41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ DEV = [ "sphinx-autobuild", "sphinx_lesson", "sphinx_rtd_theme", - "pre-commit" + "pre-commit", + "mypy" ] [project.scripts] @@ -67,3 +68,6 @@ version = {attr = "iblrig.__version__"} [tool.setuptools.packages] find = {} + +[tool.mypy] +ignore_missing_imports = true