diff --git a/CHANGELOG.md b/CHANGELOG.md index 36ed4044b..4eeef9ea4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ Changelog * automated validation of rig components * adaptive reward parameter for trainingPhaseChoiceWorld * add validate_video entry-point +* switch from flake8 to ruff for linting & code-checks +* automatically set correct trigger-mode when setting up the cameras +* support rotary encoder on arbitrary module port +* add ambient sensor reading back to trial log +* allow negative stimulus gain (reverse wheel contingency) 8.18.0 ------ diff --git a/iblrig/base_choice_world.py b/iblrig/base_choice_world.py index 0474a3527..3da2422db 100644 --- a/iblrig/base_choice_world.py +++ b/iblrig/base_choice_world.py @@ -159,8 +159,6 @@ def _run(self): This is the method that runs the task with the actual state machine :return: """ - # make the bpod send spacer signals to the main sync clock for protocol discovery - self.send_spacers() time_last_trial_end = time.time() for i in range(self.task_params.NTRIALS): # Main loop # t_overhead = time.time() @@ -186,6 +184,7 @@ def _run(self): self.bpod.run_state_machine(sma) # Locks until state machine 'exit' is reached time_last_trial_end = time.time() self.trial_completed(self.bpod.session.current_trial.export()) + self.ambient_sensor_table.loc[i] = self.bpod.get_ambient_sensor_reading() self.show_trial_log() # handle pause and stop events @@ -427,15 +426,16 @@ def next_trial(self): pass @property - def reward_amount(self): + def default_reward_amount(self): return self.task_params.REWARD_AMOUNT_UL - def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None): """Draw next trial variables. - This is called by the `next_trial` method before updating the Bpod state machine. This also calls :meth:`send_trial_info_to_bonsai`. + This is called by the `next_trial` method before updating the Bpod state machine. This also """ + + def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None, reward_amount=None): if contrast is None: contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE) assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported' @@ -443,6 +443,7 @@ def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None): quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential( scale=0.35, min_value=0.2, max_value=0.5 ) + reward_amount = self.default_reward_amount if reward_amount is None else reward_amount self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period self.trials_table.at[self.trial_num, 'contrast'] = contrast self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi) @@ -452,7 +453,7 @@ def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None): self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num self.trials_table.at[self.trial_num, 'position'] = position - self.trials_table.at[self.trial_num, 'reward_amount'] = self.reward_amount + self.trials_table.at[self.trial_num, 'reward_amount'] = reward_amount self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft self.send_trial_info_to_bonsai() @@ -527,13 +528,17 @@ def quiescent_period(self): def position(self): return self.trials_table.at[self.trial_num, 'position'] + @property + def reverse_wheel(self): + return self.task_params.STIM_GAIN < 0 + @property def event_error(self): - return self.device_rotary_encoder.THRESHOLD_EVENTS[self.position] + return self.device_rotary_encoder.THRESHOLD_EVENTS[-self.position if self.reverse_wheel else self.position] @property def event_reward(self): - return self.device_rotary_encoder.THRESHOLD_EVENTS[-self.position] + return self.device_rotary_encoder.THRESHOLD_EVENTS[self.position if self.reverse_wheel else -self.position] class HabituationChoiceWorldSession(ChoiceWorldSession): @@ -628,7 +633,9 @@ def _run(self): # starts online plotting if self.interactive: subprocess.Popen( - ['viewsession', str(self.paths['DATA_FILE_PATH'])], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT + ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])], + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, ) super()._run() @@ -783,7 +790,7 @@ def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, self.trials_table['debias_trial'] = np.zeros(NTRIALS_INIT, dtype=bool) @property - def reward_amount(self): + def default_reward_amount(self): return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL) def get_subject_training_info(self): diff --git a/iblrig/base_tasks.py b/iblrig/base_tasks.py index cc4c54bb3..6bfa28535 100644 --- a/iblrig/base_tasks.py +++ b/iblrig/base_tasks.py @@ -34,7 +34,7 @@ import pybpodapi from ibllib.oneibl.registration import IBLRegistrationClient from iblrig import sound -from iblrig.constants import BASE_PATH, BONSAI_EXE +from iblrig.constants import BASE_PATH, BONSAI_EXE, PYSPIN_AVAILABLE from iblrig.frame2ttl import Frame2TTL from iblrig.hardware import SOFTCODE, Bpod, MyRotaryEncoder, sound_device_factory from iblrig.hifi import HiFi @@ -191,6 +191,8 @@ def _init_paths(self, append: bool = False): >>> C:\iblrigv8_data\mainenlab\Subjects\SWC_043\2019-01-01\001\raw_task_data_00 # noqa DATA_FILE_PATH: contains the bpod trials >>> C:\iblrigv8_data\mainenlab\Subjects\SWC_043\2019-01-01\001\raw_task_data_00\_iblrig_taskData.raw.jsonable # noqa + SETTINGS_FILE_PATH: contains the task settings + >>>C:\iblrigv8_data\mainenlab\Subjects\SWC_043\2019-01-01\001\raw_task_data_00\_iblrig_taskSettings.raw.json # noqa """ rig_computer_paths = iblrig.path_helper.get_local_and_remote_paths( local_path=self.iblrig_settings['iblrig_local_data_path'], @@ -230,6 +232,7 @@ def _init_paths(self, append: bool = False): self.session_info.SESSION_NUMBER = int(paths.SESSION_FOLDER.name) paths.SESSION_RAW_DATA_FOLDER = paths.SESSION_FOLDER.joinpath(paths.TASK_COLLECTION) paths.DATA_FILE_PATH = paths.SESSION_RAW_DATA_FOLDER.joinpath('_iblrig_taskData.raw.jsonable') + paths.SETTINGS_FILE_PATH = paths.SESSION_RAW_DATA_FOLDER.joinpath('_iblrig_taskSettings.raw.json') return paths def _setup_loggers(self, level='INFO', level_bpod='WARNING', file=None): @@ -336,18 +339,19 @@ def _make_task_parameters_dict(self): output_dict.update(patch_dict) return output_dict - def save_task_parameters_to_json_file(self, destination_folder=None) -> Path: + def save_task_parameters_to_json_file(self, destination_folder: Path | None = None) -> Path: """ - Given a session object, collects the various settings and parameters of the session and outputs them to a JSON file + Collects the various settings and parameters of the session and outputs them to a JSON file Returns ------- Path to the resultant JSON file """ output_dict = self._make_task_parameters_dict() - destination_folder = destination_folder or self.paths.SESSION_RAW_DATA_FOLDER - # Output dict to json file - json_file = destination_folder.joinpath('_iblrig_taskSettings.raw.json') + if destination_folder: + json_file = destination_folder.joinpath('_iblrig_taskSettings.raw.json') + else: + json_file = self.paths['SETTINGS_FILE_PATH'] json_file.parent.mkdir(parents=True, exist_ok=True) with open(json_file, 'w') as outfile: json.dump(output_dict, outfile, indent=4, sort_keys=True, default=str) # converts datetime objects to string @@ -368,7 +372,10 @@ def one(self): ) try: self._one = ONE( - base_url=str(self.iblrig_settings['ALYX_URL']), username=self.iblrig_settings['ALYX_USER'], mode='remote' + base_url=str(self.iblrig_settings['ALYX_URL']), + username=self.iblrig_settings['ALYX_USER'], + mode='remote', + cache_rest=None, ) log.info('instantiated ' + info_str) except Exception: @@ -453,7 +460,7 @@ def mock(self): def create_session(self): # create the session path and save json parameters in the task collection folder # this will also create the protocol folder - self.save_task_parameters_to_json_file() + self.paths['TASK_PARAMETERS_FILE'] = self.save_task_parameters_to_json_file() # enable file logging logfile = self.paths.SESSION_RAW_DATA_FOLDER.joinpath('_ibl_log.info-acquisition.log') self._setup_loggers(level=self._logger.level, file=logfile) @@ -651,10 +658,13 @@ def start_mixin_bonsai_cameras(self): configuration = self.hardware_settings.device_cameras[self.config] if (workflow_file := self._camera_mixin_bonsai_get_workflow_file(configuration, 'setup')) is None: return - # TODO: Disable Trigger in Bonsai workflow - PySpin won't help here - # if PYSPIN_AVAILABLE: - # from iblrig.video_pyspin import enable_camera_trigger - # enable_camera_trigger(True) + + # enable trigger of cameras (so Bonsai can disable it again ... sigh) + if PYSPIN_AVAILABLE: + from iblrig.video_pyspin import enable_camera_trigger + + enable_camera_trigger(True) + call_bonsai(workflow_file, wait=True) # TODO Parameterize using configuration cameras log.info('Bonsai cameras setup module loaded: OK') @@ -793,7 +803,8 @@ def start_mixin_bpod(self): self.bpod.set_status_led(False) assert self.bpod.is_connected log.info('Bpod hardware module loaded: OK') - # self.send_spacers() + # make the bpod send spacer signals to the main sync clock for protocol discovery + self.send_spacers() def send_spacers(self): log.info('Starting task by sending a spacer signal on BNC1') @@ -971,16 +982,15 @@ def start_mixin_sound(self): match self.hardware_settings.device_sound['OUTPUT']: case 'harp': assert self.bpod.sound_card is not None, 'No harp sound-card connected to Bpod' - module_port = f'Serial{self.bpod.sound_card.serial_port}' sound.configure_sound_card( sounds=[self.sound.GO_TONE, self.sound.WHITE_NOISE], indexes=[self.task_params.GO_TONE_IDX, self.task_params.WHITE_NOISE_IDX], sample_rate=self.sound['samplerate'], ) self.bpod.define_harp_sounds_actions( + module=self.bpod.sound_card, go_tone_index=self.task_params.GO_TONE_IDX, noise_index=self.task_params.WHITE_NOISE_IDX, - sound_port=module_port, ) case 'hifi': module = self.bpod.get_module('^HiFi') @@ -991,12 +1001,10 @@ def start_mixin_sound(self): hifi.load(index=self.task_params.WHITE_NOISE_IDX, data=self.sound.WHITE_NOISE) hifi.push() hifi.close() - module_port = f'Serial{module.serial_port}' self.bpod.define_harp_sounds_actions( + module=module, go_tone_index=self.task_params.GO_TONE_IDX, noise_index=self.task_params.WHITE_NOISE_IDX, - sound_port=module_port, - module=module, ) case _: self.bpod.define_xonar_sounds_actions() diff --git a/iblrig/commands.py b/iblrig/commands.py index 2cba5fa95..0249f6d9b 100644 --- a/iblrig/commands.py +++ b/iblrig/commands.py @@ -338,17 +338,19 @@ def remove_local_sessions(weeks=2, local_path=None, remote_path=None, dry=False, return removed -def viewsession(): +def view_session(): """ Entry point for command line: usage as below - >>> viewsession /full/path/to/jsonable/_iblrig_taskData.raw.jsonable + >>> view_session /full/path/to/jsonable/_iblrig_taskData.raw.jsonable :return: None """ parser = argparse.ArgumentParser() parser.add_argument('file_jsonable', help='full file path to jsonable file') + parser.add_argument('file_settings', help='full file path to settings file', nargs='?', default=None) args = parser.parse_args() - self = OnlinePlots() - self.run(Path(args.file_jsonable)) + + online_plots = OnlinePlots(task_file=args.file_jsonable, settings_file=args.file_settings) + online_plots.run(task_file=args.file_jsonable) def flush(): diff --git a/iblrig/gui/wizard.py b/iblrig/gui/wizard.py index e129d2bf4..5bcf57815 100644 --- a/iblrig/gui/wizard.py +++ b/iblrig/gui/wizard.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from pathlib import Path +import numpy as np import pyqtgraph as pg from pydantic import ValidationError from PyQt5 import QtCore, QtGui, QtWidgets @@ -45,7 +46,7 @@ from iblrig.path_helper import load_pydantic_yaml from iblrig.pydantic_definitions import HardwareSettings, RigSettings from iblrig.tools import alyx_reachable, get_anydesk_id, internet_available -from iblrig.version_management import check_for_updates, get_changelog, is_dirty +from iblrig.version_management import check_for_updates, get_changelog from iblutil.util import setup_logger from one.webclient import AlyxClient from pybpodapi.exceptions.bpod_error import BpodErrorException @@ -825,6 +826,9 @@ def controls_for_extra_parameters(self): ) widget.valueChanged.emit(widget.value()) + case 'reward_set_ul': + label = 'Reward Set, μl' + case 'adaptive_gain': label = 'Stimulus Gain' minimum = 0 @@ -844,6 +848,7 @@ def controls_for_extra_parameters(self): case 'stim_gain': label = 'Stimulus Gain' + widget.setMinimum(-np.inf) widget.wheelEvent = lambda event: None layout.addRow(self.tr(label), widget) diff --git a/iblrig/hardware.py b/iblrig/hardware.py index 072479545..7c7f93956 100644 --- a/iblrig/hardware.py +++ b/iblrig/hardware.py @@ -13,11 +13,12 @@ from collections.abc import Callable from enum import IntEnum from pathlib import Path -from typing import Literal +from typing import Annotated, Literal import numpy as np import serial import sounddevice as sd +from annotated_types import Ge, Le from pydantic import validate_call from serial.tools import list_ports @@ -31,6 +32,10 @@ SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA']) +# some annotated types +Uint8 = Annotated[int, Ge(0), Le(255)] +ActionIdx = Annotated[int, Ge(1), Le(255)] + log = logging.getLogger(__name__) @@ -99,6 +104,10 @@ def rotary_encoder(self): def sound_card(self): return self.get_module('sound_card') + @property + def ambient_module(self): + return self.get_module('^AmbientModule') + def get_module(self, module_name: str) -> BpodModule | None: """Get module by name @@ -124,7 +133,8 @@ def get_module(self, module_name: str) -> BpodModule | None: if len(modules) > 0: return modules[0] - def _define_message(self, module: BpodModule | int, message: list[int]) -> int: + @validate_call(config={'arbitrary_types_allowed': True}) + def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx: """Define a serial message to be sent to a Bpod module as an output action within a state Parameters @@ -150,17 +160,14 @@ def _define_message(self, module: BpodModule | int, message: list[int]) -> int: will then be used as such in StateMachine: >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)] """ - if isinstance(module, int): - pass - elif isinstance(module, BpodModule): + if isinstance(module, BpodModule): module = module.serial_port - else: - raise TypeError message_id = len(self.serial_messages) + 1 self.load_serial_message(module, message_id, message) self.serial_messages.update({message_id: {'target_module': module, 'message': message}}) return message_id + @validate_call(config={'arbitrary_types_allowed': True}) def define_xonar_sounds_actions(self): self.actions.update( { @@ -170,48 +177,45 @@ def define_xonar_sounds_actions(self): } ) - def define_harp_sounds_actions( - self, go_tone_index: int = 2, noise_index: int = 3, sound_port: str = 'Serial3', module: BpodModule | None = None - ): - if module is None: - module = self.sound_card + def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None: + module_port = f"Serial{module.serial_port if module is not None else ''}" self.actions.update( { - 'play_tone': (sound_port, self._define_message(module, [ord('P'), go_tone_index])), - 'play_noise': (sound_port, self._define_message(module, [ord('P'), noise_index])), - 'stop_sound': (sound_port, ord('X')), + 'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])), + 'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])), + 'stop_sound': (module_port, ord('X')), } ) - def define_rotary_encoder_actions(self, re_port='Serial1'): - """ - Each output action is a tuple with the port and the message id - :param go_tone_index: - :param noise_index: - :return: - """ + def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None: + if module is None: + module = self.rotary_encoder + module_port = f"Serial{module.serial_port if module is not None else ''}" self.actions.update( { 'rotary_encoder_reset': ( - re_port, - self._define_message( - self.rotary_encoder, [RotaryEncoder.COM_SETZEROPOS, RotaryEncoder.COM_ENABLE_ALLTHRESHOLDS] - ), + module_port, + self._define_message(module, [RotaryEncoder.COM_SETZEROPOS, RotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]), ), - 'bonsai_hide_stim': (re_port, self._define_message(self.rotary_encoder, [ord('#'), 1])), - 'bonsai_show_stim': (re_port, self._define_message(self.rotary_encoder, [ord('#'), 8])), - 'bonsai_closed_loop': (re_port, self._define_message(self.rotary_encoder, [ord('#'), 3])), - 'bonsai_freeze_stim': (re_port, self._define_message(self.rotary_encoder, [ord('#'), 4])), - 'bonsai_show_center': (re_port, self._define_message(self.rotary_encoder, [ord('#'), 5])), + 'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])), + 'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])), + 'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])), + 'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])), + 'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])), } ) def get_ambient_sensor_reading(self): - ambient_module = [x for x in self.modules if x.name == 'AmbientModule1'][0] - ambient_module.start_module_relay() - self.bpod_modules.module_write(ambient_module, 'R') - reply = self.bpod_modules.module_read(ambient_module, 12) - ambient_module.stop_module_relay() + if self.ambient_module is None: + return { + 'Temperature_C': np.NaN, + 'AirPressure_mb': np.NaN, + 'RelativeHumidity': np.NaN, + } + self.ambient_module.start_module_relay() + self.bpod_modules.module_write(self.ambient_module, 'R') + reply = self.bpod_modules.module_read(self.ambient_module, 12) + self.ambient_module.stop_module_relay() return { 'Temperature_C': np.frombuffer(bytes(reply[:4]), np.float32)[0], diff --git a/iblrig/hardware_validation.py b/iblrig/hardware_validation.py index 00591134c..5551527b4 100644 --- a/iblrig/hardware_validation.py +++ b/iblrig/hardware_validation.py @@ -25,11 +25,10 @@ from iblrig.pydantic_definitions import HardwareSettings, RigSettings from iblrig.serial_singleton import SerialSingleton, filter_ports from iblrig.tools import ANSI, get_inheritors, internet_available +from iblrig.version_management import get_branch, is_dirty from pybpodapi.bpod_modules.bpod_module import BpodModule from pybpodapi.state_machine import StateMachine -from iblrig.version_management import is_dirty, get_branch - log = logging.getLogger(__name__) @@ -163,7 +162,11 @@ def _run(self): yield Result(Status.SKIP, f'No serial port defined for {self.name}') return False elif next((p for p in list_ports.comports() if p.device == self.port), None) is None: - yield Result(Status.FAIL, f'{self.port} is not a valid serial port', solution='Double check!') + yield Result( + Status.FAIL, + f'{self.port} is not a valid serial port', + solution='Check serial port setting in hardware_settings.yaml', + ) return False else: try: @@ -175,7 +178,12 @@ def _run(self): f'Serial Number: {self.port_info.serial_number}', ) except SerialException as e: - yield Result(Status.FAIL, f'Serial device on {self.port} cannot be connected to', exception=e) + yield Result( + Status.FAIL, + f'{self.name} on {self.port} cannot be connected to', + solution='Try power-cycling the device', + exception=e, + ) return False # first, test for properties of the serial port without opening the latter (VID, PID, etc) @@ -195,7 +203,11 @@ def _run(self): yield Result(Status.PASS, f'Serial device positively identified as {self.name}') return True else: - yield Result(Status.FAIL, f'Serial device on {self.port} does NOT seem to be a {self.name}') + yield Result( + Status.FAIL, + f'Serial device on {self.port} does NOT seem to be a {self.name}', + solution='Check serial port setting in hardware_settings.yaml', + ) return False @@ -328,7 +340,9 @@ def _run(self): bpod = Bpod(self.hardware_settings.device_bpod.COM_BPOD, skip_initialization=False) yield Result(Status.PASS, 'Successfully connected to Bpod using pybpod') except Exception as e: - yield Result(Status.FAIL, 'Could not connect to Bpod using pybpod', exception=e) + yield Result( + Status.FAIL, 'Could not connect to Bpod using pybpod', solution='Try power-cycling the Bpod', exception=e + ) return False # return connected modules @@ -365,7 +379,11 @@ def _run(self): with Cameras() as cameras: if len(cameras) == 0: - yield Result(Status.FAIL, 'Could not find a camera connected to the computer') + yield Result( + Status.FAIL, + 'Could not find a camera connected to the computer', + solution='Connect a camera on one of the computers USB ports', + ) return False else: yield Result( @@ -390,7 +408,12 @@ def _run(self): bpod.run_state_machine(sma) triggers = [i.host_timestamp for i in bpod.session.current_trial.events_occurrences if i.content == 'Port1In'] if len(triggers) == 0: - yield Result(Status.FAIL, "No TTL detected on Bpod's behavior port #1") + yield Result( + Status.FAIL, + "No TTL detected on Bpod's behavior port #1", + solution='Check the wiring between camera and valve driver board and make sure the latter is connected ' + "to Bpod's behavior port #1", + ) return False else: yield Result(Status.PASS, "Detected camera TTL on Bpod's behavior port #1") @@ -399,7 +422,7 @@ def _run(self): if isclose(trigger_rate, target_rate, rel_tol=0.1): yield Result(Status.PASS, f'Measured TTL rate: {trigger_rate:.1f} Hz') else: - yield Result(Status.WARN, f'Measured TTL rate: {trigger_rate:.1f} Hz') + yield Result(Status.WARN, f'Measured TTL rate: {trigger_rate:.1f} Hz (expecting {target_rate} Hz)') return True @@ -470,12 +493,14 @@ def _run(self): return True else: yield Result( - Status.FAIL, 'Could not find UltraMic 200K microphone', solution='Make sure that the microphone is plugged in' + Status.FAIL, + 'Could not find UltraMic 200K microphone', + solution='Make sure that the microphone is connected to the PC via USB', ) return False -class GitValidator(Validator): +class ValidatorGit(Validator): _name = 'Git' def _run(self): @@ -489,19 +514,19 @@ def _run(self): if this_branch != main_branch: yield Result( Status.WARN, - f'Working tree of IBLRIG is on Git branch `{this_branch}`', - solution=f'Issue `git checkout {main_branch}` to switch to `{main_branch}` branch' + f"Working tree of IBLRIG is on Git branch '{this_branch}'", + solution=f"Issue 'git checkout {main_branch}' to switch to '{main_branch}' branch", ) return_status = False else: - yield Result(Status.PASS, f'Working tree of IBLRIG is on Git branch `{main_branch}`') + yield Result(Status.PASS, f"Working tree of IBLRIG is on Git branch '{main_branch}'") if is_dirty(): yield Result( Status.WARN, "Working tree of IBLRIG contains local changes - don't expect things to work as intended!", - solution='To list files that have been changed locally, issue `git diff --name-only`. ' - 'Issue `git reset --hard` to reset the repository to its default state', + solution="To list files that have been changed locally, issue 'git diff --name-only'. " + "Issue 'git reset --hard' to reset the repository to its default state", ) return_status = False else: diff --git a/iblrig/online_plots.py b/iblrig/online_plots.py index 82185f4b5..0a51cfb4d 100644 --- a/iblrig/online_plots.py +++ b/iblrig/online_plots.py @@ -1,5 +1,6 @@ import datetime import json +import logging import time from pathlib import Path @@ -18,9 +19,10 @@ NTRIALS_INIT = 2000 NTRIALS_PLOT = 20 # do not edit - this is used also to enforce the completion criteria CONTRAST_SET = np.array([0, 1 / 16, 1 / 8, 1 / 4, 1 / 2, 1]) -PROBABILITY_SET = np.array([0.2, 0.5, 0.8]) # if the mouse does less than 400 trials in the first 45mins it's disengaged ENGAGED_CRITIERION = {'secs': 45 * 60, 'trial_count': 400} + +log = logging.getLogger(__name__) sns.set_style('darkgrid') @@ -35,11 +37,21 @@ class DataModel: """ task_settings = None + probability_set = np.array([0.2, 0.5, 0.8]) + ntrials = 0 + ntrials_correct = 0 + ntrials_nan = np.nan + ntrials_engaged = 0 # trials happening within the first 400s + percent_correct = np.nan + percent_error = np.nan + water_delivered = 0.0 + time_elapsed = 0.0 - def __init__(self, task_file): + def __init__(self, settings_file: Path | str, task_file: Path | str): """ - Can be instantiated empty or from an existing jsonable file from any rig version - :param task_file: + + :param task_file: full path to the _iblrig_taskData.raw.jsonable file + :param settings_file:full path to the _iblrig_taskSettings.raw.json file """ self.session_path = one.alf.files.get_session_path(task_file) or '' self.last_trials = pd.DataFrame( @@ -47,23 +59,26 @@ def __init__(self, task_file): index=np.arange(NTRIALS_PLOT), ) - if task_file is None or not Path(task_file).exists(): + if settings_file is None and task_file is not None and task_file.exists(): + settings_file = task_file.parent.joinpath('_iblrig_taskSettings.raw.json') + + if settings_file is not None and settings_file.exists(): + # most of the IBL tasks have a predefined set of probabilities (0.2, 0.5, 0.8), but in the + # case of the advanced choice world task, the probabilities are defined in the task settings + with open(settings_file) as fid: + self.task_settings = json.load(fid) + self.probability_set = [self.task_settings['PROBABILITY_LEFT']] + self.task_settings.get('BLOCK_PROBABILITY_SET', []) + else: + log.warning('Settings file not found - using default settings') + + if task_file is None or not task_file.exists(): self.psychometrics = pd.DataFrame( columns=['count', 'response_time', 'choice', 'response_time_std', 'choice_std'], - index=pd.MultiIndex.from_product([PROBABILITY_SET, np.r_[-np.flipud(CONTRAST_SET[1:]), CONTRAST_SET]]), + index=pd.MultiIndex.from_product([self.probability_set, np.r_[-np.flipud(CONTRAST_SET[1:]), CONTRAST_SET]]), ) self.psychometrics['count'] = 0 self.trials_table = pd.DataFrame(columns=['response_time'], index=np.arange(NTRIALS_INIT)) - self.ntrials = 0 - self.ntrials_correct = 0 - self.ntrials_nan = np.nan - self.percent_correct = np.nan - self.percent_error = np.nan - self.water_delivered = 0 - self.time_elapsed = 0 - self.ntrials_engaged = 0 # those are the trials happening within the first 400s else: - self.get_task_settings(Path(task_file).parent) trials_table, bpod_data = load_task_jsonable(task_file) # here we take the end time of the first trial as reference to avoid factoring in the delay self.time_elapsed = bpod_data[-1]['Trial end timestamp'] - bpod_data[0]['Trial end timestamp'] @@ -74,7 +89,7 @@ def __init__(self, task_file): CategoricalDtype(categories=np.unique(np.r_[-CONTRAST_SET, CONTRAST_SET]), ordered=True) ) trials_table['stim_probability_left'] = trials_table['stim_probability_left'].astype( - CategoricalDtype(categories=PROBABILITY_SET, ordered=True) + CategoricalDtype(categories=self.probability_set, ordered=True) ) self.psychometrics = trials_table.groupby(['stim_probability_left', 'signed_contrast']).agg( count=pd.NamedAgg(column='signed_contrast', aggfunc='count'), @@ -119,13 +134,6 @@ def __init__(self, task_file): self.last_contrasts[ileft, 0] = np.abs(self.last_trials.signed_contrast[ileft]) self.last_contrasts[iright, 1] = np.abs(self.last_trials.signed_contrast[iright]) - def get_task_settings(self, session_directory: str | Path) -> None: - task_settings_file = Path(session_directory).joinpath('_iblrig_taskSettings.raw.json') - if not task_settings_file.exists(): - return - with open(task_settings_file) as fid: - self.task_settings = json.load(fid) - def update_trial(self, trial_data, bpod_data) -> None: # update counters self.time_elapsed = bpod_data['Trial end timestamp'] - bpod_data['Bpod start timestamp'] @@ -140,6 +148,9 @@ def update_trial(self, trial_data, bpod_data) -> None: # update psychometrics using online statistics method indexer = (trial_data.stim_probability_left, signed_contrast) + if indexer not in self.psychometrics.index: + self.psychometrics.loc[indexer, :] = np.NaN + self.psychometrics.loc[indexer, ('count')] = 0 self.psychometrics.loc[indexer, ('count')] += 1 self.psychometrics.loc[indexer, ('response_time')], self.psychometrics.loc[indexer, ('response_time_std')] = online_std( new_sample=trial_data.response_time, @@ -158,10 +169,10 @@ def update_trial(self, trial_data, bpod_data) -> None: i = NTRIALS_PLOT - 1 self.last_trials.at[i, 'correct'] = trial_data.trial_correct self.last_trials.at[i, 'signed_contrast'] = signed_contrast - self.last_trials.at[i, 'stim_on'] = bpod_data['States timestamps']['stim_on'][0][0] - self.last_trials.at[i, 'play_tone'] = bpod_data['States timestamps']['play_tone'][0][0] - self.last_trials.at[i, 'reward_time'] = bpod_data['States timestamps']['reward'][0][0] - self.last_trials.at[i, 'error_time'] = bpod_data['States timestamps']['error'][0][0] + self.last_trials.at[i, 'stim_on'] = bpod_data['States timestamps'].get('stim_on', [[np.nan]])[0][0] + self.last_trials.at[i, 'play_tone'] = bpod_data['States timestamps'].get('play_tone', [[np.nan]])[0][0] + self.last_trials.at[i, 'reward_time'] = bpod_data['States timestamps'].get('reward', [[np.nan]])[0][0] + self.last_trials.at[i, 'error_time'] = bpod_data['States timestamps'].get('error', [[np.nan]])[0][0] self.last_trials.at[i, 'response_time'] = trial_data.response_time # update rgb image self.rgb_background = np.roll(self.rgb_background, -1, axis=0) @@ -208,8 +219,10 @@ class OnlinePlots: >>> OnlinePlots().run(task_file) """ - def __init__(self, task_file=None): - self.data = DataModel(task_file=task_file) + def __init__(self, task_file=None, settings_file=None): + task_file = Path(task_file) if task_file is not None else None + settings_file = Path(settings_file) if settings_file is not None else None + self.data = DataModel(task_file=task_file, settings_file=settings_file) # create figure and axes h = Bunch({}) @@ -241,7 +254,7 @@ def __init__(self, task_file=None): # create psych curves h.curve_psych = {} h.curve_reaction = {} - for p in PROBABILITY_SET: + for p in self.data.probability_set: h.curve_psych[p] = h.ax_psych.plot( self.data.psychometrics.loc[p].index, self.data.psychometrics.loc[p]['choice'], @@ -322,7 +335,7 @@ def update_graphics(self, pupdate: float | None = None): h = self.h h.fig.set_facecolor(background_color) self.update_titles() - for p in PROBABILITY_SET: + for p in self.data.probability_set: if pupdate is not None and p != pupdate: continue # update psychometric curves @@ -363,7 +376,6 @@ def run(self, task_file: Path | str) -> None: :return: """ task_file = Path(task_file) - self.data.get_task_settings(task_file.parent) self._set_session_string() self.update_titles() self.h.fig.canvas.flush_events() diff --git a/iblrig/rig_component.py b/iblrig/rig_component.py new file mode 100644 index 000000000..05a039da1 --- /dev/null +++ b/iblrig/rig_component.py @@ -0,0 +1,45 @@ +from abc import abstractmethod + +from iblrig.hardware_validation import Validator +from iblrig.pydantic_definitions import BunchModel + + +class RigComponent: + @abstractmethod + @property + def pretty_name(self) -> str: + """ + Get the pretty name of the component. + + Returns + ------- + str + A user-friendly name of the component. + """ + ... + + @abstractmethod + @property + def validator(self) -> Validator: + """ + Get the validator for the component. + + Returns + ------- + Validator + The validator instance associated with the component. + """ + ... + + @abstractmethod + @property + def settings(self) -> BunchModel: + """ + Get the settings for the component. + + Returns + ------- + BunchModel + The pydantic model for the component's settings. + """ + ... diff --git a/iblrig/test/tasks/test_advanced_choice_world.py b/iblrig/test/tasks/test_advanced_choice_world.py index ebe5de405..a80d882d9 100644 --- a/iblrig/test/tasks/test_advanced_choice_world.py +++ b/iblrig/test/tasks/test_advanced_choice_world.py @@ -1,29 +1,72 @@ +from unittest import TestCase + import numpy as np +import pandas as pd from iblrig.test.base import TASK_KWARGS, BaseTestCases from iblrig.test.tasks.test_biased_choice_world_family import get_fixtures from iblrig_tasks._iblrig_tasks_advancedChoiceWorld.task import Session as AdvancedChoiceWorldSession +class TestDefaultParameters(TestCase): + def test_params_yaml(self): + # just make sure the parameter file is + task = AdvancedChoiceWorldSession(**TASK_KWARGS) + self.assertEqual(12, task.df_contingencies.shape[0]) + self.assertEqual(task.task_params['PROBABILITY_LEFT'], 0.5) + + class TestInstantiationAdvanced(BaseTestCases.CommonTestInstantiateTask): def setUp(self) -> None: - self.task = AdvancedChoiceWorldSession(**TASK_KWARGS) + self.task = AdvancedChoiceWorldSession( + probability_set=[2, 2, 2, 1, 1, 1], + contrast_set=[1, 0.5, 0, 0, 0.5, 1], + reward_set_ul=[1, 1.5, 2, 2, 2.5, 2.6], + position_set=[-35, -35, -35, 35, 35, 35], + **TASK_KWARGS, + ) def test_task(self): task = self.task task.create_session() + # given the table probabilities above, the left stimulus is twice as likely to be right + self.assertTrue(task.task_params['PROBABILITY_LEFT'] == 2 / 3) + # run a fake task for 800 trials trial_fixtures = get_fixtures() nt = 800 + np.random.seed(65432) + for i in np.arange(nt): task.next_trial() # pc = task.psychometric_curve() trial_type = np.random.choice(['correct', 'error', 'no_go'], p=[0.9, 0.05, 0.05]) - task.trial_completed(trial_fixtures[trial_type]) + task.trial_completed(bpod_data=trial_fixtures[trial_type]) if trial_type == 'correct': assert task.trials_table['trial_correct'][task.trial_num] else: assert not task.trials_table['trial_correct'][task.trial_num] - if i == 245: task.show_trial_log() assert not np.isnan(task.reward_time) + + # check the contrasts and positions by aggregating the trials table + df_contrasts = ( + task.trials_table.iloc[:nt, :] + .groupby(['contrast', 'position']) + .agg( + count=pd.NamedAgg(column='reward_amount', aggfunc='count'), + n_unique_rewards=pd.NamedAgg(column='reward_amount', aggfunc='nunique'), + max_reward=pd.NamedAgg(column='reward_amount', aggfunc='max'), + min_reward=pd.NamedAgg(column='reward_amount', aggfunc='min'), + ) + .reset_index() + ) + # the error trials have 0 reward while the correct trials have their assigned reward amount + np.testing.assert_array_equal(df_contrasts['n_unique_rewards'], 2) + np.testing.assert_array_equal(df_contrasts['min_reward'], 0) + np.testing.assert_array_equal(df_contrasts['max_reward'], [2, 2, 1.5, 2.5, 1, 2.6]) + + n_left = np.sum(df_contrasts['count'][df_contrasts['position'] < 0]) + n_right = np.sum(df_contrasts['count'][df_contrasts['position'] > 0]) + # the left stimulus is twice as likely to be shown + self.assertTrue(n_left > (n_right * 1.5)) diff --git a/iblrig/test/test_alyx.py b/iblrig/test/test_alyx.py index db6a5bf09..da9acfc93 100644 --- a/iblrig/test/test_alyx.py +++ b/iblrig/test/test_alyx.py @@ -13,7 +13,13 @@ from iblrig.test.base import TASK_KWARGS from iblrig_tasks._iblrig_tasks_trainingChoiceWorld.task import Session as TrainingChoiceWorldSession from one.api import ONE -from one.tests import TEST_DB_1 + +TEST_DB = { + 'base_url': 'https://test.alyx.internationalbrainlab.org', + 'username': 'test_user', + 'password': 'TapetesBloc18', + 'silent': True, +} class TestRegisterSession(unittest.TestCase): @@ -26,7 +32,7 @@ def setUp(self): self.tmpdir = Path(tmp.name) # Create a random new subject - self.one = ONE(**TEST_DB_1, cache_rest=None) + self.one = ONE(**TEST_DB, cache_rest=None) self.subject = ''.join(random.choices(string.ascii_letters, k=10)) self.lab = 'mainenlab' self.one.alyx.rest('subjects', 'create', data={'lab': self.lab, 'nickname': self.subject}) diff --git a/iblrig/test/test_video.py b/iblrig/test/test_video.py index e43ecf5ae..dcebeab26 100644 --- a/iblrig/test/test_video.py +++ b/iblrig/test/test_video.py @@ -19,9 +19,10 @@ class TestDownloadFunction(unittest.TestCase): - @patch('one.webclient.AlyxClient.download_file', return_value=('mocked_tmp_file', 'mocked_md5_checksum')) + @patch('iblrig.video.aws.s3_download_file', return_value=Path('mocked_tmp_file')) + @patch('iblrig.video.hashfile.md5', return_value='mocked_md5_checksum') @patch('os.rename', return_value=None) - def test_download_from_alyx_or_flir(self, mock_os_rename, mock_alyx_download): + def test_download_from_alyx_or_flir(self, mock_os_rename, mock_hashfile, mock_aws_download): asset = 123 filename = 'test_file.txt' @@ -31,10 +32,9 @@ def test_download_from_alyx_or_flir(self, mock_os_rename, mock_alyx_download): # Assertions expected_out_file = Path.home().joinpath('Downloads', filename) self.assertEqual(result, expected_out_file) - mock_alyx_download.assert_called_once_with( - f'resources/spinnaker/{filename}', target_dir=Path(expected_out_file.parent), clobber=True, return_md5=True - ) - mock_os_rename.assert_called_once_with('mocked_tmp_file', expected_out_file) + mock_hashfile.assert_called() + mock_aws_download.assert_called_once_with(source=f'resources/{filename}', destination=Path(expected_out_file)) + mock_os_rename.assert_called_once_with(Path('mocked_tmp_file'), expected_out_file) class TestSettings(unittest.TestCase): @@ -172,7 +172,7 @@ def test_prepare_video_session(self, enable_camera_trigger, call_bonsai, session # Test config validation self.assertRaises(ValueError, video.prepare_video_session, self.subject, 'training') - session().hardware_settings = hws.construct() + session().hardware_settings = hws.model_construct() self.assertRaises(ValueError, video.prepare_video_session, self.subject, 'training') @@ -211,7 +211,7 @@ def test_validate_video(self, load_embedded_frame_data, get_video_meta): } self.assertCountEqual(set(x.getMessage() for x in log.records), expected) # Test video meta warnings - config = self.config.copy() + config = self.config.model_copy() config.HEIGHT = config.WIDTH = 160 config.FPS = 150 with self.assertLogs(video.__name__, 30) as log: diff --git a/iblrig/validation.py b/iblrig/validation.py deleted file mode 100644 index a23f3a05a..000000000 --- a/iblrig/validation.py +++ /dev/null @@ -1,133 +0,0 @@ -from iblrig.base_tasks import BpodMixin, SoundMixin -from iblrig.constants import BASE_PATH -from pybpodapi.protocol import StateMachine - - -class _SoundCheckTask(BpodMixin, SoundMixin): - protocol_name = 'hardware_check_harp' - - def __init__(self, *args, **kwargs): - param_file = BASE_PATH.joinpath('iblrig', 'base_choice_world_params.yaml') - super().__init__(*args, task_parameter_file=param_file, **kwargs) - - def start_hardware(self): - self.start_mixin_bpod() - self.start_mixin_sound() - - def get_state_machine(self): - sma = StateMachine(self.bpod) - sma.add_state('tone', 0.5, {'Tup': 'noise'}, [self.bpod.actions.play_tone]) - sma.add_state('noise', 1, {'Tup': 'exit'}, [self.bpod.actions.play_noise]) - return sma - - def play_sounds(self): - sma = self.get_state_machine() - self.bpod.send_state_machine(sma) - self.bpod.run_state_machine(sma) - - def _run(self): - pass - - -def sound_check(): - """ - # TODO: within the task (and actually with this hardware check), we need to test - for the exact same number of pulses than generated by the state machine. - # if it is more, it means it's noise and potentially disconnected, if it is less, - it means the sound card is not sending the pulses properly - - bpod_data_success: - GPIO -- Bpod OR GPIO X- Bpod - Harp -- Bpod - GPIO -- Harp - - bpod_data_success - {'Bpod start timestamp': 0.620611, - 'Trial start timestamp': 0.620611, - 'Trial end timestamp': 2.120613, - 'States timestamps': {'play_tone': [(0, 0.5)], 'error': [(0.5, 1.5)]}, - 'Events timestamps': {'BNC2High': [0.0007, 0.5007], - 'BNC2Low': [0.1107, 1.0007000000000001], - 'Tup': [0.5, 1.5]}} - - bpod_data_failure: - GPIO -- Bpod - Harp -X Bpod - GPIO -X Harp - {'Bpod start timestamp': 0.620611, - 'Trial start timestamp': 232.963811, - 'Trial end timestamp': 234.463814, - 'States timestamps': {'play_tone': [(0, 0.5)], 'error': [(0.5, 1.5)]}, - 'Events timestamps': {'BNC2Low': [0.008400000000000001, - 0.0349, - 0.08990000000000001, - 0.1796, - 0.19360000000000002, - 0.2753, - 0.28150000000000003, - 0.29550000000000004, - 0.33140000000000003, - 0.36100000000000004, - 0.5086, - 0.5457000000000001, - 0.6646000000000001, - 0.6959000000000001, - 0.7241000000000001, - 0.8599, - 0.8823000000000001, - 0.9087000000000001, - 0.9398000000000001, - 1.0050000000000001, - 1.1079, - 1.1265, - 1.1955, - 1.2302, - 1.3635000000000002, - 1.4215], - 'BNC2High': [0.0085, - 0.035, - 0.09000000000000001, - 0.1797, - 0.1937, - 0.27540000000000003, - 0.2816, - 0.29560000000000003, - 0.3315, - 0.36110000000000003, - 0.5087, - 0.5458000000000001, - 0.6647000000000001, - 0.6960000000000001, - 0.7242000000000001, - 0.86, - 0.8824000000000001, - 0.9088, - 0.9399000000000001, - 1.0051, - 1.108, - 1.1266, - 1.1956, - 1.2303, - 1.3636000000000001, - 1.4216], - 'Tup': [0.5, 1.5]}} - - bpod data failure: case if sound card is not wired properly (no feedback from sound card) - bpod_data_failure: - GPIO -X Bpod - Harp -X Bpod - GPIO -- Harp - Out[49]: - {'Bpod start timestamp': 0.620611, - 'Trial start timestamp': 405.619411, - 'Trial end timestamp': 407.119414, - 'States timestamps': {'play_tone': [(0, 0.5)], 'error': [(0.5, 1.5)]}, - 'Events timestamps': {'Tup': [0.5, 1.5]}} - """ - - task = _SoundCheckTask(subject='toto') - task.start_hardware() - task.play_sounds() - - bpod_data = task.bpod.session.current_trial.export() - assert len(bpod_data['Events timestamps']['BNC2High']) == 2 diff --git a/iblrig/video.py b/iblrig/video.py index 28675df46..5b5ec90d8 100644 --- a/iblrig/video.py +++ b/iblrig/video.py @@ -6,7 +6,6 @@ import sys import zipfile from pathlib import Path -from urllib.error import URLError import yaml @@ -25,11 +24,20 @@ ) from iblutil.util import setup_logger from one.converters import ConversionMixin -from one.webclient import AlyxClient, http_download_file # type: ignore +from one.remote import aws +from one.webclient import http_download_file # type: ignore with contextlib.suppress(ImportError): from iblrig import video_pyspin +SPINNAKER_ASSET = 59586 +SPINNAKER_FILENAME = 'SpinnakerSDK_FULL_3.2.0.57_x64.exe' +SPINNAKER_MD5 = 'aafc07c858dc2ab2e2a7d6ef900ca9a7' + +PYSPIN_ASSET = 59584 +PYSPIN_FILENAME = 'spinnaker_python-3.2.0.57-cp310-cp310-win_amd64.zip' +PYSPIN_MD5 = 'f93294208e0ecec042adb2f75cb72609' + log = logging.getLogger(__name__) @@ -60,19 +68,28 @@ def _download_from_alyx_or_flir(asset: int, filename: str, target_md5: str) -> P out_dir = Path.home().joinpath('Downloads') out_file = out_dir.joinpath(filename) options = {'target_dir': out_dir, 'clobber': True, 'return_md5': True} + + # if the file already exists skip all downloads if out_file.exists() and hashfile.md5(out_file) == target_md5: return out_file - try: - tmp_file, md5_sum = AlyxClient().download_file(f'resources/spinnaker/{filename}', **options) - except (OSError, AttributeError, URLError) as e1: + + # first try to download from public s3 bucket + tmp_file = aws.s3_download_file(source=f'resources/{filename}', destination=out_file) + if tmp_file is not None: + md5_sum = hashfile.md5(tmp_file) + + # if that fails try to download from flir server + else: try: url = f'https://flir.netx.net/file/asset/{asset}/original/attachment' tmp_file, md5_sum = http_download_file(url, **options) - except OSError as e2: - raise e2 from e1 + except OSError as e: + raise Exception(f'`{filename}` could not be downloaded - manual intervention is necessary') from e + + # finally os.rename(tmp_file, out_file) if md5_sum != target_md5: - raise Exception(f'`{filename}` does not match the expected MD5 - please try running the script again or') + raise Exception(f'`{filename}` does not match the expected MD5 - manual intervention is necessary') return out_file @@ -99,7 +116,7 @@ def install_spinnaker(): return # Download & install Spinnaker SDK - file_winsdk = _download_from_alyx_or_flir(54386, 'SpinnakerSDK_FULL_3.1.0.79_x64.exe', 'd9d83772f852e5369da2fbcc248c9c81') + file_winsdk = _download_from_alyx_or_flir(SPINNAKER_ASSET, SPINNAKER_FILENAME, SPINNAKER_MD5) print('Installing Spinnaker SDK for Windows ...') input( 'Please select the "Application Development" Installation Profile. Everything else can be left at ' @@ -136,9 +153,7 @@ def install_pyspin(): if HAS_PYSPIN: print('PySpin is already installed.') else: - file_zip = _download_from_alyx_or_flir( - 54396, 'spinnaker_python-3.1.0.79-cp310-cp310-win_amd64.zip', 'e00148800757d0ed7171348d850947ac' - ) + file_zip = _download_from_alyx_or_flir(PYSPIN_ASSET, PYSPIN_FILENAME, PYSPIN_MD5) print('Installing PySpin ...') with zipfile.ZipFile(file_zip, 'r') as f: file_whl = f.extract(file_zip.stem + '.whl', file_zip.parent) diff --git a/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task.py b/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task.py index 932e70803..9418d2ee2 100644 --- a/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task.py +++ b/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task.py @@ -1,9 +1,11 @@ from pathlib import Path +import numpy as np +import pandas as pd import yaml import iblrig.misc -from iblrig.base_choice_world import ActiveChoiceWorldSession +from iblrig.base_choice_world import NTRIALS_INIT, ActiveChoiceWorldSession # read defaults from task_parameters.yaml with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f: @@ -24,18 +26,52 @@ def __init__( self, *args, contrast_set: list[float] = DEFAULTS['CONTRAST_SET'], - contrast_set_probability_type: str = DEFAULTS['CONTRAST_SET_PROBABILITY_TYPE'], - probability_left: float = DEFAULTS['PROBABILITY_LEFT'], - reward_amount_ul: float = DEFAULTS['REWARD_AMOUNT_UL'], + probability_set: list[float] = DEFAULTS['PROBABILITY_SET'], + reward_set_ul: list[float] = DEFAULTS['REWARD_SET_UL'], + position_set: list[float] = DEFAULTS['POSITION_SET'], stim_gain: float = DEFAULTS['STIM_GAIN'], **kwargs, ): super().__init__(*args, **kwargs) + nc = len(contrast_set) + assert len(probability_set) in [nc, 1], 'probability_set must be a scalar or have the same length as contrast_set' + assert len(reward_set_ul) in [nc, 1], 'reward_set_ul must be a scalar or have the same length as contrast_set' + assert len(position_set) == nc, 'position_set must have the same length as contrast_set' self.task_params['CONTRAST_SET'] = contrast_set - self.task_params['CONTRAST_SET_PROBABILITY_TYPE'] = contrast_set_probability_type - self.task_params['PROBABILITY_LEFT'] = probability_left - self.task_params['REWARD_AMOUNT_UL'] = reward_amount_ul + self.task_params['PROBABILITY_SET'] = probability_set + self.task_params['REWARD_SET_UL'] = reward_set_ul + self.task_params['POSITION_SET'] = position_set self.task_params['STIM_GAIN'] = stim_gain + # it is easier to work with parameters as a dataframe + self.df_contingencies = pd.DataFrame(columns=['contrast', 'probability', 'reward_amount_ul', 'position']) + self.df_contingencies['contrast'] = contrast_set + self.df_contingencies['probability'] = probability_set if len(probability_set) == nc else probability_set[0] + self.df_contingencies['reward_amount_ul'] = reward_set_ul if len(reward_set_ul) == nc else reward_set_ul[0] + self.df_contingencies['position'] = position_set + # normalize the probabilities + self.df_contingencies.loc[:, 'probability'] = self.df_contingencies.loc[:, 'probability'] / np.sum( + self.df_contingencies.loc[:, 'probability'] + ) + # update the PROBABILITY LEFT field to reflect the probabilities in the parameters above + self.task_params['PROBABILITY_LEFT'] = np.sum( + self.df_contingencies['probability'] * (self.df_contingencies['position'] < 0) + ) + self.trials_table['debias_trial'] = np.zeros(NTRIALS_INIT, dtype=bool) + + def draw_next_trial_info(self, **kwargs): + nc = self.df_contingencies.shape[0] + ic = np.random.choice(np.arange(nc), p=self.df_contingencies['probability']) + # now calling the super class with the proper parameters + super().draw_next_trial_info( + pleft=self.task_params.PROBABILITY_LEFT, + contrast=self.df_contingencies.at[ic, 'contrast'], + position=self.df_contingencies.at[ic, 'position'], + reward_amount=self.df_contingencies.at[ic, 'reward_amount_ul'], + ) + + @property + def reward_amount(self): + return self.task_params.REWARD_AMOUNTS_UL[0] @staticmethod def extra_parser(): @@ -48,32 +84,34 @@ def extra_parser(): default=DEFAULTS['CONTRAST_SET'], nargs='+', type=float, - help='set of contrasts to present', + help='Set of contrasts to present', ) parser.add_argument( - '--contrast_set_probability_type', - option_strings=['--contrast_set_probability_type'], - dest='contrast_set_probability_type', - default=DEFAULTS['CONTRAST_SET_PROBABILITY_TYPE'], - type=str, - choices=['skew_zero', 'uniform'], - help=f'probability type for contrast set ' f'(default: {DEFAULTS["CONTRAST_SET_PROBABILITY_TYPE"]})', + '--probability_set', + option_strings=['--probability_set'], + dest='probability_set', + default=DEFAULTS['PROBABILITY_SET'], + nargs='+', + type=float, + help='Probabilities of each contrast in contrast_set. If scalar all contrasts are equiprobable', ) parser.add_argument( - '--probability_left', - option_strings=['--probability_left'], - dest='probability_left', - default=DEFAULTS['PROBABILITY_LEFT'], + '--reward_set_ul', + option_strings=['--reward_set_ul'], + dest='reward_set_ul', + default=DEFAULTS['REWARD_SET_UL'], + nargs='+', type=float, - help=f'probability for stimulus to appear on the left ' f'(default: {DEFAULTS["PROBABILITY_LEFT"]:.1f})', + help='Reward for contrast in contrast set.', ) parser.add_argument( - '--reward_amount_ul', - option_strings=['--reward_amount_ul'], - dest='reward_amount_ul', - default=DEFAULTS['REWARD_AMOUNT_UL'], + '--position_set', + option_strings=['--position_set'], + dest='position_set', + default=DEFAULTS['POSITION_SET'], + nargs='+', type=float, - help=f'reward amount (default: {DEFAULTS["REWARD_AMOUNT_UL"]}μl)', + help='Position for each contrast in contrast set.', ) parser.add_argument( '--stim_gain', @@ -81,7 +119,7 @@ def extra_parser(): dest='stim_gain', default=DEFAULTS['STIM_GAIN'], type=float, - help=f'visual angle/wheel displacement ' f'(deg/mm, default: {DEFAULTS["STIM_GAIN"]})', + help=f'Visual angle/wheel displacement ' f'(deg/mm, default: {DEFAULTS["STIM_GAIN"]})', ) return parser diff --git a/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task_parameters.yaml b/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task_parameters.yaml index c639d9d46..48349d982 100644 --- a/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task_parameters.yaml +++ b/iblrig_tasks/_iblrig_tasks_advancedChoiceWorld/task_parameters.yaml @@ -1,7 +1,6 @@ -'CONTRAST_SET': [1.0, 0.5, 0.25, 0.125, 0.0625, 0.0] -'CONTRAST_SET_PROBABILITY_TYPE': skew_zero # uniform, skew_zero -'PROBABILITY_LEFT': 0.5 -'REWARD_AMOUNT_UL': 1.5 +'CONTRAST_SET': [1.0, 0.5, 0.25, 0.125, 0.0625, 0.0, 0.0, 0.0625, 0.125, 0.25, 0.5, 1.0] # signed contrast set +'PROBABILITY_SET': [1] # scalar or list of n signed contrasts values, if scalar all contingencies are equiprobable +'REWARD_SET_UL': [1.5] # scalar or list of Ncontrast values +'POSITION_SET': [-35, -35, -35, -35, -35, -35, 35, 35, 35, 35, 35, 35] # position set 'STIM_GAIN': 4.0 # wheel to stimulus relationship #'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials # todo -#'CONTRAST_SET_PROBABILITIES': [2, 2, 2, 2, 1] # todo diff --git a/iblrig_tasks/_iblrig_tasks_ephysChoiceWorld/task.py b/iblrig_tasks/_iblrig_tasks_ephysChoiceWorld/task.py index ce292a124..3e5ecdf1e 100644 --- a/iblrig_tasks/_iblrig_tasks_ephysChoiceWorld/task.py +++ b/iblrig_tasks/_iblrig_tasks_ephysChoiceWorld/task.py @@ -17,17 +17,26 @@ class Session(BiasedChoiceWorldSession): def __init__(self, *args, session_template_id=0, **kwargs): super().__init__(*args, **kwargs) self.task_params.SESSION_TEMPLATE_ID = session_template_id - trials_table = pd.read_parquet(Path(__file__).parent.joinpath('trials_fixtures.pqt')) - self.trials_table = ( - trials_table.loc[trials_table['session_id'] == session_template_id].reindex().drop(columns=['session_id']) - ) - self.trials_table = self.trials_table.reset_index() + self.trials_table = self.get_session_template(session_template_id) # reconstruct the block dataframe from the trials table self.blocks_table = self.trials_table.groupby('block_num').agg( probability_left=pd.NamedAgg(column='stim_probability_left', aggfunc='first'), block_length=pd.NamedAgg(column='stim_probability_left', aggfunc='count'), ) + @staticmethod + def get_session_template(session_template_id: int) -> pd.DataFrame: + """ + Returns the pre-generated trials dataframe from the 12 fixtures according to the template iD + :param session_template_id: int 0-11 + :return: + """ + trials_table = pd.read_parquet(Path(__file__).parent.joinpath('trials_fixtures.pqt')) + trials_table = ( + trials_table.loc[trials_table['session_id'] == session_template_id].reindex().drop(columns=['session_id']) + ).reset_index() + return trials_table + @staticmethod def extra_parser(): """:return: argparse.parser()""" diff --git a/pyproject.toml b/pyproject.toml index f9b32b1e5..a3875720b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ DEV = [ ] [project.scripts] -viewsession = "iblrig.commands:viewsession" +view_session = "iblrig.commands:view_session" transfer_data = "iblrig.commands:transfer_data_cli" transfer_video_data = "iblrig.commands:transfer_video_data_cli" transfer_ephys_data = "iblrig.commands:transfer_ephys_data_cli" @@ -119,7 +119,6 @@ ignore = [ ] [tool.ruff.format] -line-ending = "lf" quote-style = "single" [tool.ruff.lint.pydocstyle]