Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom opto task and extraction for U19 project #15

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9f9a2f7
add staticTrainingChoiceWorld
mdmelin Jul 10, 2024
e811305
fix folder name
mdmelin Jul 16, 2024
8f02e32
fix duplicated stimulus
mdmelin Jul 16, 2024
5cd1133
Add pulsepal mixin
mdmelin Jul 17, 2024
8b1470b
add task
mdmelin Jul 17, 2024
4ae158e
skeleton extraction code for opto task
mdmelin Jul 23, 2024
51327ad
fix import and abstract property bugs
mdmelin Jul 23, 2024
2ec42f5
laser time is a parameter
mdmelin Jul 25, 2024
3da7353
todos
mdmelin Jul 26, 2024
5cc1641
small tweaks
mdmelin Jul 26, 2024
899fa35
Task running. Need to change on state and fix stim rampdown
mdmelin Jul 26, 2024
8414bde
add GlobalTimer to control the rampdown of opto
mdmelin Aug 5, 2024
e8c0d98
hacky fix for first trial
mdmelin Aug 6, 2024
001e05c
track opto on time
mdmelin Aug 6, 2024
17af9a4
update default params: add 0.5 contrast
mdmelin Aug 8, 2024
a021750
extract opto intervals
mdmelin Aug 8, 2024
35661eb
docs
mdmelin Aug 8, 2024
33647d2
extractor and qc
mdmelin Aug 14, 2024
52c724c
qc logic
mdmelin Aug 14, 2024
5290564
extractor map typo
mdmelin Aug 14, 2024
7d54958
map non-opto task
mdmelin Aug 14, 2024
c9ac637
update non-opto task mapping
mdmelin Aug 14, 2024
5dd6b8d
Merge branch 'int-brain-lab:main' into main
mdmelin Aug 14, 2024
1b878fa
basic LED calibration
mdmelin Aug 19, 2024
dcf015b
Merge branch 'main' of https://github.com/mdmelin/project_extraction
mdmelin Aug 19, 2024
57d2c68
make punish timeout a parameter
mdmelin Sep 5, 2024
1d5f0d3
Merge branch 'int-brain-lab:main' into main
mdmelin Oct 17, 2024
bf2773e
proper path
mdmelin Oct 23, 2024
17d55eb
Merge branch 'main' of https://github.com/mdmelin/project_extraction
mdmelin Oct 23, 2024
b3fba12
Merge branch 'int-brain-lab:main' into main
mdmelin Oct 23, 2024
94f0010
Update pyproject.toml
mdmelin Oct 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/PulsePal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import logging
import sys
from typing import Literal
from abc import ABC, abstractmethod
import numpy as np

from iblrig.base_choice_world import SOFTCODE
from pybpodapi.protocol import StateMachine, Bpod
from pypulsepal import PulsePalObject
from iblrig.base_tasks import BaseSession

log = logging.getLogger('iblrig.task')

SOFTCODE_FIRE_PULSEPAL = max(SOFTCODE).value + 1
SOFTCODE_STOP_PULSEPAL = max(SOFTCODE).value + 2
V_MAX = 5


class PulsePalStateMachine(StateMachine):
"""
This class adds:
1. Hardware or sofware triggering of optogenetic stimulation via a PulsePal (or BPod Analog Output Module)
EITHER
- adds soft-codes for starting and stopping the opto stim
OR
- sets up a TTL to hardware trigger the PulsePal
2. (not yet implemented!!!) sets up a TTL channel for recording opto stim times from the PulsePal
"""
# TODO: define the TTL channel for recording opto stim times?
def __init__(
self,
bpod,
trigger_type: Literal['soft', 'hardware'] = 'soft',
is_opto_stimulation=False,
states_opto_ttls=None,
states_opto_stop=None,
opto_t_max_seconds=None,
):
super().__init__(bpod)
self.trigger_type = trigger_type
self.is_opto_stimulation = is_opto_stimulation
self.states_opto_ttls = states_opto_ttls or []
self.states_opto_stop = states_opto_stop or []

# Set global timer 1 for T_MAX
self.set_global_timer(timer_id=1, timer_duration=opto_t_max_seconds)

def add_state(self, **kwargs):
if self.is_opto_stimulation:
if kwargs['state_name'] in self.states_opto_ttls:
if self.trigger_type == 'soft':
kwargs['output_actions'] += [('SoftCode', SOFTCODE_FIRE_PULSEPAL),]
elif self.trigger_type == 'hardware':
kwargs['output_actions'] += [('BNC2', 255),]
kwargs['output_actions'] += [(Bpod.OutputChannels.GlobalTimerTrig, 1)] # start the global timer when the opto stim comes on
elif kwargs['state_name'] in self.states_opto_stop:
if self.trigger_type == 'soft':
kwargs['output_actions'] += [('SoftCode', SOFTCODE_STOP_PULSEPAL),]
elif self.trigger_type == 'hardware':
kwargs['output_actions'] += [('BNC2', 0),]

super().add_state(**kwargs)

class PulsePalMixin(ABC):
"""
A mixin class that adds optogenetic stimulation capabilities to a task via the
PulsePal module (or a Analog Output module running PulsePal firmware). It is used
in conjunction with the PulsePalStateMachine class rather than the StateMachine class.

The user must define the arm_opto_stim method to define the parameters for optogenetic stimulation.
PulsePalMixin supports soft-code triggering via the start_opto_stim and stop_opto_stim methods.
Hardware triggering is also supported by defining trigger channels in the arm_opto_stim method.

The opto stim is currently hard-coded on output channel 1.
A TTL pulse is hard-coded on output channel 2 for accurately recording trigger times. This TTL
will rise when the opto stim starts and fall when it stops, thus accurately recording software trigger times.
"""

def start_opto_hardware(self):
self.pulsepal_connection = PulsePalObject('COM13') # TODO: get port from hardware params
log.warning('Connected to PulsePal')
# TODO: get the calibration value for this specific cannula
#super().start_hardware() # TODO: move this out

# add the softcodes for the PulsePal
soft_code_dict = self.bpod.softcodes
soft_code_dict.update({SOFTCODE_STOP_PULSEPAL: self.stop_opto_stim})
soft_code_dict.update({SOFTCODE_FIRE_PULSEPAL: self.start_opto_stim})
self.bpod.register_softcodes(soft_code_dict)

@abstractmethod
def arm_opto_stim(self, ttl_output_channel):
raise NotImplementedError("User must define the stimulus and trigger type to deliver with pulsepal")
# Define the pulse sequence and load it to the desired output channel here
# This method should not fire the pulse train, that is handled by start_opto_stim() (soft-trigger) or a hardware trigger
# See https://github.com/sanworks/PulsePal/blob/master/Python/Python3/PulsePalExample.py for examples
# you should also define the max_stim_seconds property here to set the maximum duration of the pulse train

##############################
# Example code to define a sine wave lasting 5 seconds
voltages = list(range(0, 1000))
for i in voltages:
voltages[i] = math.sin(voltages[i]/float(10))*10 # Set 1,000 voltages to create a 20V peak-to-peak sine waveform
times = np.linspace(0, 5, len(voltages)) # Create a time vector for the waveform
self.stim_length_seconds = times[-1] # it is essential to get this property right so that the TTL for recording stim pulses is correcty defined
self.pulsepal_connection.sendCustomPulseTrain(1, times, voltages)
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)
##############################

@property
@abstractmethod
def stim_length_seconds():
# this should be set within the arm_opto_stim method
pass

def arm_ttl_stim(self):
# a TTL pulse from channel 2 that rises when the opto stim starts and falls when it stops
log.warning('Arming TTL signal')
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.stim_length_seconds)
self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [V_MAX,])
self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2)

def start_opto_stim(self):
self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0)
log.warning('Started opto stim')

def stop_opto_stim(self):
# this will stop the pulse train instantly (and the corresponding TTL pulse)
# To avoid rebound spiking in the case of GtACR, a ramp down is recommended
self.pulsepal_connection.abortPulseTrains()

def compute_vmax_from_calibration(self, calibration_value):
# TODO: implement this method to convert the calibration value to a voltage for the opto stim
pass

def __del__(self):
del self.pulsepal_connection
Empty file.
212 changes: 212 additions & 0 deletions iblrig_custom_tasks/max_optoStaticTrainingChoiceWorld/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""
This task is a replica of max_staticTrainingChoiceWorld with the addition of optogenetic stimulation
An `opto_stimulation` column is added to the trials_table, which is a boolean array of length NTRIALS_INIT
The PROBABILITY_OPTO_STIMULATION parameter is used to determine the probability of optogenetic stimulation
for each trial

Additionally the state machine is modified to add output TTLs for optogenetic stimulation
"""

import logging
import random
import sys
from importlib.util import find_spec
from pathlib import Path
from typing import Literal
import pandas as pd

import numpy as np
import yaml
import time

import iblrig
from iblrig.base_choice_world import SOFTCODE
from pybpodapi.protocol import StateMachine
from iblrig_custom_tasks.max_staticTrainingChoiceWorld.task import Session as StaticTrainingChoiceSession
from iblrig_custom_tasks.max_optoStaticTrainingChoiceWorld.PulsePal import PulsePalMixin, PulsePalStateMachine

stim_location_history = []

log = logging.getLogger('iblrig.task')

NTRIALS_INIT = 2000
SOFTCODE_FIRE_LED = max(SOFTCODE).value + 1
SOFTCODE_RAMP_DOWN_LED = max(SOFTCODE).value + 2
RAMP_SECONDS = .25 # time to ramp down the opto stim # TODO: make this a parameter
LED_V_MAX = 5 # maximum voltage for LED control # TODO: make this a parameter

# read defaults from task_parameters.yaml
with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f:
DEFAULTS = yaml.safe_load(f)

class Session(StaticTrainingChoiceSession, PulsePalMixin):
protocol_name = 'max_optoStaticTrainingChoiceWorld'
extractor_tasks = ['PulsePalTrials']

def __init__(
self,
*args,
probability_opto_stim: float = DEFAULTS['PROBABILITY_OPTO_STIM'],
opto_ttl_states: list[str] = DEFAULTS['OPTO_TTL_STATES'],
opto_stop_states: list[str] = DEFAULTS['OPTO_STOP_STATES'],
max_laser_time: float = DEFAULTS['MAX_LASER_TIME'],
estimated_led_power_mW: float = DEFAULTS['ESTIMATED_LED_POWER_MW'],
**kwargs,
):
super().__init__(*args, **kwargs)
self.task_params['OPTO_TTL_STATES'] = opto_ttl_states
self.task_params['OPTO_STOP_STATES'] = opto_stop_states
self.task_params['PROBABILITY_OPTO_STIM'] = probability_opto_stim
self.task_params['MAX_LASER_TIME'] = max_laser_time
self.task_params['LED_POWER'] = estimated_led_power_mW
# generates the opto stimulation for each trial
opto = np.random.choice(
[0, 1],
p=[1 - probability_opto_stim, probability_opto_stim],
size=NTRIALS_INIT,
).astype(bool)

opto[0] = False
self.trials_table['opto_stimulation'] = opto

# get the calibration values for the LED
# TODO: do a calibration curve instead
dat = pd.read_csv(r'Y:/opto_fiber_calibration_values.csv')
l_cannula = f'{kwargs["subject"]}L' #TODO: where is SUBJECT defined?
r_cannula = f'{kwargs["subject"]}R'
l_cable = 0
r_cable = 1
l_cal_power = dat[(dat['Cannula'] == l_cannula) & (dat['cable_ID'] == l_cable)].cable_power.values[0]
r_cal_power = dat[(dat['Cannula'] == r_cannula) & (dat['cable_ID'] == r_cable)].cable_power.values[0]

mean_cal_power = np.mean([l_cal_power, r_cal_power])
vmax = LED_V_MAX * self.task_params['LED_POWER'] / mean_cal_power
log.warning(f'Using VMAX: {vmax}V for target LED power {self.task_params["LED_POWER"]}mW')
self.task_params['VMAX_LED'] = vmax

def _instantiate_state_machine(self, trial_number=None):
"""
We override this using the custom class PulsePalStateMachine that appends TTLs for optogenetic stimulation where needed
:param trial_number:
:return:
"""
# PWM1 is the LED OUTPUT for port interface board
# Input is PortIn1
# TODO: enable input port?
log.warning('Instantiating state machine')
is_opto_stimulation = self.trials_table.at[trial_number, 'opto_stimulation']
if is_opto_stimulation:
self.arm_opto_stim()
self.arm_ttl_stim()
return PulsePalStateMachine(
self.bpod,
trigger_type='soft', # software trigger
is_opto_stimulation=is_opto_stimulation,
states_opto_ttls=self.task_params['OPTO_TTL_STATES'],
states_opto_stop=self.task_params['OPTO_STOP_STATES'],
opto_t_max_seconds=self.task_params['MAX_LASER_TIME'],
)

def arm_opto_stim(self):
# define a contant offset voltage with a ramp down at the end to avoid rebound excitation
log.warning('Arming opto stim')
ramp = np.linspace(self.task_params['VMAX_LED'], 0, 1000) # SET POWER
t = np.linspace(0, RAMP_SECONDS, 1000)
v = np.concatenate((np.array([self.task_params['VMAX_LED']]), ramp)) # SET POWER
t = np.concatenate((np.array([0]), t + self.task_params['MAX_LASER_TIME']))

self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME'])
self.pulsepal_connection.sendCustomPulseTrain(1, t, v)
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)

def start_opto_stim(self):
super().start_opto_stim()
self.opto_start_time = time.time()

@property
def stim_length_seconds(self):
return self.task_params['MAX_LASER_TIME']

def stop_opto_stim(self):
if time.time() - self.opto_start_time >= self.task_params['MAX_LASER_TIME']:
# the LED should have turned off by now, we don't need to force the ramp down
log.warning('Stopped opto stim - hit opto timeout')
return

# we will modify this function to ramp down the opto stim rather than abruptly stopping it
# send instructions to set the TTL back to 0
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 2, self.task_params['MAX_LASER_TIME'])
self.pulsepal_connection.sendCustomPulseTrain(2, [0,], [0,])
self.pulsepal_connection.programOutputChannelParam('customTrainID', 2, 2)

# send instructions to ramp the opto stim down to 0
v = np.linspace(self.task_params['VMAX_LED'], 0, 1000)
t = np.linspace(0, RAMP_SECONDS, 1000)
self.pulsepal_connection.programOutputChannelParam('phase1Duration', 1, self.task_params['MAX_LASER_TIME'])
self.pulsepal_connection.sendCustomPulseTrain(1, t, v)
self.pulsepal_connection.programOutputChannelParam('customTrainID', 1, 1)

# trigger these instructions
self.pulsepal_connection.triggerOutputChannels(1, 1, 0, 0)
log.warning('Stopped opto stim - hit a stop opto state')

def start_hardware(self):
super().start_hardware()
super().start_opto_hardware()


@staticmethod
def extra_parser():
""":return: argparse.parser()"""
parser = super(Session, Session).extra_parser()
parser.add_argument(
'--probability_opto_stim',
option_strings=['--probability_opto_stim'],
dest='probability_opto_stim',
default=DEFAULTS['PROBABILITY_OPTO_STIM'],
type=float,
help=f'probability of opto-genetic stimulation (default: {DEFAULTS["PROBABILITY_OPTO_STIM"]})',
)

parser.add_argument(
'--opto_ttl_states',
option_strings=['--opto_ttl_states'],
dest='opto_ttl_states',
default=DEFAULTS['OPTO_TTL_STATES'],
nargs='+',
type=str,
help='list of the state machine states where opto stim should be delivered',
)
parser.add_argument(
'--opto_stop_states',
option_strings=['--opto_stop_states'],
dest='opto_stop_states',
default=DEFAULTS['OPTO_STOP_STATES'],
nargs='+',
type=str,
help='list of the state machine states where opto stim should be stopped',
)
parser.add_argument(
'--max_laser_time',
option_strings=['--max_laser_time'],
dest='max_laser_time',
default=DEFAULTS['MAX_LASER_TIME'],
type=float,
help='Maximum laser duration in seconds',
)
parser.add_argument(
'--estimated_led_power_mW',
option_strings=['--estimated_led_power_mW'],
dest='estimated_led_power_mW',
default=DEFAULTS['ESTIMATED_LED_POWER_MW'],
type=float,
help='The estimated LED power in mW. Computed from a calibration curve'
)

return parser


if __name__ == '__main__': # pragma: no cover
kwargs = iblrig.misc.get_task_arguments(parents=[Session.extra_parser()])
sess = Session(**kwargs)
sess.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'CONTRAST_SET': [1.0, 0.25, 0.125, 0.0625, 0.0, 0.0, 0.0625, 0.125, 0.25, 1.0] # signed contrast set
'PROBABILITY_SET': [2, 2, 2, 2, 1, 1, 2, 2, 2, 2] # scalar or list of n signed contrasts values, if scalar all contingencies are equiprobable
'REWARD_SET_UL': [1.5] # scalar or list of Ncontrast values
'POSITION_SET': [-35, -35, -35, -35, -35, 35, 35, 35, 35, 35] # position set
'STIM_GAIN': 4.0 # wheel to stimulus relationship
'STIM_REVERSE': False
#'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials # todo

# Opto parameters
'OPTO_TTL_STATES': # list of the state machine states where opto stim should be delivered
- trial_start
'OPTO_STOP_STATES':
- no_go
- error
- reward
'PROBABILITY_OPTO_STIM': 0.2 # probability of optogenetic stimulation
'MAX_LASER_TIME': 6.0
'ESTIMATED_LED_POWER_MW': 2.5
#'MASK_TTL_STATES': # list of the state machine states where mask stim should be delivered
# - trial_start
# - delay_initiation
# - reset_rotary_encoder
# - quiescent_period
# - stim_on
# - interactive_delay
# - play_tone
# - reset2_rotary_encoder
# - closed_loop
# - no_go
# - freeze_error
# - error
# - freeze_reward
# - reward
Empty file.
Loading
Loading