From 7b9c86fffd3f6b1e2c095e75ad1f83f3f3e5a7f7 Mon Sep 17 00:00:00 2001 From: Tyler Lee Date: Sun, 19 Jun 2016 13:10:37 -0700 Subject: [PATCH] Updated full experiment refactor This adds a lot of documentation plus smooths over most all of the remaining rough edges. It should now be much easier to create new behaviors. --- .gitignore | 1 + pyoperant/behavior/base.py | 741 ++++++++++++++---- pyoperant/behavior/go_no_go_interrupt.py | 177 +++++ .../behavior/go_no_go_interrupt_config.yaml | 96 +++ .../behavior/simple_stimulus_playback.py | 127 +++ .../behavior/simple_stimulus_playback.yaml | 36 + pyoperant/blocks.py | 154 ++++ pyoperant/components.py | 368 +++++++-- pyoperant/configure.py | 124 +++ pyoperant/errors.py | 14 +- pyoperant/events.py | 435 ++++++++++ pyoperant/hwio.py | 448 +++++++++-- pyoperant/interfaces/arduino_.py | 173 ++-- pyoperant/interfaces/base_.py | 151 +++- pyoperant/interfaces/console_.py | 32 +- pyoperant/interfaces/nidaq_.py | 658 ++++++++++++++++ pyoperant/interfaces/pyaudio_.py | 103 ++- pyoperant/interfaces/pydaqmx_.py | 66 ++ pyoperant/interfaces/utils.py | 6 + pyoperant/panels.py | 43 +- pyoperant/queues.py | 151 ++-- pyoperant/reinf.py | 10 +- pyoperant/run_experiment.py | 31 + pyoperant/states.py | 403 ++++++++++ pyoperant/stimuli.py | 196 +++++ pyoperant/subjects.py | 157 ++++ pyoperant/trials.py | 151 ++++ pyoperant/utils.py | 47 +- setup.py | 2 +- 29 files changed, 4615 insertions(+), 486 deletions(-) create mode 100644 pyoperant/behavior/go_no_go_interrupt.py create mode 100644 pyoperant/behavior/go_no_go_interrupt_config.yaml create mode 100644 pyoperant/behavior/simple_stimulus_playback.py create mode 100644 pyoperant/behavior/simple_stimulus_playback.yaml create mode 100644 pyoperant/blocks.py create mode 100644 pyoperant/configure.py create mode 100644 pyoperant/events.py create mode 100644 pyoperant/interfaces/nidaq_.py create mode 100644 pyoperant/interfaces/utils.py create mode 100644 pyoperant/run_experiment.py create mode 100644 pyoperant/states.py create mode 100644 pyoperant/stimuli.py create mode 100644 pyoperant/subjects.py create mode 100644 pyoperant/trials.py diff --git a/.gitignore b/.gitignore index 55c74111..1a9a6e9e 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ MANIFEST fabfile.py .DS_Store +.idea* diff --git a/pyoperant/behavior/base.py b/pyoperant/behavior/base.py index 54a123ea..37b201c4 100755 --- a/pyoperant/behavior/base.py +++ b/pyoperant/behavior/base.py @@ -1,220 +1,629 @@ -import logging, traceback -import os, sys, socket +import logging +import traceback +import logging.handlers +import os +import sys +import socket import datetime as dt -from pyoperant import utils, components, local, hwio -from pyoperant import ComponentError, InterfaceError -from pyoperant.behavior import shape +from pyoperant import utils, components, local, hwio, configure +from pyoperant import ComponentError, InterfaceError, EndExperiment +from pyoperant import states, subjects, queues +from pyoperant.events import events, EventLogHandler +import pyoperant.blocks as blocks_ +import pyoperant.trials as trials_ +logger = logging.getLogger(__name__) -try: - import simplejson as json -except ImportError: - import json def _log_except_hook(*exc_info): text = "".join(traceback.format_exception(*exc_info)) - logging.error("Unhandled exception: %s", text) + logger.error("Unhandled exception: %s", text) + class BaseExp(object): - """Base class for an experiment. - - Keyword arguments: - name -- name of this experiment - desc -- long description of this experiment - debug -- (bool) flag for debugging (default=False) - light_schedule -- the light schedule for the experiment. either 'sun' or - a tuple of (starttime,endtime) tuples in (hhmm,hhmm) form defining - time intervals for the lights to be on - experiment_path -- path to the experiment - stim_path -- path to stimuli (default = /stims) - subject -- identifier of the subject - panel -- instance of local Panel() object - - Methods: + """ Base class for an experiment. This controls most of the experiment logic + so you only have to implement specifics for your behavior. + + Parameters + ---------- + panel: Panel instance or string + Instance of a panel or the name of a panel from local.py. It must + implement all required attributes for the current experiment. + block_queue: BlockHandler instance or a queue function or Class + The queue for looping over blocks. If this is not defined, the + block_queue will be used. + subject: Subject instance + The subject that will handle data storage. If this is not defined, one + will be created using subject_name, filename, and datastore + parameters. If those aren't provided, you can add a subject using the + "set_subject" method + conditions: list + A list of stimulus conditions instances. These currently need to be + defined explicitly. + blocks: list + A list of Block instances. If this is not provided, a single block will + be created using the "conditions", "queue" and "queue_parameters" + parameters. + sleep: Sleep state instance + Controls when the experiment is in the sleep state. If not provided, + one will be created from the "sleep_schedule" and "poll_interval" + parameters, if set, or can be added using the "add_sleep_schedule" + method. + session: Session state instance + Controls the scheduling and running of an experimental session. If not + provided, one will be created from the "max_trials", "session_duration", + and "session_interval" parameters. The session duration/interval and + trial limit can be set using the "set_session_time_limits" and + "set_session_trial_limit" methods. + idle: Idle state instance + Controls how the experiment behaves while idling. If not provided, one + will be created from the "poll_interval" parameter. + name: string + Name of this experiment + description: string + Long description of this experiment + debug: bool + Flag for debugging, switches the logging stream handler between to DEBUG level + num_sessions: int + The number of sessions to run. + filetime_fmt: string + The format for the time string in the name of the file. + experiment_path: string + Path to the experiment data directory. Provides the default location to store subject data. + log_handlers: list of dictionaries + Currently supported handler types are file and email (in addition to + the default stream handler) + sleep_schedule: string or tuple + The sleep schedule for the experiment. either 'night' or a tuple of + (starttime,endtime) in (hhmm,hhmm) form defining the time interval for + the experiment to sleep. + max_trials: int + The maximum number of trials for each experimental session. Can also be + set using the "set_session_trial_limit" method. + session_duration: int + The maximum duration, in minutes, of each experimental session. Can also + be set using the "set_session_time_limits" method. + session_interval: int + The minimum intersession interval in minutes. Can also be set using the + "set_session_time_limits" method. + poll_interval: int + The number of seconds to wait between successive checks regarding when + to start/stop sleeping or start an experimental session. + queue: a queue name, function, or Class + The queue to use to loop over stimulus conditions within a block. + queue_parameters: dictionary + Any additional parameters to provide the "queue". + reinforcement: a reinforcement name or Class + The type of reinforcement to use for a block + subject_name: string + The name of the subject performing the experiment + datastore: string + The type of file to store data in (e.g. "csv") + filename: string + The name of the file in which to store data. If the full path is not + provided, it will be in the path given by the "experiment_path" + parameter. + + All other key-value pairs get placed into the parameters attribute + + Methods + ------- run() -- runs the experiment + Required Panel Attributes + ------------------------- + sleep - Puts the panel to sleep + reset - Sets the panel back to a nice initial state + ready - Prepares the panel to run the behavior (e.g. turn on the + response_port light and put the feeder down) + idle - Sets the panel into an idle state for when the experiment is not + running + + Fields To Save + -------------- + session - The index of the current session + index - The index of the current trial + time - The start time of the trial """ + + # All panels should have these methods, but it's best to include them in every experiment just in case + req_panel_attr = ["sleep", + "reset", + "idle", + "ready"] + + # All experiments should store at least these fields but probably more + fields_to_save = ['session', + 'index', + 'time'] + def __init__(self, - name='', - description='', + panel, + block_queue=queues.block_queue, + subject=None, + conditions=None, + blocks=None, + sleep=None, + session=None, + idle=None, + name='Experiment', + description='A pyoperant experiment', debug=False, + num_sessions=1, filetime_fmt='%Y%m%d%H%M%S', - light_schedule='sun', - idle_poll_interval = 60.0, experiment_path='', - stim_path='', - subject='', - panel=None, - log_handlers=[], + log_handlers=None, + sleep_schedule=None, + max_trials=None, + session_duration=None, + session_interval=None, + poll_interval=60, + queue=queues.random_queue, + queue_parameters=None, + reinforcement=None, + subject_name=None, + datastore="csv", + filename=None, *args, **kwargs): - super(BaseExp, self).__init__() + super(BaseExp, self).__init__() + + # Initialize the experiment directory to start storing data + if not os.path.exists(experiment_path): + logger.debug("Creating %s" % experiment_path) + os.makedirs(experiment_path) + self.experiment_path = experiment_path + + # Set up logging + if log_handlers is None: + log_handlers = dict() + + # Stream handler is the console and takes level as it's only config + stream_handler = log_handlers.pop("stream", dict()) + # Initialize the logging + self.configure_logging(debug=debug, **stream_handler) + + # Event logging takes filename, format, and component as arguments + event_handler = log_handlers.pop("event", dict()) + self.configure_event_logging(**event_handler) + + # File handler has keywords of filename and level + if "file" in log_handlers: + self.add_file_handler(**log_handlers["file"]) + + # Email handler has keywords of mailhost, toaddrs, fromaddr, subject, credentials, secure, and level + if "email" in log_handlers: + self.add_email_handler(**log_handlers["email"]) + + # Experiment descriptors self.name = name self.description = description - self.debug = debug self.timestamp = dt.datetime.now().strftime(filetime_fmt) - self.parameters = kwargs - self.parameters['filetime_fmt'] = filetime_fmt - self.parameters['light_schedule'] = light_schedule - self.parameters['idle_poll_interval'] = idle_poll_interval - - self.parameters['experiment_path'] = experiment_path - if stim_path == '': - self.parameters['stim_path'] = os.path.join(experiment_path,'stims') + logger.debug("Initializing experiment: %s" % self.name) + logger.debug(self.description) + logger.debug("This experiment will store the following trial " + + "parameters:\n%s" % ", ".join(self.fields_to_save)) + + # Initialize the panel + logger.debug("Panel must support these attributes: " + + "%s" % ", ".join(self.req_panel_attr)) + if isinstance(panel, str) and hasattr(local, panel): + panel = getattr(local, panel) + self.panel = panel + # Verify the panel can support this behavior + self.check_panel_attributes(panel) + logger.debug('Initialized panel: %s' % self.panel.__class__.__name__) + + # Initialize the subject + if subject is not None: + subject = subject_name + logger.info("Preparing subject and data storage") + self.set_subject(subject_name, filename, datastore) + logger.debug("Data will be stored at %s" % self.subject.filename) + + # Initialize blocks and block_queue + logger.debug("Preparing blocks and block_queue") + if isinstance(block_queue, blocks_.BlockHandler): + self.block_queue = block_queue + self.blocks = block_queue.blocks + elif blocks is not None: + if isinstance(blocks, blocks_.Block): + blocks = [blocks] + self.blocks = blocks + logger.debug("Creating block_queue from blocks") + self.block_queue = blocks_.BlockHandler(blocks, queue=block_queue) + elif conditions is not None: + logger.debug("Creating block_queue from stimulus conditions") + if queue_parameters is None: + queue_parameters = dict() + self.blocks = [blocks_.Block(conditions, + queue=queue, + reinforcement=reinforcement, + **queue_parameters)] + self.block_queue = blocks_.BlockHandler(self.blocks, + queue=block_queue) else: - self.parameters['stim_path'] = stim_path - self.parameters['subject'] = subject - - # configure logging - self.parameters['log_handlers'] = log_handlers - self.log_config() + raise ValueError("Could not create blocks for the experiment. " + + "Please provide conditions, blocks, or block_queue") + + # Initialize the states + if sleep is not None: + self.add_sleep_schedule(sleep) + elif sleep_schedule is not None: + self.add_sleep_schedule(sleep_schedule, + poll_interval=poll_interval) + else: + self._sleep = None + + if session is None: + session = states.Session() + self.session = session + self.session.experiment = self + if max_trials is not None: + self.set_session_trial_limit(max_trials) + if session_duration is not None or session_interval is not None: + self.set_session_time_limits(duration=session_duration, + interval=session_interval) + self.num_sessions = num_sessions + + if idle is None: + idle = states.Idle(poll_interval=poll_interval) + self._idle = idle + self._idle.experiment = self - self.req_panel_attr= ['house_light', - 'reset', - ] - self.panel = panel - self.log.debug('panel %s initialized' % self.parameters['panel_name']) + self.parameters = kwargs - if 'shape' not in self.parameters or self.parameters['shape'] not in ['block1', 'block2', 'block3', 'block4', 'block5']: - self.parameters['shape'] = None + # Get ready to run! + self.session_id = 0 + self.finished = False + + def set_subject(self, subject, filename=None, datastore="csv"): + """ Creates a subject for the current experiment. + + Parameters + ---------- + subject: string or instance of Subject class + The name of the subject or an already created Subject instance to use. + filename: string + The path to the file in which to store data. + datastore: string + The type of file in which to store data (e.g. "csv") + """ + if subject is None: + raise ValueError("Subject has not yet been defined. " + + "Provide a value to either the subject " + + "or subject_name parameter for the behavior.") + if not isinstance(subject, subjects.Subject): + subject = subjects.Subject(subject) + + if subject.datastore is None: + if filename is None: + filename = "%s_trialdata_%s.%s" % (subject.name, + self.timestamp, + datastore) + # Add directory if filename is not a full path + if len(os.path.split(filename)[0]) == 0: + filename = os.path.join(self.experiment_path, + filename) + subject.filename = filename + subject.create_datastore(self.fields_to_save) + + logger.debug("Creating subject") + self.subject = subject + + def add_sleep_schedule(self, time_period, poll_interval=60): + """ Add a sleep schedule between start and end times + + Parameters + ---------- + time_period: string, tuple, or instance of Sleep state + Can be a string of "night", a tuple of ("HH:MM", "HH:MM"), or a + pre-created instance of Sleep state + poll_interval: int + The number of seconds to wait between successive checks regarding + when to start/stop sleeping + """ + + if isinstance(start, states.Sleep): + self._sleep = start + else: + self._sleep = states.Sleep(time_period=time_period, + poll_interval=poll_interval) + logger.debug("Adding sleep state") + self._sleep.experiment = self - self.shaper = shape.Shaper(self.panel, self.log, self.parameters, self.log_error_callback) + # Logging configure methods + def configure_logging(self, level=logging.INFO, debug=False): + """ Configures the basic logging for the experiment. This creates a handler for logging to the console, sets it at the appropriate level (info by default unless overridden in the config file or by the debug flag) and creates the default formatting for log messages. + """ - def save(self): - self.snapshot_f = os.path.join(self.parameters['experiment_path'], self.timestamp+'.json') - with open(self.snapshot_f, 'wb') as config_snap: - json.dump(self.parameters, config_snap, sort_keys=True, indent=4) + if debug is True: + self.log_level = logging.DEBUG + else: + self.log_level = level + + sys.excepthook = _log_except_hook # send uncaught exceptions to file + + logging.basicConfig( + level=self.log_level, + format='"%(asctime)s","%(levelname)s","%(message)s"' + ) + + # Make sure that the stream handler has the requested log level. + root_logger = logging.getLogger() + for handler in root_logger.handlers: + if isinstance(handler, logging.StreamHandler): + handler.setLevel(self.log_level) + + def configure_event_logging(self, filename="events.log", format=None, + component=None): + """ Sets up the logging of component events to a file. See events.py for + more details. + + Parameters + ---------- + filename + format + component + + TODO: If one already exists, don't create another! + + """ + + # Add directory if filename is not a full path + if len(os.path.split(filename)[0]) == 0: + filename = os.path.join(self.experiment_path, filename) + log_handler = EventLogHandler(filename=filename, format=format, + component=component) + events.add_handler(log_handler) + + def add_file_handler(self, filename="experiment.log", + format='"%(asctime)s","%(levelname)s","%(message)s"', + level=logging.INFO): + """ Add a file handler to the root logger + + Parameters + ---------- + filename: string + name of the experiment log file + format: string + format for log messages + level: logging level + defaults to logging.INFO, but could be set to logging.DEBUG + """ + + # Add directory if filename is not a full path + if len(os.path.split(filename)[0]) == 0: + filename = os.path.join(self.experiment_path, filename) + + file_handler = logging.FileHandler(filename) + file_handler.setLevel(level) + file_handler.setFormatter(logging.Formatter(format)) + + # Make sure the root logger's level is not too high + root_logger = logging.getLogger() + if root_logger.level > level: + root_logger.setLevel(level) + root_logger.addHandler(file_handler) + logger.debug("File handler added to %s with level %d" % (filename, + level)) + + def add_email_handler(self, toaddrs, mailhost="localhost", + fromaddr="Pyoperant level: + root_logger.setLevel(level) + root_logger.addHandler(email_handler) + logger.debug("Email handler added to %s with level %d" % (",".join(email_handler.toaddrs), level)) + + # Scheduling methods + def check_sleep_schedule(self): + """returns true if the experiment should be sleeping""" + if self._sleep is None: + return False + + to_sleep = self._sleep.check() + logger.debug("Checking sleep schedule: %s" % to_sleep) + return to_sleep - def log_config(self): + def check_session_schedule(self): + """returns True if the subject should be running sessions""" - self.log_file = os.path.join(self.parameters['experiment_path'], self.parameters['subject'] + '.log') + return self.session.check() + + def set_session_time_limits(self, duration=None, interval=None): + """ Sets the duration for the current or next session + + Parameters + ---------- + duration: int + Time, in minutes, that the session should last + interval: int + Time, in minutes, between consecutive sessions + """ + scheduler = states.TimeScheduler(duration=duration, interval=interval) + self.session.schedulers.append(scheduler) + + def set_session_trial_limit(self, max_trials): + """ Sets the number of trials for the current or upcoming session + + Parameters + ---------- + max_trials: int + Maximum number of trials that should be run + """ + + scheduler = states.CountScheduler(max_trials=max_trials) + self.session.schedulers.append(scheduler) + + def end(self): + """ Finish the experiment and put the panel to sleep """ + + # Close the event handlers because they are in separate threads + events.close_handlers() + self.finished = True + self.panel.sleep() + + def shape(self): + """ + This will house a method to run shaping. + """ + + pass + + @classmethod + def check_panel_attributes(cls, panel, raise_on_fail=True): + """ Check if the panel has all required attributes + + Parameters + ---------- + panel: panel instance + The panel to check + raise_on_fail: bool + True causes an AttributeError to be raised if the panel doesn't contain all required attributes + + Returns + ------- + True if panel has all required attributes, False otherwise + """ + + missing_attrs = list() + for attr in cls.req_panel_attr: + logger.debug("Checking that panel has attribute %s" % attr) + if not hasattr(panel, attr): + missing_attrs.append(attr) + + if len(missing_attrs) > 0: + logger.critical("Panel is missing attributes: %s" % ", ".join(missing_attrs)) + if raise_on_fail: + raise AttributeError("Panel is missing attributes: %s" % ", ".join(missing_attrs)) + return False - if self.debug: - self.log_level = logging.DEBUG else: - self.log_level = logging.INFO + return True - sys.excepthook = _log_except_hook # send uncaught exceptions to log file + def run(self): + """ Run shaping and then star the experiment """ - logging.basicConfig(filename=self.log_file, - level=self.log_level, - format='"%(asctime)s","%(levelname)s","%(message)s"') - self.log = logging.getLogger() + logger.info("Preparing to run experiment %s" % self.name) + logger.debug("Resetting panel") + self.panel.reset() - if 'email' in self.parameters['log_handlers']: - from pyoperant.local import SMTP_CONFIG - from logging import handlers - SMTP_CONFIG['toaddrs'] = [self.parameters['experimenter']['email'],] + # This still seems very odd to me. + logger.debug("Running shaping") + self.shape() - email_handler = handlers.SMTPHandler(**SMTP_CONFIG) - email_handler.setLevel(logging.WARNING) + # Run until self.end() is called + while self.finished == False: + # The idle state checks whether it's time to sleep or time to start the session, so start in that state. + self._idle.start() - heading = '%s\n' % (self.parameters['subject']) - formatter = logging.Formatter(heading+'%(levelname)s at %(asctime)s:\n%(message)s') - email_handler.setFormatter(formatter) + ## Session Flow + def session_pre(self): + """ Runs before the session starts. Initializes the block queue and + records the session start time. + """ + logger.debug("Beginning session") - self.log.addHandler(email_handler) + # Reinitialize the block queue + self.block_queue.reset() + self.session_id += 1 + self.session_start_time = dt.datetime.now() + self.panel.ready() - def check_light_schedule(self): - """returns true if the lights should be on""" - return utils.check_time(self.parameters['light_schedule']) + def session_main(self): + """ Runs the session by looping over the block queue and then running + each trial in each block. + """ - def check_session_schedule(self): - """returns True if the subject should be running sessions""" - return False + for self.this_block in self.block_queue: + self.this_block.experiment = self + logger.info("Beginning block #%d" % self.this_block.index) + for trial in self.this_block: + trial.run() - def panel_reset(self): - try: - self.panel.reset() - except components.ComponentError as err: - self.log.error("component error: %s" % str(err)) + def session_post(self): + """ Closes out the sessions + """ - def run(self): + self.panel.idle() + self.session_end_time = dt.datetime.now() + logger.info("Finishing session %d at %s" % (self.session_id, self.session_end_time.ctime())) + if self.session_id >= self.num_sessions: + logger.info("Finished all sessions.") + self.end() - for attr in self.req_panel_attr: - assert hasattr(self.panel,attr) - self.panel_reset() - self.save() - self.init_summary() - - self.log.info('%s: running %s with parameters in %s' % (self.name, - self.__class__.__name__, - self.snapshot_f, - ) - ) - if self.parameters['shape']: - self.shaper.run_shape(self.parameters['shape']) - while True: #is this while necessary - utils.run_state_machine(start_in='idle', - error_state='idle', - error_callback=self.log_error_callback, - idle=self._run_idle, - sleep=self._run_sleep, - session=self._run_session) - - def _run_idle(self): - if self.check_light_schedule() == False: - return 'sleep' - elif self.check_session_schedule(): - return 'session' - else: - self.panel_reset() - self.log.debug('idling...') - utils.wait(self.parameters['idle_poll_interval']) - return 'idle' + # Defining the different trial states. If any of these are not needed by the behavior, just don't define them in your subclass + def trial_pre(self): + pass + def stimulus_pre(self): + pass + def stimulus_main(self): + pass - # defining functions for sleep - def sleep_pre(self): - self.log.debug('lights off. going to sleep...') - return 'main' + def stimulus_post(self): + pass - def sleep_main(self): - """ reset expal parameters for the next day """ - self.log.debug('sleeping...') - self.panel.house_light.off() - utils.wait(self.parameters['idle_poll_interval']) - if self.check_light_schedule() == False: - return 'main' - else: - return 'post' + def response_pre(self): + pass - def sleep_post(self): - self.log.debug('ending sleep') - self.panel.house_light.on() - self.init_summary() - return None + def response_main(self): + pass - def _run_sleep(self): - utils.run_state_machine(start_in='pre', - error_state='post', - error_callback=self.log_error_callback, - pre=self.sleep_pre, - main=self.sleep_main, - post=self.sleep_post) - return 'idle' + def response_post(self): + pass - # session + def reward_pre(self): + pass - def session_pre(self): - return 'main' + def reward_main(self): + pass - def session_main(self): - return 'post' + def reward_post(self): + pass - def session_post(self): - return None + def punish_pre(self): + pass + + def punish_main(self): + pass - def _run_session(self): - utils.run_state_machine(start_in='pre', - error_state='post', - error_callback=self.log_error_callback, - pre=self.session_pre, - main=self.session_main, - post=self.session_post) - return 'idle' + def punish_post(self): + pass + def trial_post(self): + pass # gentner-lab specific functions def init_summary(self): @@ -244,4 +653,4 @@ def write_summary(self): def log_error_callback(self, err): if err.__class__ is InterfaceError or err.__class__ is ComponentError: - self.log.critical(str(err)) + logger.critical(str(err)) diff --git a/pyoperant/behavior/go_no_go_interrupt.py b/pyoperant/behavior/go_no_go_interrupt.py new file mode 100644 index 00000000..9f64a757 --- /dev/null +++ b/pyoperant/behavior/go_no_go_interrupt.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python +import os +import sys +import logging +import csv +import datetime as dt +import random +import numpy as np +from pyoperant.behavior import base +from pyoperant.errors import EndSession +from pyoperant import states, trials, blocks +from pyoperant import components, utils, reinf, queues, configure, stimuli, subjects + +logger = logging.getLogger(__name__) + + +class RewardedCondition(stimuli.StimulusConditionWav): + """ Rewarded stimuli are rewarded if the subject does *not* respond (i.e. + No-Go stimuli). + """ + def __init__(self, file_path="", recursive=False): + super(RewardedCondition, self).__init__(name="Rewarded", + response=False, + is_rewarded=True, + is_punished=False, + file_path=file_path, + recursive=recursive) + + +class UnrewardedCondition(stimuli.StimulusConditionWav): + """ Unrewarded stimuli are not consequated and should be pecked through + (i.e. Go stimuli) + """ + def __init__(self, file_path="", recursive=False): + + super(UnrewardedCondition, self).__init__(name="Unrewarded", + response=True, + is_rewarded=False, + is_punished=False, + file_path=file_path, + recursive=recursive) + + +class GoNoGoInterrupt(base.BaseExp): + """A go no-go interruption experiment + + Additional Parameters + --------------------- + reward_value: int + The value to pass as a reward (e.g. feed duration) + + For all other parameters, see pyoperant.behavior.base.BaseExp + + Required Panel Attributes + ------------------------- + sleep - Puts the panel to sleep + reset - Sets the panel back to a nice initial state + ready - Prepares the panel to run the behavior (e.g. turn on the + response_port light and put the feeder down) + idle - Sets the panel into an idle state for when the experiment is not + running + reward - Method for supplying a reward to the subject. Should take a reward + value as an argument + response_port - The input through which the subject responds + speaker - A speaker for sound output + + Fields To Save + -------------- + session - The index of the current session + index - The index of the current trial + time - The start time of the trial + stimulus_name - The filename of the stimulus + condition_name - The condition of the stimulus + response - Whether or not there was a response + correct - Whether the response was correct + rt - If there was a response, the time from sound playback + max_wait - The duration of the sound and thus maximum rt to be counted as a + response. + """ + + req_panel_attr = ["sleep", + "reset", + "ready", + "idle", + "reward", + "response_port", + "speaker"] + + fields_to_save = ['session', + 'index', + 'time', + 'stimulus_name', + 'condition_name', + 'response', + 'correct', + 'rt', + 'reward', + 'max_wait', + ] + + def __init__(self, reward_value=12, *args, **kwargs): + + super(GoNoGoInterrupt, self).__init__(*args, **kwargs) + self.start_immediately = False + self.reward_value = reward_value + + def trial_pre(self): + """ Initialize the trial and, if necessary, wait for a peck before + starting stimulus playback. + """ + + logger.debug("Starting trial #%d" % self.this_trial.index) + stimulus = self.this_trial.stimulus + condition = self.this_trial.condition.name + self.this_trial.annotate(stimulus_name=stimulus.file_origin, + condition_name=condition, + max_wait=stimulus.duration) + + if not self.start_immediately: + logger.debug("Begin polling for a response") + self.panel.response_port.poll() + + def stimulus_main(self): + """ Queue the stimulus and play it back """ + + logger.info("Trial %d - %s - %s - %s" % ( + self.this_trial.index, + self.this_trial.time.strftime("%H:%M:%S"), + self.this_trial.condition.name, + self.this_trial.stimulus.name)) + self.panel.speaker.queue(self.this_trial.stimulus.file_origin) + self.this_trial.annotate(stimulus_time=dt.datetime.now()) + self.panel.speaker.play() + + def response_main(self): + """ Poll for an interruption for the duration of the stimulus. """ + + self.this_trial.response_time = self.panel.response_port.poll(self.this_trial.stimulus.duration) + logger.debug("Received peck or timeout. Stopping playback") + + self.panel.speaker.stop() + logger.debug("Playback stopped") + + if self.this_trial.response_time is None: + logger.debug("No peck was received") + self.this_trial.response = False + self.start_immediately = False # Next trial will poll for a response before beginning + self.this_trial.rt = np.nan + else: + logger.debug("Peck was received") + self.this_trial.response = True + self.start_immediately = True # Next trial will begin immediately + self.this_trial.rt = self.this_trial.response_time - \ + self.this_trial.annotations["stimulus_time"] + + def reward_main(self): + """ Reward a correct non-interruption """ + + value = self.parameters.get('reward_value', 12) + logger.info("Supplying reward for %3.2f seconds" % value) + reward_event = self.panel.reward(value=value) + if isinstance(reward_event, dt.datetime): # There was a response during the reward period + self.start_immediately = True + + +if __name__ == "__main__": + + # Load config file + config_file = "/path/to/config" + if config_file.lower().endswith(".json"): + parameters = configure.ConfigureJSON.load(config_file) + elif config_file.lower().endswith(".yaml"): + parameters = configure.ConfigureYAML.load(config_file) + + # Create experiment object + exp = GoNoGoInterrupt(**parameters) + exp.run() diff --git a/pyoperant/behavior/go_no_go_interrupt_config.yaml b/pyoperant/behavior/go_no_go_interrupt_config.yaml new file mode 100644 index 00000000..3106f974 --- /dev/null +++ b/pyoperant/behavior/go_no_go_interrupt_config.yaml @@ -0,0 +1,96 @@ +--- +# The main format is as follows: +# Each section contains information that will help to configure your +# experiment. Where possible, values that are not supplied will be filled in +# with default values supplied by the given experiment. The author of a +# particular behavior script should also write a config file that will serve as +# a template and possibly be used to provide the default values. If a section +# requires specification of a python object, the formatting should go something +# like this: +# object_name: !!python/object/apply:module.submodule.Class +# args: [list of arguments] +# kwds: +# param1: value1 +# param2: value2 +# ... +# paramn: valuen +# You can refer to a previously defined object using "&some_name" before +# defining the object and "*some_name" when referencing the object + +# Experiment description +name: Go No-Go Interruption +description: > + Runs a Go No-Go Interruption experiment +experimenter: + name: First Last + email: &def_email "test@somewhere.com" + +# File paths +experiment_path: "/path/to/data/directory" + +# Verbose logging +debug: false + +# Number of sessions to run +num_sessions: 1 + +# Behavior properties +reward_value: 12.0 + +# Subject +subject_name: TestSubject + +# Panel properties +panel: !!python/object/apply:pyoperant.panels.BasePanel {} + +# Stimulus conditions +conditions: + - &NoGo !!python/object/apply:pyoperant.stimuli.StimulusConditionWav + kwds: + name: "No-Go" + file_path: "/path/to/nogo/stimuli" + response: false + is_rewarded: true + is_punished: false + - &Go !!python/object:pyoperant.stimuli.StimulusConditionWav + kwds: + name: "Go" + file_path: "/path/to/go/stimuli" + response: true + is_rewarded: false + is_punished: false + + +# Blocks +blocks: + # Block number 1 + - !!python/object/apply:pyoperant.blocks.Block + kwds: + conditions: + # NoGo stimuli + - *NoGo + # Go stimuli + - *Go + queue: !!python/name:pyoperant.queues.random_queue + # Weights for random queue + weights: + - 0.2 + - 0.8 + reinforcement: !!python/object/apply:pyoperant.reinf.ContinuousReinforcement {} + +# Block Handler +block_queue: !!python/name:pyoperant.queues.block_queue + +# Log handler setup +# Possible values are stream, file, email +log_handlers: + # stream's only option is level. Overrides "debug" parameter for logging + stream: + level: !!python/name:logging.INFO + # file takes options of + # filename: a file under experiment_path + # level: a python logging level, written as "!!python/name:logging.LEVEL" + file: + filename: "experiment.log" + level: !!python/name:logging.DEBUG +... diff --git a/pyoperant/behavior/simple_stimulus_playback.py b/pyoperant/behavior/simple_stimulus_playback.py new file mode 100644 index 00000000..22002d2a --- /dev/null +++ b/pyoperant/behavior/simple_stimulus_playback.py @@ -0,0 +1,127 @@ +import logging +import datetime as dt +import numpy as np +from pyoperant.behavior import base +from pyoperant.errors import EndSession +from pyoperant import utils, stimuli, queues +from pyoperant import blocks as blocks_ + +logger = logging.getLogger(__name__) + + +class SimpleStimulusPlayback(base.BaseExp): + """ Simply plays back stimuli with a random or fixed intertrial interval + + Additional Parameters + --------------------- + intertrial_interval: float or 2-element list + If the value is a float, then the intertrial interval is fixed. If it is a list, then the interval is taken as from a uniform random distribution between the first and second elements. + stimulus_directory: string or list + Full path to the stimulus directory. If given, a stimulus condition will be created and passed in to BaseExp. Can also be a list of dictionaries with name and directory keys. + + For all other parameters, see pyoperant.behavior.base.BaseExp + + Required Panel Attributes + ------------------------- + sleep - Puts the panel to sleep + reset - Sets the panel back to a nice initial state + ready - Prepares the panel to run the behavior (e.g. turn on the + response_port light and put the feeder down) + idle - Sets the panel into an idle state for when the experiment is not + running + speaker - A speaker for sound output + + Fields To Save + -------------- + session - The index of the current session + index - The index of the current trial + time - The start time of the trial + stimulus_name - The filename of the stimulus + intertrial_interval - The intertrial interval preceding the trial + """ + + req_panel_attr = ["sleep", + "reset", + "idle", + "ready", + "speaker"] + + fields_to_save = ['session', + 'index', + 'time', + 'stimulus_name', + 'intertrial_interval'] + + def __init__(self, intertrial_interval=2.0, stimulus_directory=None, + queue=queues.random_queue, reinforcement=None, + queue_parameters=None, *args, **kwargs): + + # let's parse any stimulus directories provided + if stimulus_directory is not None: + # Append to any existing blocks + blocks = kwargs.pop("blocks", list()) + + if queue_parameters is None: + queue_parameters = dict() + + # If a path is given, convert it to the list of dictionaries + if isinstance(stimulus_directory, str): + stimulus_directory = [dict(name="Playback", + directory=stimulus_directory)] + + for ii, stim_dict in enumerate(stimulus_directory): + # Default name is Playback# + name = stim_dict.get("name", "Playback%d" % ii) + directory = stim_dict["directory"] + # Create a stimulus condition for this directory + condition = stimuli.StimulusConditionWav(name=name, + file_path=directory, + is_rewarded=False, + is_punished=False, + response=False) + + # Create a block for this condition + block = blocks_.Block([condition], + queue=queue, + reinforcement=reinforcement, + **queue_parameters) + blocks.append(block) + + self.intertrial_interval = intertrial_interval + + super(SimpleStimulusPlayback, self).__init__(blocks=blocks, + *args, **kwargs) + + + def trial_pre(self): + """ Store data that is specific to this experiment, and compute a wait time for an intertrial interval + """ + + stimulus = self.this_trial.stimulus.file_origin + if isinstance(self.intertrial_interval, (list, tuple)): + iti = np.random.uniform(*self.intertrial_interval) + else: + iti = self.intertrial_interval + + logger.debug("Waiting for %1.3f seconds" % iti) + self.this_trial.annotate(stimulus_name=stimulus, + intertrial_interval=iti) + utils.wait(iti) + + def stimulus_main(self): + """ Queue the sound and play it """ + + logger.info("Trial %d - %s - %s" % ( + self.this_trial.index, + self.this_trial.time.strftime("%H:%M:%S"), + self.this_trial.stimulus.name + )) + + self.panel.speaker.queue(self.this_trial.stimulus.file_origin) + self.panel.speaker.play() + + # Wait for stimulus to finish + utils.wait(self.this_trial.stimulus.duration) + + # Stop the sound + self.panel.speaker.stop() diff --git a/pyoperant/behavior/simple_stimulus_playback.yaml b/pyoperant/behavior/simple_stimulus_playback.yaml new file mode 100644 index 00000000..06dbf406 --- /dev/null +++ b/pyoperant/behavior/simple_stimulus_playback.yaml @@ -0,0 +1,36 @@ +--- +# Experiment description +name: Simple Stimulus Playback +description: > + Loops through a directory of stimuli and plays them back, with some intertrial interval +experimenter: + name: First Last + email: &def_email "test@somewhere.com" +debug: false +num_sessions: 1 +experiment_path: "/path/to/data/directory" + +# Behavior properties +intertrial_interval: [0.5, 1.5] +# a directory or a list of dictionaries with name and directory keys +stimulus_directory: + - name: Playback + directory: "/path/to/stimulus/directory" + +# Additional configuration +subject_name: TestSubject +panel: !!python/object/apply:pyoperant.panels.BasePanel {} + +# Log handler setup +# Possible values are stream, file, email +log_handlers: + # stream's only option is level. Overrides "debug" parameter for logging + stream: + level: !!python/name:logging.INFO + # file takes options of + # filename: a file under experiment_path + # level: a python logging level, written as "!!python/name:logging.LEVEL" + file: + filename: "experiment.log" + level: !!python/name:logging.INFO +... diff --git a/pyoperant/blocks.py b/pyoperant/blocks.py new file mode 100644 index 00000000..35ab63f1 --- /dev/null +++ b/pyoperant/blocks.py @@ -0,0 +1,154 @@ +import logging +from pyoperant import queues, reinf, utils, trials + +logger = logging.getLogger(__name__) + + +class Block(queues.BaseHandler): + """ Class that allows one to iterate over a block of trials according to a + specific queue. + + Parameters + ---------- + conditions: list + A list of StimulusConditions to iterate over according to the queue + index: int + Index of the block + experiment: instance of Experiment class + The experiment of which this block is a part + queue: a queue function or Class + The queue used to iterate over trials for this block + reinforcement: instance of Reinforcement class (ContinuousReinforcement()) + The reinforcement schedule to use for this block. + Additional key-value pairs are used to initialize the trial queue + + Attributes + ---------- + conditions: list + A list of StimulusConditions to iterate over according to the queue + index: int + Index of the block + experiment: instance of Experiment class + The experiment of which this block is a part + queue: queue generator or class instance + The queue that will be iterated over. + reinforcement: instance of Reinforcement class (ContinuousReinforcement()) + The reinforcement schedule to use for this block. + + Examples + -------- + # Initialize a block with a random queue, and at most 200 trials. + trials = Block(conditions, + experiment=e, + queue=queues.random_queue, + max_items=200) + for trial in trials: + trial.run() + """ + + def __init__(self, conditions, index=0, experiment=None, + queue=queues.random_queue, reinforcement=None, + **queue_parameters): + + if conditions is None: + raise ValueError("Block must be called with a list of conditions") + + # Could check to ensure reinforcement is of the correct type + if reinforcement is None: + reinforcement = reinf.ContinuousReinforcement() + + super(Block, self).__init__(queue=queue, + items=conditions, + **queue_parameters) + + self.index = index + self.experiment = experiment + self.conditions = conditions + self.reinforcement = reinforcement + + logger.debug("Initialize block: %s" % self) + + def __str__(self): + + desc = ["Block"] + if self.conditions is not None: + desc.append("%d stimulus conditions" % len(self.conditions)) + if self.queue is not None: + desc.append("queue = %s" % self.queue.__name__) + + return " - ".join(desc) + + def check_completion(self): + + # if self.end is not None: + # if utils.check_time((self.start, self.end)): # Will start ever be none? Shouldn't be. + # logger.debug("Block is complete due to time") + # return True # Block is complete + + # if self.max_trials is not None: + # if self.num_trials >= self.max_trials: + # logger.debug("Block is complete due to trial count") + # return True + + return False + + def __iter__(self): + + # Loop through the queue generator + trial_index = 0 + for condition in self.queue: + # Create a trial instance + trial_index += 1 + trial = trials.Trial(index=trial_index, + experiment=self.experiment, + condition=condition, + block=self) + yield trial + + +class BlockHandler(queues.BaseHandler): + """ Class which enables iterating over blocks of trials + + Parameters + ---------- + blocks: list + A list of Block objects + queue: a queue function or Class + The queue used to iterate over blocks + Additional key-value pairs are used to initialize the queue + + Attributes + ---------- + block_index: int + Index of the current block + blocks: list + A list of Block objects + queue: queue generator or class instance + The queue that will be iterated over. + queue_parameters: dict + All additional parameters used to initialize the queue. + + Example + ------- + # Initialize the BlockHandler + blocks = BlockHandler(blocks, queue=queues.block_queue) + # Loop through the blocks, then loop through all trials in the block + for block in blocks: + for trial in block: + trial.run() + """ + + def __init__(self, blocks, queue=queues.block_queue, **queue_parameters): + + self.blocks = blocks + self.block_index = 0 + super(BlockHandler, self).__init__(queue=queue, + items=blocks, + **queue_parameters) + + def __iter__(self): + + for block in self.queue: + self.block_index += 1 + block.index = self.block_index + yield block diff --git a/pyoperant/components.py b/pyoperant/components.py index cef3f698..9fc7428e 100644 --- a/pyoperant/components.py +++ b/pyoperant/components.py @@ -2,34 +2,61 @@ from pyoperant import hwio, utils, ComponentError class BaseComponent(object): - """Base class for physcal component""" + """Base class for physcal component + + Parameters + ---------- + name: string + The name of the component to be stored in events (defaults to the + class name) + + Attributes + ---------- + name: string + The name of the component to be stored in events (defaults to the + class name) + event: + A dictionary that is sent along with changes to the state of inputs and + outputs. It contains three main keys: name, action, and metadata. name + is set here, action is set when a component method is called, and + metadata is set as an optional argument to some methods. + """ def __init__(self, name=None, *args, **kwargs): + if name is None: + name = self.__class__.__name__ self.name = name + self.event = dict(name=self.name, + action="", + metadata=None) pass ## Hopper ## - class HopperActiveError(ComponentError): """raised when the hopper is up when it shouldn't be""" pass + class HopperInactiveError(ComponentError): """raised when the hopper is down when it shouldn't be""" pass + class HopperAlreadyUpError(HopperActiveError): """raised when the hopper is already up before it goes up""" pass + class HopperWontComeUpError(HopperInactiveError): """raised when the hopper won't come up""" pass + class HopperWontDropError(HopperActiveError): """raised when the hopper won't drop""" pass + class Hopper(BaseComponent): """ Class which holds information about a hopper @@ -37,35 +64,38 @@ class Hopper(BaseComponent): ---------- solenoid : `hwio.BooleanOutput` output channel to activate the solenoid & raise the hopper - IR : :class:`hwio.BooleanInput` - input channel for the IR beam to check if the hopper is up - max_lag : float, optional - time in seconds to wait before checking to make sure the hopper is up (default=0.3) + IR : :class:`hwio.BooleanInput` + input channel for the IR beam to check if the hopper is up (optional) + max_lag : float, optional + time in seconds to wait before checking to make sure the hopper is up + (default=0.3) Attributes ---------- - solenoid : hwio.BooleanOutput + solenoid : hwio.BooleanOutput output channel to activate the solenoid & raise the hopper - IR : hwio.BooleanInput - input channel for the IR beam to check if the hopper is up - max_lag : float + IR : hwio.BooleanInput + input channel for the IR beam to check if the hopper is up (optional) + max_lag : float time in seconds to wait before checking to make sure the hopper is up """ - def __init__(self,IR,solenoid,max_lag=0.3,*args,**kwargs): - super(Hopper, self).__init__(*args,**kwargs) + def __init__(self, solenoid, IR=None, max_lag=0.3, *args, **kwargs): + super(Hopper, self).__init__(*args, **kwargs) self.max_lag = max_lag - if isinstance(IR,hwio.BooleanInput): - self.IR = IR - else: + + if (IR is not None) and (not isinstance(IR, hwio.BooleanInput)): raise ValueError('%s is not an input channel' % IR) - if isinstance(solenoid,hwio.BooleanOutput): + self.IR = IR + + if isinstance(solenoid, hwio.BooleanOutput): self.solenoid = solenoid else: raise ValueError('%s is not an output channel' % solenoid) def check(self): - """reads the status of solenoid & IR beam, then throws an error if they don't match + """ Reads the status of solenoid & IR beam, then throws an error if they + don't match. If IR is None, then trust the solenoid's status. Returns ------- @@ -75,11 +105,16 @@ def check(self): Raises ------ HopperActiveError - The Hopper is up and it shouldn't be. (The IR beam is tripped, but the solenoid is not active.) + The Hopper is up and it shouldn't be. (The IR beam is tripped, but + the solenoid is not active.) HopperInactiveError - The Hopper is down and it shouldn't be. (The IR beam is not tripped, but the solenoid is active.) + The Hopper is down and it shouldn't be. (The IR beam is not tripped, + but the solenoid is active.) """ + if self.IR is None: + return self.solenoid.read() is True + IR_status = self.IR.read() solenoid_status = self.solenoid.read() if IR_status != solenoid_status: @@ -88,17 +123,17 @@ def check(self): elif solenoid_status: raise HopperInactiveError else: - raise ComponentError("the IR & solenoid don't match: IR:%s,solenoid:%s" % (IR_status,solenoid_status)) + raise ComponentError("the IR & solenoid don't match: IR:%s,solenoid:%s" % (IR_status, solenoid_status)) else: return IR_status def up(self): - """Raises the hopper up. + """ Raises the hopper up. Returns ------- - bool - True if the hopper comes up. + datetime + Time at which the hopper came up Raises ------ @@ -106,29 +141,34 @@ def up(self): The Hopper did not raise. """ - self.solenoid.write(True) + self.event["action"] = "up" + self.solenoid.write(True, event=self.event) + if self.IR is None: + return datetime.datetime.now() + time_up = self.IR.poll(timeout=self.max_lag) - if time_up is None: # poll timed out + if time_up is None: # poll timed out self.solenoid.write(False) raise HopperWontComeUpError else: return time_up def down(self): - """Lowers the hopper. + """ Lowers the hopper. Returns ------- - bool - True if the hopper drops. + datetime + Time at which the hopper came down Raises ------ HopperWontDropError The Hopper did not drop. """ - self.solenoid.write(False) + self.event["action"] = "down" + self.solenoid.write(False, event=self.event) time_down = datetime.datetime.now() utils.wait(self.max_lag) try: @@ -137,12 +177,12 @@ def down(self): raise HopperWontDropError(e) return time_down - def feed(self,dur=2.0,error_check=True): - """Performs a feed + def feed(self, dur=2.0, error_check=True): + """ Performs a feed Parameters - --------- - dur : float, optional + ---------- + dur : float, optional duration of feed in seconds Returns @@ -161,7 +201,7 @@ def feed(self,dur=2.0,error_check=True): The Hopper did not drop fater the feed. """ - assert self.max_lag < dur, "max_lag (%ss) must be shorter than duration (%ss)" % (self.max_lag,dur) + assert self.max_lag < dur, "max_lag (%ss) must be shorter than duration (%ss)" % (self.max_lag, dur) try: self.check() except HopperActiveError as e: @@ -171,14 +211,67 @@ def feed(self,dur=2.0,error_check=True): utils.wait(dur) feed_over = self.down() feed_duration = feed_over - feed_time - return (feed_time,feed_duration) + return (feed_time, feed_duration) + + def reward(self, value=2.0): + """ Performs a feed as a reward - def reward(self,value=2.0): - """wrapper for `feed`, passes *value* into *dur* """ + Parameters + ---------- + value : float, optional + duration of feed in seconds + + Returns + ------- + (datetime, float) + Timestamp of the feed and the feed duration + """ return self.feed(dur=value) -## Peck Port ## +class Button(BaseComponent): + """ Class which holds information about buttons with an input but no output. + Could also describe a perch. + + Parameters + ---------- + IR : hwio.BooleanInput + input channel for the IR beam to check for a peck + + Attributes + ---------- + IR : hwio.BooleanInput + input channel for the IR beam to check for a peck + """ + def __init__(self, IR, *args, **kwargs): + super(Button, self).__init__(*args, **kwargs) + if isinstance(IR, hwio.BooleanInput): + self.IR = IR + else: + raise ValueError('%s is not an input channel' % IR) + + def status(self): + """ Reads the status of the IR beam + + Returns + ------- + bool + True if beam is broken + """ + return self.IR.read() + + def poll(self, timeout=None): + """ Polls the peck port until there is a peck + + Returns + ------- + datetime + Timestamp of the IR beam being broken. + """ + return self.IR.poll(timeout=timeout) + + +## Peck Port ## class PeckPort(BaseComponent): """ Class which holds information about peck ports @@ -197,19 +290,20 @@ class PeckPort(BaseComponent): input channel for the IR beam to check for a peck """ - def __init__(self,IR,LED,*args,**kwargs): - super(PeckPort, self).__init__(*args,**kwargs) - if isinstance(IR,hwio.BooleanInput): + def __init__(self, IR, LED, *args, **kwargs): + super(PeckPort, self).__init__(*args, **kwargs) + if isinstance(IR, hwio.BooleanInput): self.IR = IR else: raise ValueError('%s is not an input channel' % IR) - if isinstance(LED,hwio.BooleanOutput): + + if isinstance(LED, hwio.BooleanOutput): self.LED = LED else: raise ValueError('%s is not an output channel' % LED) def status(self): - """reads the status of the IR beam + """ Reads the status of the IR beam Returns ------- @@ -219,29 +313,32 @@ def status(self): return self.IR.read() def off(self): - """ Turns the LED off + """ Turns the LED off Returns ------- bool True if successful """ - self.LED.write(False) + self.event["action"] = "off" + self.LED.write(False, event=self.event) return True def on(self): - """Turns the LED on + """ Turns the LED on Returns ------- bool True if successful """ - self.LED.write(True) + self.event["action"] = "on" + self.LED.write(True, event=self.event) return True - def flash(self,dur=1.0,isi=0.1): - """Flashes the LED on and off with *isi* seconds high and low for *dur* seconds, then revert LED to prior state. + def flash(self, dur=1.0, isi=0.1): + """ Flashes the LED on and off with *isi* seconds high and low for *dur* + seconds, then revert LED to prior state. Parameters ---------- @@ -263,9 +360,9 @@ def flash(self,dur=1.0,isi=0.1): utils.wait(isi) flash_duration = datetime.datetime.now() - flash_time self.LED.write(LED_state) - return (flash_time,flash_duration) + return (flash_time, flash_duration) - def poll(self,timeout=None): + def poll(self, timeout=None): """ Polls the peck port until there is a peck Returns @@ -273,27 +370,27 @@ def poll(self,timeout=None): datetime Timestamp of the IR beam being broken. """ - return self.IR.poll(timeout) + return self.IR.poll(timeout=timeout) + ## House Light ## class HouseLight(BaseComponent): """ Class which holds information about the house light - Keywords - -------- + Parameters + ---------- light : hwio.BooleanOutput output channel to turn the light on and off - Methods: - on() -- - off() -- - timeout(dur) -- turns off the house light for 'dur' seconds (default=10.0) - punish() -- calls timeout() for 'value' as 'dur' + Attributes + ---------- + light : hwio.BooleanOutput + output channel to turn the light on and off """ - def __init__(self,light,*args,**kwargs): - super(HouseLight, self).__init__(*args,**kwargs) - if isinstance(light,hwio.BooleanOutput): + def __init__(self, light, *args, **kwargs): + super(HouseLight, self).__init__(*args, **kwargs) + if isinstance(light, hwio.BooleanOutput): self.light = light else: raise ValueError('%s is not an output channel' % light) @@ -307,7 +404,8 @@ def off(self): True if successful. """ - self.light.write(False) + self.event["action"] = "off" + self.light.write(False, event=self.event) return True def on(self): @@ -318,14 +416,15 @@ def on(self): bool True if successful. """ - self.light.write(True) + self.event["action"] = "on" + self.light.write(True, event=self.event) return True - def timeout(self,dur=10.0): - """Turn off the light for *dur* seconds + def timeout(self, dur=10.0): + """Turn off the light for *dur* seconds - Keywords - ------- + Parameters + ---------- dur : float, optional The amount of time (in seconds) to turn off the light. @@ -336,24 +435,34 @@ def timeout(self,dur=10.0): """ timeout_time = datetime.datetime.now() - self.light.write(False) + self.off() utils.wait(dur) timeout_duration = datetime.datetime.now() - timeout_time - self.light.write(True) - return (timeout_time,timeout_duration) + self.on() + return (timeout_time, timeout_duration) - def punish(self,value=10.0): - """Calls `timeout(dur)` with *value* as *dur* """ + def punish(self, value=10.0): + """ Turns light off as a punishment + + Parameters + --------- + value : float, optional + duration of timeout in seconds + + Returns + ------- + (datetime, float) + Timestamp of the timeout and the timeout duration + """ return self.timeout(dur=value) ## Cue Light ## - class RGBLight(BaseComponent): """ Class which holds information about an RGB cue light - Keywords - -------- + Parameters + ---------- red : hwio.BooleanOutput output channel for the red LED green : hwio.BooleanOutput @@ -361,18 +470,29 @@ class RGBLight(BaseComponent): blue : hwio.BooleanOutput output channel for the blue LED + Attributes + ---------- + _red : hwio.BooleanOutput + output channel for the red LED + _green : hwio.BooleanOutput + output channel for the green LED + _blue : hwio.BooleanOutput + output channel for the blue LED + """ - def __init__(self,red,green,blue,*args,**kwargs): + def __init__(self, red, green, blue, *args, **kwargs): super(RGBLight, self).__init__(*args,**kwargs) - if isinstance(red,hwio.BooleanOutput): + if isinstance(red, hwio.BooleanOutput): self._red = red else: raise ValueError('%s is not an output channel' % red) - if isinstance(green,hwio.BooleanOutput): + + if isinstance(green, hwio.BooleanOutput): self._green = green else: raise ValueError('%s is not an output channel' % green) - if isinstance(blue,hwio.BooleanOutput): + + if isinstance(blue, hwio.BooleanOutput): self._blue = blue else: raise ValueError('%s is not an output channel' % blue) @@ -385,9 +505,11 @@ def red(self): bool `True` if successful. """ + self.event["action"] = "red" self._green.write(False) self._blue.write(False) - return self._red.write(True) + return self._red.write(True, event=self.event) + def green(self): """Turns the cue light to green @@ -396,9 +518,11 @@ def green(self): bool `True` if successful. """ + self.event["action"] = "green" self._red.write(False) self._blue.write(False) - return self._green.write(True) + return self._green.write(True, event=self.event) + def blue(self): """Turns the cue light to blue @@ -407,9 +531,11 @@ def blue(self): bool `True` if successful. """ + self.event["action"] = "blue" self._red.write(False) self._green.write(False) - return self._blue.write(True) + return self._blue.write(True, event=self.event) + def off(self): """Turns the cue light off @@ -418,12 +544,85 @@ def off(self): bool `True` if successful. """ - self._red.write(False) + self.event["action"] = "off" + self._red.write(False, event=self.event) self._green.write(False) self._blue.write(False) return True +class Speaker(BaseComponent): + """ Class which holds information about a speaker + + Parameters + ---------- + output: hwio.AudioOutput + Output of the speaker + + Attributes + ---------- + output: hwio.AudioOutput + Output of the speaker + """ + + def __init__(self, output, *args, **kwargs): + + super(Speaker, self).__init__(*args, **kwargs) + self.output = output + + def queue(self, wav_filename, metadata=None): + + self.event["action"] = "queue" + self.event["metadata"] = metadata + return self.output.queue(wav_filename, event=self.event) + + def play(self): + + self.event["action"] = "play" + return self.output.play(event=self.event) + + def stop(self): + + self.event["action"] = "stop" + return self.output.stop(event=self.event) + + +class Microphone(BaseComponent): + """ Class which holds information about a microphone + + Parameters + ---------- + input_: hwio.AnalogInput + Input to the microphone + + Attributes + ---------- + input: hwio.AnalogInput + Input to the microphone + """ + + def __init__(self, input_, *args, **kwargs): + + super(Speaker, self).__init__(*args, **kwargs) + self.input = input_ + + def record(self, nsamples): + """ Reads from input for a set number of samples + + Parameters + ---------- + nsamples: int + Number of samples to read from input + + Returns + ------- + numpy array + The analog signal recorded by input + """ + # TODO: This should use a stop signal too, I think + self.event["action"] = "rec" + return self.input.read(nsamples, event=self.event) + # ## Perch ## # class Perch(BaseComponent): @@ -435,4 +634,3 @@ def off(self): # """ # def __init__(self,*args,**kwargs): # super(Perch, self).__init__(*args,**kwargs) - diff --git a/pyoperant/configure.py b/pyoperant/configure.py new file mode 100644 index 00000000..c526d547 --- /dev/null +++ b/pyoperant/configure.py @@ -0,0 +1,124 @@ +import os + + +class ConfigureJSON(object): + + @classmethod + def load(cls, config_file): + """ Load experiment parameters from a JSON configuration file + + Parameters + ---------- + config_file: string + path to a JSON configuration file + + Returns + ------- + dictionary (or list of dictionaries) of parameters to pass to a behavior + """ + try: + import simplejson as json + except ImportError: + import json + + with open(config_file, 'rb') as config: + parameters = json.load(config) + + return parameters + + @staticmethod + def save(parameters, filename, overwrite=False): + """ Save a dictionary of parameters to an experiment JSON config file + + Parameters + ---------- + parameters: dictionary + experiment parameters + filename: string + path to output file + overwrite: bool + whether or not to overwrite if the output file already exists + """ + try: + import simplejson as json + except ImportError: + import json + + if os.path.exists(filename) and (overwrite is False): + raise IOError("File %s already exists! To overwrite, set overwrite=True" % filename) + + with open(filename, "w") as json_file: + json.dump(parameters, + json_file, + sort_keys=True, + indent=4, + separators=(",", ":")) + + +class ConfigureYAML(object): + """ Configuration using YAML files. Thanks to pyyaml (http://pyyaml.org/wiki/PyYAMLDocumentation), this type of configuration file can be very flexible. It supports multiple experiments in the form of multiple documents in one file. It also allows for expressing native python objects (including custom classes) directly in the configuration file. + """ + + @classmethod + def load(cls, config_file): + """ Load experiment parameters from a YAML configuration file + + Parameters + ---------- + config_file: string + path to a YAML configuration file + + Returns + ------- + dictionary (or list of dictionaries) of parameters to pass to a behavior + """ + try: + import yaml + except ImportError: + raise ImportError("Pyyaml is required to use a .yaml configuration file") + + parameters = list() + with open(config_file, "rb") as config: + for val in yaml.load_all(config): + parameters.append(val) + + if len(parameters) == 1: + parameters = parameters[0] + + return parameters + + @staticmethod + def save(parameters, filename, overwrite=False): + """ Save a dictionary of parameters to an experiment YAML config file + + Parameters + ---------- + parameters: dictionary + experiment parameters + filename: string + path to output file + overwrite: bool + whether or not to overwrite if the output file already exists + """ + try: + import yaml + except ImportError: + raise ImportError("Pyyaml is required to use a .yaml configuration file") + + if os.path.exists(filename) and (overwrite is False): + raise IOError("File %s already exists! To overwrite, set overwrite=True" % filename) + + with open(filename, "w") as yaml_file: + yaml.dump(parameters, yaml_file, + indent=4, + explicit_start=True, + explicit_end=True) + + +# ## What is this?? +# class ConfigurableYAML(type): +# +# def __new__(cls, *args, **kwargs): +# +# ConfigureYAML.constructors.append(cls) +# return super(ConfigureableYAML, cls, *args, **kwargs) diff --git a/pyoperant/errors.py b/pyoperant/errors.py index 9986471e..09084c18 100644 --- a/pyoperant/errors.py +++ b/pyoperant/errors.py @@ -3,6 +3,10 @@ class GoodNite(Exception): """ exception for when the lights should be off """ pass +class EndExperiment(Exception): + """ exception for when an experiment should terminate""" + pass + class EndSession(Exception): """ exception for when a session should terminate """ pass @@ -19,7 +23,7 @@ class Error(Exception): class InterfaceError(Exception): '''raised for errors with an interface. - this should indicate a software error, like difficulty + this should indicate a software error, like difficulty connecting to an interface ''' pass @@ -27,11 +31,15 @@ class InterfaceError(Exception): class ComponentError(Exception): '''raised for errors with a component. - this should indicate a hardware error in the physical world, + this should indicate a hardware error in the physical world, like a problem with a feeder. - this should be raised by components when doing any internal + this should be raised by components when doing any internal validation that they are working properly ''' pass + +class WriteCannotBeReadError(Exception): + '''raised when an interface configured to write output cannot be read ''' + pass diff --git a/pyoperant/events.py b/pyoperant/events.py new file mode 100644 index 00000000..78cad647 --- /dev/null +++ b/pyoperant/events.py @@ -0,0 +1,435 @@ +import threading +import Queue +# from multiprocessing import Process, Queue +import datetime as dt +import logging +import numpy as np +from pyoperant import hwio + +logger = logging.getLogger(__name__) + +class Events(object): + """ Writes small event dictionaries out to a list of event handlers. + + Attributes + ---------- + handlers: list + The currently configured EventHandler instances + + Methods + ------- + add_handler(handler) - appends the handler to the handlers list + write(event) - Writes the dictionary `event` to each handler + """ + + def __init__(self): + + self.handlers = list() + + def add_handler(self, handler): + """ Adds the handler to the list of handlers, as long as it supports + writing. + + Parameters + ---------- + handler: instance of EventHandler + The handler to be added + """ + + if not hasattr(handler, "queue"): + raise AttributeError("Event handler instance must contain a queue") + + self.handlers.append(handler) + + def close_handlers(self): + """ Closes all of the existing handlers """ + + for handler in self.handlers: + handler.close() + + def write(self, event): + """ Places the event in the queue for each handler to write. + + Parameters + ---------- + event: dict + A dictionary describing the current component event. It should have + 3 keys: name, action, and metadata. A time key will be added + containing the datetime of the event. + """ + if event is None: + return + + event["time"] = dt.datetime.now() + for handler in self.handlers: + logger.debug("Adding to handler %s" % str(handler)) + handler.queue.put(event) + + +class EventHandler(object): + """ Base class for all event handlers. Creates a separate thread that writes each event out using the handler's "write" method. + + Parameters + ---------- + component: string + Optionally argument that allows one to only log events with the + specified component name. + + Attributes + ---------- + thread: threading.Thread instance + The thread that loops infinitely and writes each event placed in its + queue out through the handler. + queue: queue.Queue instance + The queue that handles communicating events between threads. + + Methods + ------- + write(event) - Writes the event using the specified handler + close() - Ends the thread so everything can be properly closed out + """ + STOP_FLAG = 0 + + def __init__(self, component=None, *args, **kwargs): + + super(EventHandler, self).__init__(*args, **kwargs) + + self.component = component + + # Initialize the queue + self.queue = Queue.Queue(maxsize=0) + #self.queue = Queue(maxsize=0) + + # Initialize the thread + self.thread = threading.Thread(target=self.run, name=self.__class__.__name__) + # self.thread = Process(target=self.run, name=self.__class__.__name__) + + # Run the thread + self.thread.start() + + def filter(self, event): + """ Returns True if the event should be written """ + + if self.component is None: + return True + + return self.event["name"] == self.component + + def run(self): + """ Runs inside the separate thread and calls the class' `write` method + on any new events + """ + while True: + event = self.queue.get() + if event is self.STOP_FLAG: + logger.debug("Stopping thread %s" % self.thread.name) + return + if self.filter(event): + self.write(event) + + def close(self): + """ Ends the separate thread """ + + self.queue.put(self.STOP_FLAG) + + def __del__(self): + + self.close() + + def write(self, event): + + raise NotImplementedError("Event handlers must implement a `write` method") + + +class EventInterfaceHandler(EventHandler, hwio.BooleanOutput): + """ Handler to send event information out to a boolean interface. The event + information is sent as a sequence of three chunks of bits, the first + describing the name of the component, the second describing the action, and + the third with any additional metadata. If the event details are too long to + fit in the requested number of bytes, they are truncated first. + + Parameters + ---------- + interface: instance of an Interface + The interface to use to write the bit string out to hardware + params: dictionary + A set of key-value pairs that are sent to the interface when configuring + the boolean write and when writing to it. + name_bytes: int + The number of bytes to use for encoding the name of the component + action_bytes: int + The number of bytes to use for encoding the action being performed + metadata_bytes: int + The number of bytes to use for any additional metadata. + component: string + Optionally argument that allows one to only log events with the + specified component name. + + Attributes + ---------- + thread: threading.Thread instance + The thread that loops infinitely and writes each event placed in its + queue out through the handler. + queue: queue.Queue instance + The queue that handles communicating events between threads. + + Methods + ------- + write(event) - Writes the event using the specified handler + close() - Ends the thread so everything can be properly closed out + to_bit_sequence(event) - Serializes the event details into a string of bits + """ + def __init__(self, interface, params={}, name_bytes=4, action_bytes=4, + metadata_bytes=16, component=None): + + self.name_bytes = name_bytes + self.action_bytes = action_bytes + self.metadata_bytes = metadata_bytes + self.component = component + self.map_to_bit = dict() + super(EventInterfaceHandler, self).__init__(interface=interface, + params=params, + component=component) + + def write(self, event): + """ Writes the event out the boolean output + + Parameters + ---------- + event: dict + A dictionary describing the current component event. It should have + 3 keys: name, action, and metadata. + """ + + try: + key = (event["name"], event["action"], event["metadata"]) + bits = self.map_to_bit[key] + except KeyError: + bits = self.to_bit_sequence(event) + self.interface._write_bool(value=bits, **self.params) + + def to_bit_sequence(self, event): + """ Creates an array of bits containing the details in the event + dictionary. Once created, the array is cached to speed up future writes. + + Parameters + ---------- + event: dict + A dictionary describing the current component event. It should have + 3 keys: name, action, and metadata. + + Returns + ------- + The array of bits + """ + + if event["metadata"] is None: + nbytes = self.action_bytes + self.name_bytes + metadata_array = [] + else: + nbytes = self.metadata_bytes + self.action_bytes + self.name_bytes + try: + metadata_array = np.fromstring(event["metadata"], + dtype=np.uint16).astype(np.uint8)[:self.metadata_bytes] + except TypeError: + metadata_array = np.array(map(ord, + event["metadata"].ljust(self.metadata_bytes)[:self.metadata_bytes]), + dtype=np.uint8) + + int8_array = np.zeros(nbytes, dtype="uint8") + int8_array[:self.name_bytes] = map(ord, event["name"].ljust(self.name_bytes)[:self.name_bytes]) + int8_array[self.name_bytes:self.name_bytes + self.action_bytes] = map(ord, event["action"].ljust(self.action_bytes)[:self.action_bytes]) + int8_array[self.name_bytes + self.action_bytes:] = metadata_array + + sequence = ([True] + + np.unpackbits(int8_array).astype(bool).tolist() + + [False]) + key = (event["name"], event["action"], event["metadata"]) + self.map_to_bit[key] = sequence + + return sequence + + def toggle(self): + pass + + +class EventDToAHandler(EventHandler): + """ Handler to format event information so that it can be sent as a sequence + of bits out an analog output. The event information is returned as a + sequence of three chunks of bits, the first describing the name of the + component, the second describing the action, and the third with any + additional metadata. If the event details are too long to fit in the + requested number of bytes, they are truncated first. Before being returned, + the sequence is upsampled by a certain factor and then converted to float64. + + Parameters + ---------- + name_bytes: int + The number of bytes to use for encoding the name of the component + action_bytes: int + The number of bytes to use for encoding the action being performed + metadata_bytes: int + The number of bytes to use for any additional metadata. + upsample_factor: int + The factor by which the bit sequence should be upsampled. + scaling: float + A scaling factor to scale the analog representation of the digital signal (e.g. send out 3.3 Volts to pass to a digital input) + component: string + Optionally argument that allows one to only log events with the + specified component name. + + All additional key-value pairs are stored for use by the interface + + Methods + ------- + to_bit_sequence(event) - Serializes the event details into a string of bits + """ + def __init__(self, name_bytes=4, action_bytes=4, metadata_bytes=16, + upsample_factor=1, scaling=1.0, component=None, + **interface_params): + + self.name_bytes = name_bytes + self.action_bytes = action_bytes + self.metadata_bytes = metadata_bytes + self.upsample_factor = upsample_factor + self.scaling = scaling + self.component = component + self.map_to_bit = dict() + self.queue = Queue.Queue(maxsize=0) + for key, value in interface_params.items(): + setattr(self, key, value) + + def filter(self, event): + """ Always returns False, as this one should never be called by Events + """ + + return False + + def write(self, event): + """ Does nothing """ + pass + + def to_bit_sequence(self, event): + """ Creates an array of bits containing the details in the event + dictionary. This array is then upsampled and converted to float64 to be + sent down an analog output. Once created, the array is cached to speed + up future calls. + + Parameters + ---------- + event: dict + A dictionary describing the current component event. It should have + 3 keys: name, action, and metadata. + + Returns + ------- + The array of bits expressed as analog values + """ + + key = (event["name"], event["action"], event["metadata"]) + # Check if the bit string is already stored + if key in self.map_to_bit: + return self.map_to_bit[key] + + trim = lambda ss, l: ss.ljust(l)[:l] + # Set up int8 arrays where strings are converted to integers using ord + name_array = np.array(map(ord, trim(event["name"], self.name_bytes)), + dtype=np.uint8) + action_array = np.array(map(ord, trim(event["action"], + self.action_bytes)), + dtype=np.uint8) + + # Add the metadata array if a value was passed + if event["metadata"] is not None: + metadata_array = np.array(map(ord, trim(event["metadata"], + self.metadata_bytes)), + dtype=np.uint8) + else: + metadata_array = np.array([], dtype=np.uint8) + + sequence = ([True] + + np.unpackbits(name_array).astype(bool).tolist() + + np.unpackbits(action_array).astype(bool).tolist() + + np.unpackbits(metadata_array).astype(bool).tolist() + + [False]) + sequence = np.repeat(sequence, self.upsample_factor).astype("float64") + sequence *= self.scaling + + self.map_to_bit[key] = sequence + + return sequence + + def close(self): + """ Nothing needs to be done """ + pass + + +class EventLogHandler(EventHandler): + """ Writes event details out to a file log. + + Parameters + ---------- + filename: string + Path to the output file + format: string + A string that can be formatted with the event dictionary + component: string + Optional argument that allows one to only log events with the + specified component name. + + Attributes + ---------- + thread: threading.Thread instance + The thread that loops infinitely and writes each event placed in its + queue out through the handler. + queue: queue.Queue instance + The queue that handles communicating events between threads. + + Methods + ------- + write(event) - Writes the event to the file + close() - Ends the thread so everything can be properly closed out + """ + def __init__(self, filename, format=None, component=None): + + self.filename = filename + if format is None: + self.format = "\t".join(["{time}", + "{name}", + "{action}", + "{metadata}"]) + super(EventLogHandler, self).__init__(component=component) + + def write(self, event): + """ Writes the event out to the file + + Parameters + ---------- + event: dict + A dictionary describing the current component event. It should have + 4 keys: name, action, and metadata added by the compnent, and time + added by the Events class. + """ + + if "time" not in event: + event["time"] = dt.datetime.now() + + with open(self.filename, "a") as fh: + fh.write(self.format.format(**event) + "\n") + +events = Events() + +if __name__ == "__main__": + + ihandler = EventInterfaceHandler(None) + events.add_handler(ihandler) + for ii in range(100): + events.write({}) + time.sleep(0.1) + + if ihandler.delay_queue.qsize() > 0: + for ii in range(ihandler.delay_queue.qsize()): + ihandler.delays.append(ihandler.delay_queue.get()) + + print("Mean delay was %.4e seconds" % (sum(ihandler.delays) / 100)) + print("Max delay was %.4e seconds" % max(ihandler.delays)) diff --git a/pyoperant/hwio.py b/pyoperant/hwio.py index e5702ab0..b7df506a 100644 --- a/pyoperant/hwio.py +++ b/pyoperant/hwio.py @@ -1,121 +1,431 @@ +import logging +from pyoperant.errors import WriteCannotBeReadError + +logger = logging.getLogger(__name__) + -# Classes of operant components class BaseIO(object): - """any type of IO device. maintains info on interface for query IO device""" - def __init__(self,interface=None,params={},*args,**kwargs): + """ Any type of IO device. Maintains info on the interface and configuration + params for querying the IO device + + Parameters + ---------- + name: string + A name given to the IO. Useful when it is used in logging and error + messages. + interface: subclass of base_.BaseInterface + An instance of an interface through which writes and reads are sent + params: dictionary + A dictionary of parameters for configuration and write/read calls. + Common keys are: subdevice, channel, etc. + + Attributes + ---------- + name: string + A name given to the IO. Useful when it is used in logging and error + messages. + interface: subclass of base_.BaseInterface + An instance of an interface through which writes and reads are sent + params: dictionary + A dictionary of parameters for configuration and write/read calls. + Common keys are: subdevice, channel, etc. + """ + + def __init__(self, name=None, interface=None, params={}, + *args, **kwargs): + + super(BaseIO, self).__init__(*args, **kwargs) + self.name = name self.interface = interface self.params = params + class BooleanInput(BaseIO): - """Class which holds information about inputs and abstracts the methods of - querying their values + """ Class which holds information about boolean inputs and abstracts the + methods of querying their values - Keyword arguments: - interface -- Interface() instance. Must have '_read_bool' method. - params -- dictionary of keyword:value pairs needed by the interface + Parameters + ---------- + interface: a subclass of base_.BaseInterface + Interface through which values are read. Must have '_read_bool' method. + params: dictionary + A dictionary of parameters for configuration and boolean read calls. + Common keys are: subdevice, channel, invert, etc. - Methods: - read() -- reads value of the input. Returns a boolean - poll() -- polls the input until value is True. Returns the time of the change + Attributes + ---------- + name: string + + interface: a subclass of base_.BaseInterface + Interface through which values are read. Must have '_read_bool' method. + params: dictionary + A dictionary of parameters for configuration and boolean read calls. + Common keys are: subdevice, channel, invert, etc. + last_value: bool + Most recently returned value + + Methods + ------- + config() + Configures the boolean input + read() + Reads value of the input. Returns a boolean + poll() + Polls the input until value is True. Returns the time of the change """ - def __init__(self,interface=None,params={},*args,**kwargs): - super(BooleanInput, self).__init__(interface=interface,params=params,*args,**kwargs) - assert hasattr(self.interface,'_read_bool') + def __init__(self, interface=None, params={}, + *args, **kwargs): + super(BooleanInput, self).__init__(interface=interface, + params=params, + *args, + **kwargs) + + assert self.interface.can_read_bool + self.last_value = False self.config() def config(self): - try: - return self.interface._config_read(**self.params) - except AttributeError: + """ Calls the interface's _config_read method with the keyword arguments + in params + + Returns + ------- + bool + True if configuration succeeded + """ + + if not hasattr(self.interface, "_config_read"): return False + return self.interface._config_read(**self.params) + def read(self): - """read status""" - return self.interface._read_bool(**self.params) + """ Read the status of the boolean input + + Returns + ------- + bool + The current status reported by the interface + """ + + self.last_value = self.interface._read_bool(**self.params) + return self.last_value + + def poll(self, timeout=None): + """ Runs a loop, querying for the boolean input to return True. - def poll(self,timeout=None): - """ runs a loop, querying for pecks. returns peck time or "GoodNite" exception """ - return self.interface._poll(timeout=timeout,**self.params) + Parameters + ---------- + timeout: float + + Returns + ------- + datetime or None + peck time or None if timeout + """ + + input_time = self.interface._poll(timeout=timeout, + last_value=self.last_value, + **self.params) + if input_time is not None: + self.last_value = True + else: + self.last_value = False + + return input_time class BooleanOutput(BaseIO): - """Class which holds information about outputs and abstracts the methods of - writing to them + """Class which holds information about boolean outputs and abstracts the + methods of writing to them - Keyword arguments: - interface -- Interface() instance. Must have '_write_bool' method. - params -- dictionary of keyword:value pairs needed by the interface + Parameters + ---------- + interface: subclass of base_.BaseInterface + Interface through which values are written. Must have '_write_bool' + method. + params: dictionary + A dictionary of parameters for configuration and boolean write calls. + Common keys are: subdevice, channel, invert, etc. - Methods: - write(value) -- writes a value to the output. Returns the value - read() -- if the interface supports '_read_bool' for this output, returns + Methods + ------- + config() + Configures the boolean output + write(value) + Writes a value to the output. Returns the value + read() + If the interface supports '_read_bool' for this output, returns the current value of the output from the interface. Otherwise this returns the last passed by write(value) - toggle() -- flips the value from the current value + toggle() + Flips the value from the current value """ - def __init__(self,interface=None,params={},*args,**kwargs): - super(BooleanOutput, self).__init__(interface=interface,params=params,*args,**kwargs) + def __init__(self, interface=None, params={}, *args, **kwargs): + super(BooleanOutput, self).__init__(interface=interface, + params=params, + *args, + **kwargs) - assert hasattr(self.interface,'_write_bool') + assert self.interface.can_write_bool self.last_value = None self.config() def config(self): - try: - return self.interface._config_write(**self.params) - except AttributeError: + """ Calls the interface's _config_write method with the keyword + arguments in params + + Returns + ------- + bool + True if configuration succeeded + """ + if not hasattr(self.interface, "_config_write"): return False + logger.debug("Configuring BooleanOutput to write on interface % s" % self.interface) + return self.interface._config_write(**self.params) + def read(self): - """read status""" - if hasattr(self.interface,'_read_bool'): - return self.interface._read_bool(**self.params) + """ Read the status of the boolean output, if supported + + Returns + ------- + bool + The current status reported by the interface or the last value + written to the interface. + """ + if self.interface.can_read_bool: + try: + value = self.interface._read_bool(**self.params) + except WriteCannotBeReadError: + value = self.last_value else: - return self.last_value + value = self.last_value + + return value - def write(self,value=False): - """write status""" - self.last_value = self.interface._write_bool(value=value,**self.params) + def write(self, value=False, event=None): + """ Writes to the boolean output + + Parameters + ---------- + value: bool + Value to be written to the output + event: dictionary + Dictionary containing event details that are passed along to the + interface. + + Returns + ------- + bool + The value written to the interface + """ + + logger.debug("Setting value to %s" % value) + self.last_value = self.interface._write_bool(value=value, + event=event, + **self.params) return self.last_value - def toggle(self): + def toggle(self, event=None): + """ Toggles the value of the boolean output + + Parameters + ---------- + event: dictionary + Dictionary containing event details that are passed along to the + interface. + + Returns + ------- + bool + The value written to the interface + """ + + # TODO: what will event be here? value = not self.read() - return self.write(value=value) + return self.write(value=value, event=event) -class AudioOutput(BaseIO): - """Class which holds information about audio outputs and abstracts the + +class AnalogInput(BaseIO): + """ Class which holds information about analog inputs and abstracts the + methods of reading from them + + Parameters + ---------- + interface: subclass of base_.BaseInterface + Interface through which values are written. Must have '_read_analog' + method. + params: dictionary + A dictionary of parameters for configuration and analog read calls. + Common keys are: subdevice, channel, etc. + + Methods + ------- + config() + Configures the analog input + read(nsamples) + Reads nsamples values from the input. Returns the values as an array + """ + def __init__(self, interface=None, params={}, *args, **kwargs): + super(AnalogInput, self).__init__(interface=interface, + params=params, + *args, + **kwargs) + assert self.interface.can_read_analog + + self.config() + + def config(self): + """ Calls the interface's _config_read_analog method with the keyword + arguments in params + + Returns + ------- + bool + True if configuration succeeded + """ + if not hasattr(self.interface, "_config_read_analog"): + return False + + logger.debug("Configuring AnalogInput to read on interface % s" % self.interface) + return self.interface._config_read_analog(**self.params) + + def read(self, nsamples): + """ Reads from the analog input + + Parameters + ---------- + nsamples: int + Number of samples to read from input + + Returns + ------- + numpy array + The analog signal recorded by the interface + """ + return self.interface._read_analog(nsamples=nsamples, **self.params) + + +class AnalogOutput(BaseIO): + """ Class which holds information about analog outputs and abstracts the methods of writing to them - Keyword arguments: - interface -- Interface() instance. Must have the methods '_queue_wav', - '_play_wav', '_stop_wav' - params -- dictionary of keyword:value pairs needed by the interface + Parameters + ---------- + interface: subclass of base_.BaseInterface + Interface through which values are written. Must have '_write_analog' + method. + params: dictionary + A dictionary of parameters for configuration and analog write calls. + Common keys are: subdevice, channel, etc. - Methods: - queue(wav_filename) -- queues - read() -- if the interface supports '_read_bool' for this output, returns - the current value of the output from the interface. Otherwise this - returns the last passed by write(value) - toggle() -- flips the value from the current value + Methods + ------- + config() + Configures the analog output + write(values) + Writes an array of values to the output. Returns True if successful. """ - def __init__(self, interface=None,params={},*args,**kwargs): - super(AudioOutput, self).__init__(interface=interface,params=params,*args,**kwargs) + def __init__(self, interface=None, params={}, *args, **kwargs): + super(AnalogOutput, self).__init__(interface=interface, + params=params, + *args, + **kwargs) + assert self.interface.can_write_analog - assert hasattr(self.interface,'_queue_wav') - assert hasattr(self.interface,'_play_wav') - assert hasattr(self.interface,'_stop_wav') + self.config() - def queue(self,wav_filename): - return self.interface._queue_wav(wav_filename) + def config(self): + """ Calls the interface's config_write_analog method with the keyword + arguments in params - def play(self): - return self.interface._play_wav() + Returns + ------- + bool + True if configuration succeeded + """ + if not hasattr(self.interface, "_config_write_analog"): + return False + + logger.debug("Configuring AnalogOutput to write on interface % s" % self.interface) + return self.interface._config_write_analog(**self.params) + + def write(self, values, event=None): + """ Writes to the analog output + + Parameters + ---------- + values: numpy array + Array of float values to be written to the output + event: dictionary + Dictionary containing event details that are passed along to the + interface. + + Returns + ------- + bool + True if the write succeeded + """ + + return self.interface._write_analog(values=values, event=event) + + +class AudioOutput(BaseIO): + """ Class which holds information about audio outputs and abstracts the + methods of writing to them - def stop(self): - return self.interface._stop_wav() + Parameters + ---------- + interface: subclass of base_.BaseInterface() + Must have the methods '_queue_wav', '_play_wav', '_stop_wav' + params: dictionary + A dictionary of parameters for configuration and audio playback calls. + Common keys are: subdevice, channel, etc. + Methods: + config() + Configures the audio interface + queue(wav_filename) + Queues a .wav file for playback + play() + Plays the queued audio file + stop() + Stops the playing audio file + """ + def __init__(self, interface=None, params={}, *args, **kwargs): + super(AudioOutput, self).__init__(interface=interface, + params=params, + *args, + **kwargs) + + assert hasattr(self.interface, '_queue_wav') + assert hasattr(self.interface, '_play_wav') + assert hasattr(self.interface, '_stop_wav') + self.config() + + def config(self): + """ Calls the interface's config_write_analog method with the keyword + arguments in params + + Returns + ------- + bool + True if configuration succeeded + """ + if not hasattr(self.interface, "_config_write_analog"): + return False + logger.debug("Configuring AudioOutput to write on interface % s" % self.interface) + return self.interface._config_write_analog(**self.params) + def queue(self, wav_filename, event=None): + return self.interface._queue_wav(wav_filename, event=event, **self.params) + def play(self, event=None): + return self.interface._play_wav(event=event, **self.params) + def stop(self, event=None): + return self.interface._stop_wav(event=event, **self.params) diff --git a/pyoperant/interfaces/arduino_.py b/pyoperant/interfaces/arduino_.py index 6e34d744..8fb47703 100644 --- a/pyoperant/interfaces/arduino_.py +++ b/pyoperant/interfaces/arduino_.py @@ -4,6 +4,7 @@ import logging from pyoperant.interfaces import base_ from pyoperant import utils, InterfaceError +from pyoperant.events import events logger = logging.getLogger(__name__) @@ -12,8 +13,7 @@ # TODO: Allow device to be connected to through multiple python instances. This kind of works but needs to be tested thoroughly. class ArduinoInterface(base_.BaseInterface): - """Creates a pyserial interface to communicate with an Arduino via the serial connection. - Communication is through two byte messages where the first byte specifies the channel and the second byte specifies the action. + """ Creates a pyserial interface to communicate with an Arduino via the serial connection. Communication is through two byte messages where the first byte specifies the channel and the second byte specifies the action. Valid actions are: 0. Read input value 1. Set output to ON @@ -21,15 +21,50 @@ class ArduinoInterface(base_.BaseInterface): 3. Sets channel as an output 4. Sets channel as an input 5. Sets channel as an input with a pullup resistor (basically inverts the input values) - :param device_name: The address of the device on the local system (e.g. /dev/tty.usbserial) - :param baud_rate: The baud (bits/second) rate for serial communication. If this is changed, then it also needs to be changed in the arduino project code. + + Parameters + ---------- + device_name: string + The address of the device on the local system (e.g. /dev/tty.usbserial) + baud_rate: int + The baud (bits/second) rate for serial communication. If this is changed, then it also needs to be changed in the arduino project code. + + Attributes + ---------- + device_name: string + The address of the device on the local system (e.g. /dev/tty.usbserial) + baud_rate: int + The baud (bits/second) rate for serial communication. If this is changed, then it also needs to be changed in the arduino project code. + device: serial device + + inputs: list + + output: list + + + Methods + ------- + + Examples + -------- + dev = ArduinoInterface("/dev/tty.usbserial") + + # Configure a boolean output on channel 8 + dev._config_write(channel=8) + # Set the output to True + dev._write_bool(channel=8, value=True) + + # Configure a boolean input on channel 4 + dev._config_read(channel=4) + # Read from that input + dev._read_bool(channel=4) """ _default_state = dict(invert=False, held=False, ) - def __init__(self, device_name, baud_rate=19200, inputs=None, outputs=None, *args, **kwargs): + def __init__(self, device_name, baud_rate=19200, *args, **kwargs): super(ArduinoInterface, self).__init__(*args, **kwargs) @@ -37,18 +72,12 @@ def __init__(self, device_name, baud_rate=19200, inputs=None, outputs=None, *arg self.baud_rate = baud_rate self.device = None - self.read_params = ('channel', 'pullup') + self.read_params = ('channel', 'invert') self._state = dict() self.inputs = [] self.outputs = [] self.open() - if inputs is not None: - for input_ in inputs: - self._config_read(*input_) - if outputs is not None: - for output in outputs: - self._config_write(output) def __str__(self): @@ -56,12 +85,11 @@ def __str__(self): def __repr__(self): # Add inputs and outputs to this - return "ArduinoInterface(%s, baud_rate=%d)" % (self.device_name, self.baud_rate) + return "ArduinoInterface(%s, baud_rate=%d)" % (self.device_name, + self.baud_rate) def open(self): - '''Open a serial connection for the device - :return: None - ''' + ''' Open a serial connection for the device ''' logger.debug("Opening device %s" % self) self.device = serial.Serial(port=self.device_name, @@ -76,23 +104,28 @@ def open(self): logger.info("Successfully opened device %s" % self) def close(self): - '''Close a serial connection for the device - :return: None - ''' + ''' Close a serial connection for the device ''' logger.debug("Closing %s" % self) self.device.close() - def _config_read(self, channel, pullup=False, **kwargs): - ''' Configure the channel to act as an input - :param channel: the channel number to configure - :param pullup: the channel should be configured in pullup mode. On the arduino this has the effect of - returning HIGH when unpressed and LOW when pressed. The returned value will have to be inverted. - :return: None + def _config_read(self, channel, invert=False, **kwargs): + ''' Configure the channel to act as a boolean input + + Parameters + ---------- + channel: int + the channel number to configure + invert: bool + the channel should be configured in pullup mode. On the arduino this has the effect of returning HIGH when unpressed and LOW when pressed. The returned value will have to be inverted. + + Returns + ------- + True if configuration succeeded ''' logger.debug("Configuring %s, channel %d as input" % (self.device_name, channel)) - if pullup is False: + if invert is False: self.device.write(self._make_arg(channel, 4)) else: self.device.write(self._make_arg(channel, 5)) @@ -103,13 +136,20 @@ def _config_read(self, channel, pullup=False, **kwargs): self.inputs.append(channel) self._state.setdefault(channel, self._default_state.copy()) - self._state[channel]["invert"] = pullup + self._state[channel]["invert"] = invert def _config_write(self, channel, **kwargs): - ''' Configure the channel to act as an output - :param channel: the channel number to configure - :return: None - ''' + """ Configure the channel to act as a boolean output + + Parameters + ---------- + channel: int + the channel number to configure + + Returns + ------- + True if configuration succeeded + """ logger.debug("Configuring %s, channel %d as output" % (self.device_name, channel)) self.device.write(self._make_arg(channel, 3)) @@ -119,16 +159,26 @@ def _config_write(self, channel, **kwargs): self.outputs.append(channel) self._state.setdefault(channel, self._default_state.copy()) - def _read_bool(self, channel, **kwargs): - ''' Read a value from the specified channel - :param channel: the channel from which to read - :return: value + def _read_bool(self, channel, invert=False, event=None, **kwargs): + """ Read a value from the specified channel + + Parameters + ---------- + channel: int + the channel from which to read + invert: bool + whether or not to invert the read value + + Returns + ------- + bool: + the value read from the hardware Raises ------ ArduinoException Reading from the device failed. - ''' + """ if channel not in self._state: raise InterfaceError("Channel %d is not configured on device %s" % (channel, self.device_name)) @@ -151,53 +201,17 @@ def _read_bool(self, channel, **kwargs): logger.debug("Read value of %d from channel %d on %s" % (v, channel, self)) if v in [0, 1]: - if self._state[channel]["invert"]: + if invert: v = 1 - v - return v == 1 + value = v == 1 + if value: + events.write(event) + return value else: logger.error("Device %s returned unexpected value of %d on reading channel %d" % (self, v, channel)) # raise InterfaceError('Could not read from serial device "%s", channel %d' % (self.device, channel)) - def _poll(self, channel, timeout=None, wait=None, suppress_longpress=True, **kwargs): - """ runs a loop, querying for pecks. returns peck time or None if polling times out - :param channel: the channel from which to read - :param timeout: the time, in seconds, until polling times out. Defaults to no timeout. - :param wait: the time, in seconds, between subsequent reads. Defaults to 0. - :param suppress_longpress: only return a successful read if the previous read was False. This can be helpful when using a button, where a press might trigger multiple times. - - :return: timestamp of True read - """ - - if timeout is not None: - start = time.time() - - logger.debug("Begin polling from device %s" % self.device_name) - while True: - if not self._read_bool(channel): - logger.debug("Polling: %s" % False) - # Read returned False. If the channel was previously "held" then that flag is removed - if self._state[channel]["held"]: - self._state[channel]["held"] = False - else: - logger.debug("Polling: %s" % True) - # As long as the channel is not currently held, or longpresses are not being supressed, register the press - if (not self._state[channel]["held"]) or (not suppress_longpress): - break - - if timeout is not None: - if time.time() - start >= timeout: # Return GoodNite exception? - logger.debug("Polling timed out. Returning") - return None - - # Wait for a specified amount of time before continuing on with the next loop - if wait is not None: - utils.wait(wait) - - self._state[channel]["held"] = True - logger.debug("Input detected. Returning") - return datetime.datetime.now() - - def _write_bool(self, channel, value, **kwargs): + def _write_bool(self, channel, value, event=None, **kwargs): '''Write a value to the specified channel :param channel: the channel to write to :param value: the value to write @@ -208,6 +222,7 @@ def _write_bool(self, channel, value, **kwargs): raise InterfaceError("Channel %d is not configured on device %s" % (channel, self)) logger.debug("Writing %s to device %s, channel %d" % (value, self, channel)) + events.write(event) if value: s = self.device.write(self._make_arg(channel, 1)) else: @@ -226,6 +241,6 @@ def _make_arg(channel, value): return "".join([chr(channel), chr(value)]) -class ArduinoException(Exception): +class ArduinoException(InterfaceError): pass diff --git a/pyoperant/interfaces/base_.py b/pyoperant/interfaces/base_.py index 7fa5dec9..28439a2f 100644 --- a/pyoperant/interfaces/base_.py +++ b/pyoperant/interfaces/base_.py @@ -1,8 +1,23 @@ +import time +import datetime +import logging +import wave +import numpy as np +from pyoperant import InterfaceError + +logger = logging.getLogger(__name__) + class BaseInterface(object): - """docstring for BaseInterface""" + """ + Implements generic interface methods. + Implemented methods: + - _poll + """ + def __init__(self, *args, **kwargs): + super(BaseInterface, self).__init__() - pass + self.device_name = None def open(self): pass @@ -10,6 +25,138 @@ def open(self): def close(self): pass + def _poll(self, channel=None, subdevices=None, invert=False, + last_value=False, suppress_longpress=False, + timeout=None, wait=None, event=None, + *args, **kwargs): + """ Runs a loop, querying for the boolean input to return True. + + Parameters + ---------- + channel: + default channel argument to pass to _read_bool() + subdevices: + default subdevices argument to pass to _read_bool() + invert: bool + whether or not to invert the read value + last_value: bool + if the last read value was True. Necessary to suppress longpresses + suppress_longpress: bool + if True, attempts to suppress returning immediately if the button is still being pressed since the last call. If last_value is True, then it waits until the interface reads a single False value before allowing it to return. + timeout: float + the time, in seconds, until polling times out. Defaults to no timeout. + wait: float + the time, in seconds, to wait between subsequent reads (default no wait). + event: dict + a dictionary of event information to emit just before writing + + Returns + ------- + timestamp of True read or None if timed out + """ + + logger.debug("Begin polling from device %s" % self.device_name) + if timeout is not None: + start = time.time() + while True: + value = self._read_bool(channel=channel, + subdevices=subdevices, + invert=invert, + event=event, + *args, **kwargs) + if not isinstance(value, bool): + raise ValueError("Polling for bool returned something that was not a bool") + if value is True: + if (last_value is False) or (suppress_longpress is False): + logger.debug("Input detected. Returning") + return datetime.datetime.now() + else: + last_value = False + + if timeout is not None: + if time.time() - start >= timeout: + logger.debug("Polling timed out. Returning") + return None + + if wait is not None: + utils.wait(wait) + + def __del__(self): self.close() + @property + def can_read_bool(self): + """ + If the interface is capable of reading boolean values from the device + """ + + return hasattr(self, "_read_bool") + + @property + def can_write_bool(self): + """ + If the interface is capable of writing boolean values to the device + """ + + return hasattr(self, "_write_bool") + + @property + def can_read_analog(self): + """ + If the interface is capable of reading analog values from the device + """ + + return hasattr(self, "_read_analog") + + @property + def can_write_analog(self): + """ + If the interface is capable of writing analog values from the device + """ + + return hasattr(self, "_write_analog") + +class AudioInterface(BaseInterface): + """ + Generic audio interface that implements wavefile handling + Implemented methods: + - validate + - + """ + + def __init__(self, *args, **kwargs): + + super(AudioInterface, self).__init__() + self.wf = None + + def _config_write_analog(self, *args, **kwargs): + + pass + + def validate(self): + """ + Verifies simply that the wav file has been opened. Could do other + checks in the future. + """ + + if self.wf is None: + raise InterfaceError("wavefile is not open, but it should be") + + def _load_wav(self, filename): + """ Loads the .wav file and normalizes it according to its bit depth + """ + + self.wf = wave.open(filename) + self.validate() + sampwidth = self.wf.getsampwidth() + if sampwidth == 2: + max_val = 32768.0 + dtype = np.int16 + elif sampwidth == 4: + max_val = float(2 ** 32) + dtype = np.int32 + + data = np.fromstring(self.wf.readframes(-1), dtype=dtype) + + return (data / max_val).astype(np.float64) diff --git a/pyoperant/interfaces/console_.py b/pyoperant/interfaces/console_.py index 8a1ff6ed..72f29555 100644 --- a/pyoperant/interfaces/console_.py +++ b/pyoperant/interfaces/console_.py @@ -1,18 +1,44 @@ import sys +import datetime as dt from pyoperant.interfaces import base_ class ConsoleInterface(base_.BaseInterface): """docstring for ComediInterface""" def __init__(self,*args,**kwargs): super(ConsoleInterface, self).__init__(*args,**kwargs) + + def _config_read(self, **kwargs): + """ + """ pass - def _read(self,prompt=''): + def _config_write(self, **kwargs): + """ + """ + pass + + def _read(self, default_value=None, prompt='', **kwargs): """ read from keyboard input """ + if default_value is not None: + return default_value + return raw_input(prompt) - def _write(self,value): - """Write to console + def _write(self, value, **kwargs): + """ Write to console """ print value + + def _poll(self, timeout=None, **kwargs): + + if timeout is not None: + prompt = "Timeout?" + else: + prompt = "Press enter" + value = self.read(prompt=prompt, **kwargs) + + if value: + return dt.datetime.now() + else: + return None diff --git a/pyoperant/interfaces/nidaq_.py b/pyoperant/interfaces/nidaq_.py new file mode 100644 index 00000000..984f4ee0 --- /dev/null +++ b/pyoperant/interfaces/nidaq_.py @@ -0,0 +1,658 @@ +import time +import datetime +import logging +import numpy as np +import nidaqmx +import wave +from pyoperant.interfaces import base_ +from pyoperant import utils, InterfaceError +from pyoperant.events import events, EventDToAHandler + +logger = logging.getLogger(__name__) + + +def list_devices(): + """ List the devices currently connected to the system. """ + + return nidaqmx.System().devices + + +def list_analog_inputs(): + """ List the analog inputs for each device """ + + channels = dict() + for dev in nidaqmx.System().devices: + channels[str(dev)] = dev.get_analog_input_channels() + + return channels + + +def list_analog_outputs(): + """ List the analog outputs for each device """ + + channels = dict() + for dev in nidaqmx.System().devices: + channels[str(dev)] = dev.get_analog_output_channels() + + return channels + + +def list_boolean_inputs(): + """ List the boolean inputs for each device """ + + channels = dict() + for dev in nidaqmx.System().devices: + channels[str(dev)] = dev.get_digital_input_lines() + + return channels + + +def list_boolean_outputs(): + """ List the boolean outputs for each device """ + + channels = dict() + for dev in nidaqmx.System().devices: + channels[str(dev)] = dev.get_digital_output_lines() + + return channels + + +# TODO: list clock channels? + +class NIDAQmxError(InterfaceError): + pass + + +class NIDAQmxInterface(base_.BaseInterface): + """ Creates an interface for inputs and outputs to a NIDAQ card using + the pylibnidaqmx library: https://github.com/imrehg/pylibnidaqmx + + Parameters + ---------- + device_name: string + the name of the device on your system (e.g. "Dev1") + samplerate: float + the samplerate for all inputs and outputs. If an external clock is + specified, then this should be the maximum allowed samplerate. + clock_channel: string + the channel name for an external clock signal (e.g. "/Dev1/PFI0") + analog_event_handler: instance of events.EventDToAHandler + an event handler for sending event information down an analog channel. + Should have a channel attribute. This can also be passed when you + configure the analog output. + + Attributes + ---------- + device_name: string + the name of the device on your system (e.g. "Dev1") + samplerate: float + the samplerate for all inputs and outputs. If an external clock is + specified, then this should be the maximum allowed samplerate. + clock_channel: string + the channel name for an external clock signal (e.g. "/Dev1/PFI0") + tasks: dict + a dictionary of all configured tasks. Each task corresponds to inputs or outputs of the same type that will be read or written to together. + analog_event_handler: instance of events.EventDToAHandler + an event handler for sending event information down an analog channel. + + Methods + ------- + _config_write + _config_read + _write_bool + _read_bool + _config_write_analog + _config_read_analog + _write_analog + _read_analog + + Examples + -------- + dev = NIDAQmxInterface("Dev1") # open the device at "Dev1" + # or with an external clock + dev = NIDAQmxInterface("Dev1", clock_channel="/Dev1/PFI0") + + # Configure a boolean output on port 0, line 0 + dev._config_write("Dev1/port0/line0") + # Set the output to True + dev._write_bool("Dev1/port0/line0", True) + + # Configure a boolean input on port 0, line 1 + dev._config_read("Dev1/port0/line1") + # Read from that input + dev._read_bool("Dev1/port0/line1") + + # Configure an analog output on channel ao0 + dev._config_write("Dev1/ao0") + # Set the output to True + dev._write("Dev1/ao0", True) + + # Configure a analog input on channel ai0 + dev._config_read("Dev1/ai0") + # Read from that input + dev._read("Dev1/ai0") + """ + + def __init__(self, device_name, samplerate=30000, + analog_event_handler=None, clock_channel="OnboardClock", + *args, **kwargs): + super(NIDAQmxInterface, self).__init__(*args, **kwargs) + self.device_name = device_name + self.samplerate = samplerate + self.clock_channel = clock_channel + self._analog_event_handler = analog_event_handler + + self.tasks = dict() + self.open() + + def open(self): + """ Opens the nidaqmx device """ + + logger.debug("Opening nidaqmx device named %s" % self.device_name) + self.device = nidaqmx.Device(self.device_name) + + def close(self): + """ Closes the nidaqmx device and deletes all of the tasks """ + + logger.debug("Closing nidaqmx device named %s" % self.device_name) + for task in self.tasks.values(): + logger.debug("Deleting task named %s" % str(task.name)) + task.stop() + task.clear() + del task + self.tasks = dict() + + def _config_read(self, channel, **kwargs): + """Configure a channel or group of channels as a boolean input + + Parameters + ---------- + channel: string + a channel or group of channels that will all be read from at the same time (e.g. "Dev1/port0/line1" or "Dev1/port0/line1-7") + + Returns + ------- + True on successful configuration + """ + # TODO: test multiple channels. What format should channels be in? + + logger.debug("Configuring digital input on channel(s) %s" % str(channel)) + task = nidaqmx.DigitalInputTask() + task.create_channel(channel) + task.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate, + sample_mode="continuous") + task.set_read_relative_to("most_recent") + task.set_read_offset(-1) + self.tasks[channel] = task + + def _config_write(self, channel, **kwargs): + """ Configure a channel or group of channels as a boolean output + + Parameters + ---------- + channel: string + a channel or group of channels that will all be written to at the same time + + Returns + ------- + True on successful configuration + """ + + # TODO: test multiple channels. What format should channels be in? + logger.debug("Configuring digital output on channel(s) %s" % str(channel)) + task = nidaqmx.DigitalOutputTask() + task.create_channel(channel) + task.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate) + task.set_buffer_size(0) + self.tasks[channel] = task + + def _read_bool(self, channel, invert=False, event=None, **kwargs): + """ Read a boolean value from a channel or group of channels + + Parameters + ---------- + channel: string + a channel or group of channels that will all be read at the same time + invert: bool + whether or not to invert the read value + event: dict + a dictionary of event information to emit after a True reading + + Returns + ------- + The value read from the hardware + """ + if channel not in self.tasks: + raise NIDAQmxError("Channel(s) %s not yet configured" % str(channel)) + + task = self.tasks[channel] + task.start() + # while task.get_samples_per_channel_acquired() == 0: + # pass + value, bits_per_sample = task.read(1) + value = value[0, 0] + task.stop() + if invert: + value = 1 - value + value = bool(value == 1) + if value: + events.write(event) + + return value + + def _write_bool(self, channel, value, event=None, **kwargs): + """ Write a boolean value to a channel or group of channels + + Parameters + ---------- + channel: string + a channel or group of channels that will all be written to at the same time + value: bool or boolean array + value to write to the hardware + event: dict + a dictionary of event information to emit just before writing + + Returns + ------- + True + """ + if channel not in self.tasks: + raise NIDAQmxError("Channel(s) %s not yet configured" % str(channel)) + task = self.tasks[channel] + events.write(event) + task.write(value, auto_start=True) + + return True + + def _poll(self, channel=None, invert=False, + last_value=False, suppress_longpress=False, + timeout=None, wait=None, event=None, *args, **kwargs): + """ Runs a loop, querying for the boolean input to return True. + + Parameters + ---------- + channel: + default channel argument to pass to _read_bool() + invert: bool + whether or not to invert the read value + last_value: bool + if the last read value was True. Necessary to suppress longpresses + suppress_longpress: bool + if True, attempts to suppress returning immediately if the button is still being pressed since the last call. If last_value is True, then it waits until the interface reads a single False value before allowing it to return. + timeout: float + the time, in seconds, until polling times out. Defaults to no timeout. + wait: float + the time, in seconds, to wait between subsequent reads (default no wait). + + Returns + ------- + timestamp of True read or None if timed out + """ + + logger.debug("Begin polling from device %s" % self.device_name) + if timeout is not None: + start = time.time() + + if channel not in self.tasks: + raise NIDAQmxError("Channel(s) %s not yet configured" % str(channel)) + + task = self.tasks[channel] + task.start() + while True: + # Read the value - cannot use _read_bool because it must start and stop the task each time. + value, bits_per_sample = task.read(1) + value = value[0, 0] + if invert: + value = 1 - value + value = bool(value == 1) + if value: + events.write(event) + + if not isinstance(value, bool): + task.stop() + raise ValueError("Polling for bool returned something that was not a bool") + if value is True: + + if (last_value is False) or (suppress_longpress is False): + logger.debug("Input detected. Returning") + task.stop() + return datetime.datetime.now() + else: + last_value = False + + if timeout is not None: + if time.time() - start >= timeout: + logger.debug("Polling timed out. Returning") + task.stop() + return None + + if wait is not None: + utils.wait(wait) + + + def _config_read_analog(self, channel, min_val=-10.0, max_val=10.0, + **kwargs): + """ Configure a channel or group of channels as an analog input + + Parameters + ---------- + channel: string + a channel or group of channels that will all be read at the same time + min_val: float + the minimum voltage that can be read + max_val: float + the maximum voltage that can be read + + Returns + ------- + True if configuration succeeded + """ + + logger.debug("Configuring analog input on channel(s) %s" % str(channel)) + task = nidaqmx.AnalogInputTask() + task.create_voltage_channel(channel, min_val=min_val, max_val=max_val) + task.configure_timing_sample_clock(source=selsf.clock_channel, + rate=self.samplerate, + sample_mode="finite") + self.tasks[channel] = task + + return True + + def _config_write_analog(self, channel, analog_event_handler=None, + min_val=-10.0, max_val=10.0, **kwargs): + """ Configure a channel or group of channels as an analog output + + Parameters)) + ---------- + channel: string + a channel or group of channels that will all be written to at the same + analog_event_handler: instance of events.EventDToAHandler + an event handler for sending event information down an analog channel. Should have a channel attribute. + min_val: float + the minimum voltage that can be read + max_val: float + the maximum voltage that can be read + + Returns + ------- + True if configuration succeeded + """ + + logger.debug("Configuring analog output on channel(s) %s" % str(channel)) + task = nidaqmx.AnalogOutputTask() + if self._analog_event_handler is None and \ + analog_event_handler is not None: + if not hasattr(analog_event_handler, "channel"): + raise AttributeError("analog_event_handler must have a channel attribute") + channel = nidaqmx.libnidaqmx.make_pattern([channel, + analog_event_handler.channel]) + logger.debug("Configuring digital to analog output as well.") + self._analog_event_handler = analog_event_handler + + task.create_voltage_channel(channel, min_val=min_val, max_val=max_val) + task.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate, + sample_mode="finite") + self.tasks[channel] = task + + def _read_analog(self, channel, nsamples, event=None, **kwargs): + """ Read from a channel or group of channels for the specified number of + samples. + + Parameters + ---------- + channel: string + a channel or group of channels that will be read at the same time + nsamples: int + the number of samples to read + event: dict + a dictionary of event information to emit after reading + + Returns + ------- + a numpy array of the data that was read + """ + + if channel not in self.tasks: + raise NIDAQmxError("Channel(s) %s not yet configured" % str(channel)) + + task = self.tasks[channel] + task.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate, + sample_mode="finite", + samples_per_channel=nsamples) + values = task.read(nsamples) + events.write(event) + return values + + def _write_analog(self, channel, values, is_blocking=False, event=None, + **kwargs): + """ Write a numpy array of float64 values to the buffer on a channel or + group of channels + + Parameters + ---------- + channel: string + a channel or group of channels that will all be written to at the same time + values: numpy array of float64 values + values to write to the hardware. Should be of dimension nchannels x nsamples. + is_blocking: bool + whether or not to block execution until all samples are written to the hardware + event: dict + a dictionary of event information to emit just before writing + + Returns + ------- + True + """ + + if channel not in self.tasks: + raise NIDAQmxError("Channel(s) %s not yet configured" % str(channel)) + + task = self.tasks[channel] + task.stop() + task.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate, + sample_mode="finite", + samples_per_channel=values.shape[0]) + + if self._analog_event_handler is not None: + # Get the string of (scaled) bits from the event handler + bit_string = self._analog_event_handler.to_bit_sequence(event) + + # multi-channel outputs need to be of shape nsamples x nchannels + if len(values.shape) == 1: + values = values.reshape((-1, 1)) + + # Add a channel of all zeros + values = np.hstack([values, np.zeros((values.shape[0], 1))]) + # Place the bit string at the start + values[:len(bit_string), -1] = bit_string + + # Write the values to the nidaq buffer + # I think we might want to set layout='group_by_scan_number' in .write() + task.write(values, auto_start=False) + events.write(event) + task.start() + if is_blocking: + task.wait_until_done() + task.stop() + + return True + + +class NIDAQmxAudioInterface(NIDAQmxInterface, base_.AudioInterface): + """ Creates an interface for writing audio data to a NIDAQ card using + the pylibnidaqmx library: https://github.com/imrehg/pylibnidaqmx + + Parameters + ---------- + device_name: string + the name of the device on your system (e.g. "Dev1") + samplerate: float + the samplerate for the sound. If an external clock is + specified, then this should be the maximum allowed samplerate. + clock_channel: string + the channel name for an external clock signal (e.g. "/Dev1/PFI0") + + Attributes + ---------- + device_name: string + the name of the device on your system (e.g. "Dev1") + samplerate: float + the samplerate for the sound. If an external clock is + specified, then this should be the maximum allowed samplerate. + clock_channel: string + the channel name for an external clock signal (e.g. "/Dev1/PFI0") + stream: nidaqmx.AnalogOutputTask + the task used for writing out sound data + wf: file handle + the currently playing wavefile handle + + Methods + ------- + _config_write_analog + _get_stream + _queue_wav + _play_wav + _stop_wav + + Examples + -------- + + """ + def __init__(self, device_name, samplerate=30000.0, + clock_channel=None, *args, **kwargs): + + super(NIDAQmxAudioInterface, self).__init__(device_name=device_name, + samplerate=samplerate, + clock_channel=clock_channel, + *args, **kwargs) + self.stream = None + self.wf = None + self._wav_data = None + + def _config_write_analog(self, channel, analog_event_handler=None, + min_val=-10.0, max_val=10.0, **kwargs): + """ Configure a channel or group of channels as an analog output + + Parameters + ---------- + channel: string + a channel or group of channels that will all be written to at the + same time + analog_event_handler: instance of events.EventDToAHandler + an event handler for sending event information down an analog channel. Should have a channel attribute. + min_val: float + the minimum voltage that can be read + max_val: float + the maximum voltage that can be read + + Returns + ------- + True if configuration succeeded + """ + super(NIDAQmxAudioInterface, self)._config_write_analog( + channel, + analog_event_handler=analog_event_handler, + min_val=min_val, + max_val=max_val, + **kwargs) + self.stream = self.tasks.values()[0] + + def _queue_wav(self, wav_file, start=False, event=None, **kwargs): + """ Queue the wav file for playback + + Parameters + ---------- + wav_file: string + Path to the wave file to load + start: bool + Whether or not to immediately start playback + event: dict + a dictionary of event information to emit just before playback + """ + + if self.wf is not None: + self._stop_wav() + + events.write(event) + logger.debug("Queueing wavfile %s" % wav_file) + self._wav_data = self._load_wav(wav_file) + + if self._analog_event_handler is not None: + # Get the string of (scaled) bits from the event handler + bit_string = self._analog_event_handler.to_bit_sequence(event) + + # multi-channel outputs need to be of shape nsamples x nchannels + if len(self._wav_data.shape) == 1: + values = self._wav_data.reshape((-1, 1)) + else: + values = self._wav_data + + # Add a channel of all zeros + self._wav_data = np.hstack([values, np.zeros((values.shape[0], 1))]) + # Place the bit string at the start + self._wav_data[:len(bit_string), -1] = bit_string + self._get_stream(start=start, **kwargs) + + def _get_stream(self, start=False, **kwargs): + """ Writes the stream to the nidaq buffer and optionally starts it. + + Parameters + ---------- + start: bool + Whether or not to immediately start playback + """ + + self.stream.configure_timing_sample_clock(source=self.clock_channel, + rate=self.samplerate, + sample_mode="finite", + samples_per_channel=self._wav_data.shape[0]) + # I think we might want to set layout='group_by_scan_number' in .write() + self.stream.write(self._wav_data, auto_start=False) + if start: + self._play_wav(**kwargs) + + def _play_wav(self, is_blocking=False, event=None, **kwargs): + """ Play the data that is currently in the buffer + + Parameters + ---------- + is_blocking: bool + Whether or not to play the sound in blocking mode + event: dict + a dictionary of event information to emit just before playback + """ + + logger.debug("Playing wavfile") + events.write(event) + self.stream.start() + if is_blocking: + self.wait_until_done() + + def _stop_wav(self, event=None, **kwargs): + """ Stop the current playback and clear the buffer + + Parameters + ---------- + event: dict + a dictionary of event information to emit just before stopping + """ + + try: + logger.debug("Attempting to close stream") + events.write(event) + self.stream.stop() + logger.debug("Stream closed") + except AttributeError: + self.stream = None + + try: + self.wf.close() + except AttributeError: + self.wf = None + + self._wav_data = None diff --git a/pyoperant/interfaces/pyaudio_.py b/pyoperant/interfaces/pyaudio_.py index dae6abcb..6df27d3b 100644 --- a/pyoperant/interfaces/pyaudio_.py +++ b/pyoperant/interfaces/pyaudio_.py @@ -1,9 +1,64 @@ +from ctypes import * +from contextlib import contextmanager import pyaudio import wave +import logging from pyoperant.interfaces import base_ from pyoperant import InterfaceError +from pyoperant.events import events -class PyAudioInterface(base_.BaseInterface): +logger = logging.getLogger(__name__) +# TODO: Clean up _stop_wav logging changes + + +# Modify the alsa error function to suppress needless warnings +# Code derived from answer by Nils Werner at: +# http://stackoverflow.com/questions/7088672/pyaudio-working-but-spits-out-error-messages-each-time +# TODO: Pass actual warnings to logger.debug when logging is fully integrated into master. +@contextmanager +def log_alsa_warnings(): + """ Suppresses ALSA warnings when initializing a PyAudio instance. + + with log_alsa_warnings(): + pa = pyaudio.PyAudio() + """ + # Set up the C error handler for ALSA + ERROR_HANDLER_FUNC = CFUNCTYPE(None, + c_char_p, + c_int, + c_char_p, + c_int, + c_char_p, + c_char_p) + + def py_error_handler(filename, line, function, err, fmt, args): + + # ALSA_STR = "ALSA lib %s:%i:(%s) %s" + + # Try to format fmt with args. As far as I can tell, CFUNCTYPE does not + # support variable number of arguments, so formatting will fail with + # TypeError if fmt has multiple %'s. + # if args is not None: + # try: + # fmt %= args + # except TypeError: + # pass + # logger.debug(ALSA_STR, filename, line, function, fmt) + pass + + c_error_handler = ERROR_HANDLER_FUNC(py_error_handler) + for asound_library in ["libasound.so", "libasound.so.2"]: + try: + asound = cdll.LoadLibrary(asound_library) + break + except OSError: + continue + asound.snd_lib_error_set_handler(c_error_handler) + yield + asound.snd_lib_error_set_handler(None) + + +class PyAudioInterface(base_.AudioInterface): """Class which holds information about an audio device assign a simple callback function that will execute on each frame @@ -26,9 +81,11 @@ def __init__(self,device_name='default',*args,**kwargs): self.open() def open(self): - self.pa = pyaudio.PyAudio() + with log_alsa_warnings(): + self.pa = pyaudio.PyAudio() for index in range(self.pa.get_device_count()): if self.device_name == self.pa.get_device_info_by_index(index)['name']: + logger.debug("Found device %s at index %d" % (self.device_name, index)) self.device_index = index break else: @@ -39,28 +96,23 @@ def open(self): self.device_info = self.pa.get_device_info_by_index(self.device_index) def close(self): + logger.debug("Closing device") try: self.stream.close() except AttributeError: self.stream = None - try: + try: self.wf.close() except AttributeError: self.wf = None self.pa.terminate() - def validate(self): - if self.wf is not None: - return True - else: - raise InterfaceError('there is something wrong with this wav file') - - def _get_stream(self,start=False): + def _get_stream(self, start=False, event=None, **kwargs): """ """ def _callback(in_data, frame_count, time_info, status): try: - cont = self.callback() + cont = self.callback() except TypeError: cont = True @@ -75,23 +127,40 @@ def _callback(in_data, frame_count, time_info, status): rate=self.wf.getframerate(), output=True, output_device_index=self.device_index, - start=start, + start=False, stream_callback=_callback) + if start: + self._play_wav(event=event) - def _queue_wav(self,wav_file,start=False): + def _queue_wav(self, wav_file, start=False, event=None, **kwargs): + logger.debug("Queueing wavfile %s" % wav_file) self.wf = wave.open(wav_file) self.validate() - self._get_stream(start=start) + self._get_stream(start=start, event=event) - def _play_wav(self): + def _play_wav(self, event=None, **kwargs): + logger.debug("Playing wavfile") + events.write(event) self.stream.start_stream() - def _stop_wav(self): + def _stop_wav(self, event=None, **kwargs): try: + logger.debug("Attempting to close pyaudio stream") + events.write(event) self.stream.close() + logger.debug("Stream closed") except AttributeError: self.stream = None - try: + try: self.wf.close() except AttributeError: self.wf = None + +if __name__ == "__main__": + + with log_alsa_warnings(): + pa = pyaudio.PyAudio() + pa.terminate() + print "-" * 40 + pa = pyaudio.PyAudio() + pa.terminate() diff --git a/pyoperant/interfaces/pydaqmx_.py b/pyoperant/interfaces/pydaqmx_.py index e69de29b..60857920 100644 --- a/pyoperant/interfaces/pydaqmx_.py +++ b/pyoperant/interfaces/pydaqmx_.py @@ -0,0 +1,66 @@ +import logging +import numpy as np +from PyDAQmx import * +from pyoperant.interfaces import base_ +from pyoperant import utils, InterfaceError + + +class PyDAQmxInterface(base_.BaseInterface): + + def __init__(self, device_name="default", *args, **kwargs): + # Initialize the device + super(PyDAQmxInterface, self).__init__(*args, **kwargs) + self.device_name = device_name + self.device_index = None + # self.stream = None + # self.wf = None + # self.callback = None + self.open() + + def open(self): + + pass + + def close(self): + + pass + + def _config_read(self): + + pass + + def _config_write(self): + + pass + + def _read_bool(self): + + pass + + def _poll(self): + + pass + + def _write_bool(self): + + pass + + def validate(self): + + pass + + def _get_stream(self): + + pass + + def _queue_wav(self): + + pass + + def _play_wav(self): + + pass + + def _stop_wav(self): + + pass diff --git a/pyoperant/interfaces/utils.py b/pyoperant/interfaces/utils.py new file mode 100644 index 00000000..9124b0d6 --- /dev/null +++ b/pyoperant/interfaces/utils.py @@ -0,0 +1,6 @@ +from contextlib import contextmanager + +@contextmanager +def buffered_analog_output(data, chunk_size, buffer_size): + + pass diff --git a/pyoperant/panels.py b/pyoperant/panels.py index 7593ba82..6587b2cb 100644 --- a/pyoperant/panels.py +++ b/pyoperant/panels.py @@ -5,10 +5,10 @@ class BasePanel(object): This class should be subclassed to define a local panel configuration. - To build a panel, do the following in the __init__() method of your local + To build a panel, do the following in the __init__() method of your local subclass: - 1. add instances of the necessary interfaces to the 'interfaces' dict + 1. add instances of the necessary interfaces to the 'interfaces' dict attribute: >>> self.interfaces['comedi'] = comedi.ComediInterface(device_name='/dev/comedi0') @@ -25,7 +25,7 @@ class BasePanel(object): 4. assign panel methods needed for operant behavior, such as 'reward': >>> self.reward = self.hopper.reward - 5. finally, define a reset() method that will set the entire panel to a + 5. finally, define a reset() method that will set the entire panel to a neutral state: >>> def reset(self): @@ -43,4 +43,39 @@ def __init__(self, *args,**kwargs): self.outputs = [] def reset(self): - raise NotImplementedError + """ + Turn everything off (sleep), then ready + """ + # Should log that nothing is being reset and move on + + self.sleep() + return self.ready() + + def ready(self): + + return True + + def reward(self): + + pass + + def sleep(self): + """ + Turn all boolean outputs off + """ + for output in self.outputs: + if isinstance(output, hwio.BooleanOutput): + self.output.write(False) + + return True + + def idle(self): + + return True + + def wake(self): + """ + Ready the panel + """ + + return self.ready() diff --git a/pyoperant/queues.py b/pyoperant/queues.py index 08d181ed..1fb867f0 100644 --- a/pyoperant/queues.py +++ b/pyoperant/queues.py @@ -2,57 +2,69 @@ from pyoperant.utils import rand_from_log_shape_dist import cPickle as pickle import numpy as np +import logging -def random_queue(conditions,tr_max=100,weights=None): - """ generator which randomly samples conditions +logger = logging.getLogger(__name__) - Args: - conditions (list): The conditions to sample from. - weights (list of ints): Weights of each condition - Kwargs: - tr_max (int): Maximum number of trial conditions to generate. (default: 100) +def random_queue(items, weights=None, max_items=None): + """ Generator which randomly samples items, with replacement - Returns: - whatever the elements of 'conditions' are + Parameters + ---------- + items: list + A list of items to be queued + weights: list + A list of weights, 1 for each item in items + max_items: int + Maximum number of items to generate. (default: None) + Yields + ------ + A single item at each iteration """ - if weights: - conditions_weighted = [] - for cond,w in zip(conditions,weights): - for ww in range(w): - conditions_weighted += cond - conditions = conditions_weighted - - tr_num = 0 - while tr_num < tr_max: - yield random.choice(conditions) - tr_num += 1 - -def block_queue(conditions,reps=1,shuffle=False): - """ generate trial conditions from a block - - Args: - conditions (list): The conditions to sample from. - - Kwargs: - reps (int): number of times each item in conditions will be presented (default: 1) - shuffle (bool): Shuffles the queue (default: False) - - Returns: - whatever the elements of 'conditions' are - + if len(items) == 0: + raise ValueError("Cannot intialize a queue with 0 items") + + if weights is None: + weights = [1.0 / len(items)] * len(items) + else: + weights = [float(ww) / sum(weights) for ww in weights] + + ii = 0 + while True: + if (max_items is not None) and (ii >= max_items): + break + yield np.random.choice(items, p=weights) + ii += 1 + + +def block_queue(items, repetitions=1, shuffle=False): + """ Generator which samples items in blocks + + Parameters + ---------- + items: list + A list of items to be queued + repetitions: int + The number of times each item in items will be presented (default: 1) + shuffle: bool + Shuffles the queue (default: False) + + Yields + ------ + A single item at each iteration """ - conditions_repeated = [] - for rr in range(reps): - conditions_repeated += conditions - conditions = conditions_repeated + items_repeated = [] + for rr in range(repetitions): + items_repeated += items + items = items_repeated if shuffle: - random.shuffle(conditions) - - for cond in conditions: - yield cond + random.shuffle(items) + + for item in items: + yield item class AdaptiveBase(object): """docstring for AdaptiveBase @@ -138,7 +150,7 @@ class KaernbachStaircase(AdaptiveBase): Returns: float """ - def __init__(self, + def __init__(self, start_val=100, stepsize_up=3, stepsize_dn=1, @@ -150,7 +162,7 @@ def __init__(self, super(KaernbachStaircase, self).__init__() self.val = start_val self.stepsize_up = stepsize_up - self.stepsize_dn = stepsize_dn + self.stepsize_dn = stepsize_dn self.min_val = min_val self.max_val = max_val self.crit = crit @@ -160,7 +172,7 @@ def __init__(self, def update(self, correct, no_resp): super(KaernbachStaircase, self).update(correct, no_resp) - + self.val += -1*self.stepsize_dn if correct else self.stepsize_up if self.crit_method=='reversals': @@ -183,7 +195,7 @@ def next(self): class DoubleStaircase(AdaptiveBase): """ - Generates conditions from a list of stims that monotonically vary from most + Generates conditions from a list of stims that monotonically vary from most easily left to most easily right i.e. left is low and right is high @@ -217,7 +229,7 @@ def next(self): super(DoubleStaircase, self).next() if self.high_idx - self.low_idx <= 1: raise StopIteration - + delta = int(np.ceil((self.high_idx - self.low_idx) * self.rate_constant)) if random.random() < .5: # probe low side self.trial['low'] = True @@ -237,7 +249,7 @@ class DoubleStaircaseReinforced(AdaptiveBase): Generates conditions as with DoubleStaircase, but 1-probe_rate proportion of the trials easier/known trials to reduce frustration. - Easier trials are sampled from a log shaped distribution so that more trials + Easier trials are sampled from a log shaped distribution so that more trials are sampled from the edges than near the indices stims: an array of stimuli names ordered from most easily left to most easily right @@ -281,7 +293,7 @@ def next(self): return {'class': 'L', 'stim_name': self.stims[val]} else: # probe right if self.sample_log: - val = self.dblstaircase.high_idx + int(rand_from_log_shape_dist() * (len(self.stims) - self.dblstaircase.high_idx)) + val = self.dblstaircase.high_idx + int(rand_from_log_shape_dist() * (len(self.stims) - self.dblstaircase.high_idx)) else: val = self.dblstaircase.high_idx + random.randrange(len(self.stims) - self.dblstaircase.high_idx) return {'class': 'R', 'stim_name': self.stims[val]} @@ -300,7 +312,7 @@ class MixedAdaptiveQueue(PersistentBase, AdaptiveBase): Generates conditions from multiple adaptive sub queues. Use the generator MixedAdaptiveQueue.load(filename, sub_queues) - to load a previously saved MixedAdaptiveQueue or generate a new one + to load a previously saved MixedAdaptiveQueue or generate a new one if the pkl file doesn't exist. sub_queues: a list of adaptive queues @@ -344,5 +356,44 @@ def on_load(self): pass +class BaseHandler(object): + """ Base class for implementing an iterable queue handler + + Parameters + ---------- + queue: queue function or class + The queue that will be iterated over. All queues must accept an items + argument and implement generator, either through yielding values or a + Class.next() method. + items: list + A list of items to iterate over. + Additional key-value pairs are used to initialize the queue + + Attributes + ---------- + queue: queue generator or class instance + The queue that will be iterated over. + queue_parameters: dict + All additional parameters used to initialize the queue. + """ + + def __init__(self, queue, items, **queue_parameters): + if not hasattr(queue, "__call__"): + raise TypeError("queue must be a callable function or class") + # Store these in case we need to reset + self._queue = queue + self._items = items + + self.queue = queue(items=items, **queue_parameters) + self.queue_parameters = queue_parameters + + def reset(self): + """ Reset the queue """ + + self.queue = self._queue(items=self._items, **self.queue_parameters) + + def __iter__(self): + for item in self.queue: + yield item diff --git a/pyoperant/reinf.py b/pyoperant/reinf.py index e5f68649..e95ae665 100644 --- a/pyoperant/reinf.py +++ b/pyoperant/reinf.py @@ -11,6 +11,7 @@ class BaseSchedule(object): should be consequated. Always returns True. """ + def __init__(self): super(BaseSchedule, self).__init__() @@ -137,4 +138,11 @@ def consequate(self,trial): return True def __unicode__(self): - return "PR%i" % self.prob \ No newline at end of file + return "PR%i" % self.prob + +SCHEDULE_DICT = dict(continuous=ContinuousReinforcement, + fixed=FixedRatioSchedule, + fixedratio=FixedRatioSchedule, + variable=VariableRatioSchedule, + variableratio=VariableRatioSchedule, + percent=PercentReinforcement) diff --git a/pyoperant/run_experiment.py b/pyoperant/run_experiment.py new file mode 100644 index 00000000..2ff3fafb --- /dev/null +++ b/pyoperant/run_experiment.py @@ -0,0 +1,31 @@ +import os +from pyoperant import (configure, states, panels, behavior, stimuli, blocks, subjects) + + +def run(configuration_file): + """ Run an experiment or set of experiments detailed in the configuration file """ + + # Check for configuration file + if not os.path.exists(configuration_file): + raise IOError("Configuration file could not be found: %s" % configuration_file) + + # Load the configuration based on its extension + extension = os.path.splitext(configuration_file)[1].lower() + if extension == ".yaml": + parameters = configure.ConfigureYAML(configuration_file) + elif extension == ".json": + parameters = configure.ConfigureJSON(configuration_file) + else: + raise ValueError("Configuration file must be either yaml or json") + + # Set up subject + + # Set up panel + + # Set up experiments + + # Set up stimlulus conditions + + # Set up blocks + + # Set up BlockHandler diff --git a/pyoperant/states.py b/pyoperant/states.py new file mode 100644 index 00000000..0e4c3868 --- /dev/null +++ b/pyoperant/states.py @@ -0,0 +1,403 @@ +import logging +import datetime as dt +import numpy as np +from pyoperant import (EndSession, + EndExperiment, + ComponentError, + InterfaceError, + utils) + +logger = logging.getLogger(__name__) + + +class State(object): + """ States provide a nice interface for running experiments and transitioning to sleep/idle phases. By implementing __enter__ and __exit__ methods, they use the "with" statement construct that allows for simple error handling (e.g. session ends, keyboard interrupts to stop an experiment, etc.) + + Parameters + ---------- + schedulers: list of scheduler objects + These determine whether or not the state should be running, using their check method + + Methods + ------- + check() - Check if the state should be active according to its schedulers + run() - Run the state (should be used within the "with" statement) + start() - wrapper for run that includes the "with" statement + """ + + def __init__(self, experiment=None, schedulers=None): + + if schedulers is None: + schedulers = list() + if not isinstance(schedulers, list): + schedulers = [schedulers] + self.schedulers = schedulers + self.experiment = experiment + + def check(self): + """ Checks all of the states schedulers to see if the state should be active. + + Returns + ------- + True if the state should be active and False otherwise. + """ + + # If any scheduler says not to run, then don't run + for scheduler in self.schedulers: + if not scheduler.check(): + return False + + return True + + def __enter__(self): + """ Start all of the schedulers """ + + logger.info("Entering %s state" % self.__class__.__name__) + for scheduler in self.schedulers: + scheduler.start() + + return self + + def __exit__(self, type_, value, traceback): + """ Handles KeyboardInterrupt and EndExperiment exceptions to end the experiment, EndSession exceptions to end the session state, and logs all others. + """ + logger.info("Exiting %s state" % self.__class__.__name__) + + # Stop the schedulers + for scheduler in self.schedulers: + scheduler.stop() + + # Handle expected exceptions + if type_ in [KeyboardInterrupt, EndExperiment]: + logger.info("Finishing experiment") + self.experiment.end() + return True + elif type_ is EndSession: + logger.info("Session has ended") + return True + + # Log all other exceptions and raise them + if isinstance(value, Exception): + if type_ in [InterfaceError, ComponentError]: + logger.critical("There was a critical error in communicating with the hardware!") + logger.critical(repr(value)) + + return False + + def run(self): + + pass + + def start(self): + """ Implements the "with" context for this state """ + + with self as state: + state.run() + + +class Session(State): + """ Session state for running an experiment. Should be used with the "with" statement (see Examples). + + Parameters + ---------- + schedulers: list of scheduler objects + These determine whether or not the state should be running, using their check method + experiment: an instance of a Behavior class + The experiment whose session methods should be run. + + Methods + ------- + check() - Check if the state should be active according to its schedulers + run() - Run the experiment's session_main method + start() - wrapper for run that includes the "with" statement + update() - Update schedulers at the end of the trial + + Examples + -------- + with Session(experiment=experiment) as state: # Runs experiment.session_pre + state.run() # Runs experiment.session_main + # Exiting with statement runs experiment.session_post + + # "with" context is also implemented in the start() method + state = Session(experiment=experiment) + state.start() + """ + + def __enter__(self): + + self.experiment.session_pre() + for scheduler in self.schedulers: + scheduler.start() + + return self + + def run(self): + """ Runs session main """ + + self.experiment.session_main() + + def __exit__(self, type_, value, traceback): + + self.experiment.session_post() + + return super(Session, self).__exit__(type_, value, traceback) + + def update(self): + """ Updates all schedulers with information on the current trial """ + + if hasattr(self.experiment, "this_trial"): + for scheduler in self.schedulers: + scheduler.update(self.experiment.this_trial) + + +class Idle(State): + """ A simple idle state. + + Parameters + ---------- + experiment: an instance of a Behavior class + The experiment whose session methods should be run. + poll_interval: int + The interval, in seconds, at which other states should be checked to run + + Methods + ------- + run() - Run the experiment's session_main method + """ + def __init__(self, experiment=None, poll_interval=60): + + super(Idle, self).__init__(experiment=experiment, + schedulers=None) + self.poll_interval = poll_interval + + def run(self): + """ Checks if the experiment should be sleeping or running a session and kicks off those states. """ + + while True: + if self.experiment.check_sleep_schedule(): + return self.experiment._sleep.start() + elif self.experiment.check_session_schedule(): + return self.experiment.session.start() + else: + logger.debug("idling...") + utils.wait(self.poll_interval) + + +class Sleep(State): + """ A panel sleep state. Turns off all outputs, checking every so often if it should wake up + + Parameters + ---------- + experiment: an instance of a Behavior class + The experiment whose session methods should be run. + schedulers: an instance of TimeOfDayScheduler + The time of day scheduler to follow for when to sleep. + poll_interval: int + The interval, in seconds, at which other states should be checked to run + time_period: string or tuple + Either "night" or a tuple of "HH:MM" start and end times. Only used if scheduler is not provided. + + Methods + ------- + run() - Run the experiment's session_main method + """ + def __init__(self, experiment=None, schedulers=None, poll_interval=60, + time_period="night"): + + if schedulers is None: + schedulers = TimeOfDayScheduler(time_period) + self.poll_interval = poll_interval + + super(Sleep, self).__init__(experiment=experiment, + schedulers=schedulers) + + def run(self): + """ Checks every poll interval whether the panel should be sleeping and puts it to sleep """ + + while True: + logger.debug("sleeping") + self.experiment.panel.sleep() + utils.wait(self.poll_interval) + if not self.check(): + break + self.experiment.panel.wake() + + +class BaseScheduler(object): + """ Implements a base class for scheduling states + + Summary + ------- + Schedulers allow the state to be started and stopped based on certain critera. For instance, you can start the sleep state when the sun sets, or stop and session state after 100 trials. + + Methods + ------- + check() - Checks whether the state should be active + start() - Run when the state starts to initialize any variables + stop() - Run when the state finishes to close out any variables + update(trial) - Run after each trial to update the scheduler if necessary + """ + + def __init__(self): + + pass + + def start(self): + + pass + + def stop(self): + + pass + + def update(self, trial): + + pass + + def check(self): + """ This should really be implemented by the subclass """ + + raise NotImplementedError("Scheduler %s does not have a check method" % self.__class__.__name__) + + +class TimeOfDayScheduler(BaseScheduler): + """ Schedule a state to start and stop depending on the time of day + + Parameters + ---------- + time_periods: string or list + The time periods in which this schedule should be active. The value of "sun" can be passed to use the current day-night schedule. Otherwise, pass a list of tuples (start, end) (e.g. [("5:00", "17:00")] for 5am to 5pm) + + Methods + ------- + check() - Returns True if the state should be active according to this schedule + """ + + def __init__(self, time_periods="sun"): + + # Any other sanitizations? + if isinstance(time_periods, tuple): + time_periods = [time_periods] + self.time_periods = time_periods + + def check(self): + """ Returns True if the state should be active according to this schedule + """ + + return utils.check_time(self.time_periods) + + +class TimeScheduler(BaseScheduler): + """ Schedules a state to start and stop based on how long the state has been active and how long since the state was previously active. + + Parameters + ---------- + duration: int + The duration, in minutes, that the state should be active + interval: int + The time since the state was last active before it should become active again. + + Methods + ------- + start() - Stores the start time of the current state + stop() - Stores the end time of the current state + check() - Returns True if the state should activate + """ + def __init__(self, duration=None, interval=None): + + self.duration = duration + self.interval = interval + + self.start_time = None + self.stop_time = None + + def start(self): + """ Stores the start time of the current state """ + + self.start_time = dt.datetime.now() + self.stop_time = None + + def stop(self): + """ Stores the end time of the current state """ + + self.stop_time = dt.datetime.now() + self.start_time = None + + def check(self): + """ Checks if the current time is greater than `duration` minutes after start time or `interval` minutes after stop time """ + + current_time = dt.datetime.now() + # If start_time is None, the state is not active. Should it be? + if self.start_time is None: + # No interval specified, always start + if self.interval is None: + return True + + # The state hasn't activated yet, always start + if self.stop_time is None: + return True + + # Has it been greater than interval minutes since the last time? + time_since = (current_time - self.stop_time).total_seconds() / 60. + if time_since < self.interval: + return False + + # If stop_time is None, the state is currently active. Should it stop? + if self.stop_time is None: + # No duration specified, so do not stop + if self.duration is None: + return True + + # Has the state been active for long enough? + time_since = (current_time - self.start_time).total_seconds() / 60. + if time_since >= self.duration: + return False + + return True + + +class CountScheduler(BaseScheduler): + """ Schedules a state stop after a certain number of trials. + + Parameters + ---------- + max_trials: int + The maximum number of trials + + Methods + ------- + check() - Returns True if the state has not yet reached max_trials + + TODO: This could be expanded to include things like total number of rewards or correct responses. + """ + def __init__(self, max_trials=None): + + self.max_trials = max_trials + self.trial_index = 0 + + def check(self): + """ Returns True if current trial index is less than max_trials """ + + if self.max_trials is None: + return True + + return self.trial_index < self.max_trials + + def stop(self): + """ Resets the trial index since the session is over """ + + self.trial_index = 0 + + def update(self, trial): + """ Updates the current trial index """ + + self.trial_index = trial.index + + +available_states = {"idle": Idle, + "session": Session, + "sleep": Sleep} + +available_schedulers = {"day": TimeOfDayScheduler, + "timeofday": TimeOfDayScheduler, + "time": TimeScheduler} diff --git a/pyoperant/stimuli.py b/pyoperant/stimuli.py new file mode 100644 index 00000000..63e26722 --- /dev/null +++ b/pyoperant/stimuli.py @@ -0,0 +1,196 @@ +import fnmatch +import os +import wave +import logging +import random +from contextlib import closing +from pyoperant.utils import Event, filter_files + +logger = logging.getLogger(__name__) + +# TODO: Integrate this concept of "event" with the one in events.py + +class Stimulus(Event): + """docstring for Stimulus""" + def __init__(self, *args, **kwargs): + super(Stimulus, self).__init__(*args, **kwargs) + if self.label=='': + self.label = 'stimulus' + + +class AuditoryStimulus(Stimulus): + """docstring for AuditoryStimulus""" + def __init__(self, *args, **kwargs): + super(AuditoryStimulus, self).__init__(*args, **kwargs) + if self.label=='': + self.label = 'auditory_stimulus' + + @classmethod + def from_wav(cls, wavfile): + + logger.debug("Attempting to create stimulus object from %s" % wavfile) + with closing(wave.open(wavfile,'rb')) as wf: + (nchannels, sampwidth, framerate, nframes, comptype, compname) = wf.getparams() + + duration = float(nframes)/sampwidth + duration = duration * 2.0 / framerate + stim = cls(time=0.0, + duration=duration, + name=wavfile, + label='wav', + description='', + file_origin=wavfile, + annotations={'nchannels': nchannels, + 'sampwidth': sampwidth, + 'framerate': framerate, + 'nframes': nframes, + 'comptype': comptype, + 'compname': compname, + } + ) + return stim + + +class StimulusCondition(object): + """ Class to represent a single stimulus condition for an operant + conditioning experiment. The name parameter should be meaningful, as it will + be stored with the trial data. The booleans "is_rewarded" and "is_punished" + can be used to state if a stimulus should consequated according to the + experiment's reinforcement schedule. + + Parameters + ---------- + name: string + Name of the stimulus condition used in data storage + response: string, int, or bool + The value of the desired response. Used to determine if the subject's + response was correct. (e.g. "left", True) + is_rewarded: bool + Whether or not a correct response should be rewarded + is_punished: bool + Whether or not an incorrect response should be punished + files: list + A list of files to use for the condition. If files is omitted, the list + will be discovered using the file_path, file_pattern, and recursive + parameters. + file_path: string + Path to directory where stimuli are stored + recursive: bool + Whether or not to search file_path recursively + file_pattern: string + A glob pattern to filter files by + replacement: bool + Whether individual stimuli should be sampled with replacement + shuffle: bool + Whether the list of files should be shuffled before sampling. + + Attributes + ---------- + name: string + Name of the stimulus condition used in data storage + response: string, int, or bool + The value of the desired response. Used to determine if the subject's + response was correct. (e.g. "left", True) + is_rewarded: bool + Whether or not a correct response should be rewarded + is_punished: bool + Whether or not an incorrect response should be punished + files: list + All of the matching files found + replacement: bool + Whether individual stimuli should be sampled with replacement + shuffle: bool + Whether the list of files should be shuffled before sampling. + + Methods + ------- + get() + + Examples + -------- + # Get ".wav" files for a "go" condition of a "Go-NoGo" experiment + condition = StimulusCondition(name="Go", + response=True, + is_rewarded=True, + is_punished=True, + file_path="/path/to/stimulus_directory", + recursive=True, + file_pattern="*.wav", + replacement=False) + + # Get a wavefile + wavefile = condition.get() + """ + + def __init__(self, name="", response=None, is_rewarded=True, + is_punished=True, files=None, file_path="", recursive=False, + file_pattern="*", shuffle=True, replacement=False): + + # These should do something better than printing and returning + if files is None: + if len(file_path) == 0: + raise IOError("No stimulus file_path provided!") + if not os.path.exists(file_path): + raise IOError("Stimulus file_path does not exist! %s" % file_path) + + self.name = name + self.response = response + self.is_rewarded = is_rewarded + self.is_punished = is_punished + self.shuffle = shuffle + self.replacement = replacement + + if files is None: + self.files = filter_files(file_path, + file_pattern=file_pattern, + recursive=recursive) + else: + self.files = files + + self._index_list = range(len(self.files)) + if self.shuffle: + random.shuffle(self._index_list) + + logger.debug("Created new condition: %s" % self) + + def __str__(self): + + return "".join(["Condition %s: " % self.name, + "# files = %d" % len(self.files)]) + + def get(self): + """ Gets a single file from this condition's list of files. If + replacement is True, choose a file randomly with replacement. If + replacement is False, then return files in their (possibly shuffled) + order. + """ + + if len(self._index_list) == 0: + self._index_list = range(len(self.files)) + if self.shuffle: + random.shuffle(self._index_list) + + if self.replacement is True: + index = random.choice(self._index_list) + else: + index = self._index_list.pop(0) + + logger.debug("Selected file %d of %d" % (index + 1, len(self.files))) + return self.files[index] + + +class StimulusConditionWav(StimulusCondition): + """ Modifies StimulusCondition to only include .wav files. For usage + information see StimulusCondition. + """ + + def __init__(self, *args, **kwargs): + + super(StimulusConditionWav, self).__init__(file_pattern="*.wav", + *args, **kwargs) + + def get(self): + """ Gets an AuditoryStimulus instance from a chosen .wav file """ + wavfile = super(StimulusConditionWav, self).get() + + return AuditoryStimulus.from_wav(wavfile) diff --git a/pyoperant/subjects.py b/pyoperant/subjects.py new file mode 100644 index 00000000..421c53a8 --- /dev/null +++ b/pyoperant/subjects.py @@ -0,0 +1,157 @@ +import os +import csv +import logging +logger = logging.getLogger(__name__) + + +class Subject(object): + """ Class which holds information about the subject currently running the + experiment + + Parameters + ---------- + name: string + The name of the subject + filename: string + The name of the output file. The extension is used to determine the + datastore type. + + Attributes + ---------- + name: string + The name of the subject + filename: string + The name of the output file + datastore: Store object instance + The datastore object in use + + Methods + ------- + create_datastore(fields) + Creates a datastore according to filename's extension + store_data(trial) + Stores a trial's data in the datastore + """ + + def __init__(self, name=None, filename=""): + + logger.debug("Creating subject object for %s" % name) + self.name = name + self.filename = filename + logger.info("Created subject object with name %s" % self.name) + self.datastore = None + + def create_datastore(self, fields): + """ Creates a datastore object to store trial data + + Parameters + ---------- + fields: list + A list of field names to store from the trial object + + Returns + ------- + bool + True if the creation was successful + + Raises + ------ + ValueError + If filename extension is of unknown type + """ + ext = os.path.splitext(self.filename)[1].lower() + if ext == ".csv": + self.datastore = CSVStore(fields, self.filename) + else: + raise ValueError("Extension %s is of unknown type" % ext) + + logger.info("Created datastore %s for subject %s" % (self.datastore, + self.name)) + + return True + + def store_data(self, trial): + """ Stores the trial data in the datastore + + Parameters + ---------- + trial: instance of Trial + The trial to store. It should have all fields used in creation of + the datastore as attributes or annotations. + + Returns + ------- + bool + True if store succeeded + """ + trial_dict = {} + for field in self.datastore.fields: + if hasattr(trial, field): + trial_dict[field] = getattr(trial, field) + elif field in trial.annotations: + trial_dict[field] = trial.annotations[field] + else: + trial_dict[field] = None + + logger.debug("Storing data for trial %d" % trial.index) + return self.datastore.store(trial_dict) + + +class CSVStore(object): + """ Class that wraps storing trial data in a CSV file + + Parameters + ---------- + fields: list + A list of columns for the CSV file + filename: string + Full path to the csv file. Appends to the file if it already exists. + + Attributes + ---------- + fields: list + A list of columns for the CSV file + filename: string + Full path to the csv file + + Methods + ------- + store(data) + Appends data to the CSV file + """ + def __init__(self, fields, filename): + + self.filename = filename + self.fields = fields + + with open(self.filename, 'ab') as data_fh: + trialWriter = csv.writer(data_fh) + trialWriter.writerow(self.fields) + + def __str__(self): + + return "CSVStore: filename = %s, fields = %s" % (self.filename, + ", ".join(self.fields)) + + def store(self, data): + """ Appends the data to the CSV file + + Parameters + ---------- + data: dictionary + The data to store. The keys should match the fields specified when + creating the CSVStore. + + Returns + ------- + bool + True if store succeeded + """ + + with open(self.filename, 'ab') as data_fh: + trialWriter = csv.DictWriter(data_fh, + fieldnames=self.fields, + extrasaction='ignore') + trialWriter.writerow(data) + + return True diff --git a/pyoperant/trials.py b/pyoperant/trials.py new file mode 100644 index 00000000..9e791182 --- /dev/null +++ b/pyoperant/trials.py @@ -0,0 +1,151 @@ +import logging +import datetime as dt +from pyoperant import EndSession +from pyoperant.events import events + +logger = logging.getLogger(__name__) + + +class Trial(object): + """ Class that implements all basic functionality of a trial + + Parameters + ---------- + index: int + Index of the trial + experiment: instance of Experiment class + The experiment of which this trial is a part + block: instance of Block class + The block that generated this trial + condition: instance of StimulusCondition + The condition for the current trial. Provides the trial with a stimulus, + as well as reinforcement instructions + + Attributes + ---------- + index: int + Index of the trial + experiment: instance of Experiment class + The experiment of which this trial is a part + stimulus_condition: instance of StimulusCondition + The condition for the current trial. Provides the trial with a stimulus, + as well as reinforcement instructions + time: datetime + The time the trial started + session: int + Index of the current session + + Methods + ------- + run() - Runs the trial + annotate() - Annotates the trial with key-value pairs + """ + def __init__(self, + index=None, + experiment=None, + condition=None, + block=None, + *args, **kwargs): + + super(Trial, self).__init__(*args, **kwargs) + + # Object references + self.experiment = experiment + self.condition = condition + self.block = block + self.annotations = dict() + + # Trial properties + self.index = index + self.session = self.experiment.session_id + self.time = None # Set just after trial_pre + + # Likely trial details + self.stimulus = None + self.response = None + self.rt = None + self.correct = False + self.reward = False + self.punish = False + + # Trial event information + self.event = dict(name="Trial", + action="", + metadata="") + + def annotate(self, **annotations): + """ Annotate the trial with key-value pairs """ + + self.annotations.update(annotations) + + def run(self): + """ Runs the trial + + Summary + ------- + The main structure is as follows: + + Get stimulus -> Initiate trial -> Play stimulus -> Receive response -> + Consequate response -> Finish trial -> Save data. + + The stimulus, response and consequate stages are broken into pre, main, + and post stages. Only use the stages you need in your experiment. + """ + + self.experiment.this_trial = self + + # Get the stimulus + self.stimulus = self.condition.get() + + # Any pre-trial logging / computations + self.experiment.trial_pre() + + # Emit trial event + self.event.update(action="start", metadata=str(self.index)) + events.write(self.event) + + # Record the trial time + self.time = dt.datetime.now() + + # Perform stimulus playback + self.experiment.stimulus_pre() + self.experiment.stimulus_main() + self.experiment.stimulus_post() + + # Evaluate subject's response + self.experiment.response_pre() + self.experiment.response_main() + self.experiment.response_post() + + # Consequate the response with a reward, punishment or neither + if self.response == self.condition.response: + self.correct = True + if self.condition.is_rewarded and self.block.reinforcement.consequate(self): + self.reward = True + self.experiment.reward_pre() + self.experiment.reward_main() + self.experiment.reward_post() + else: + self.correct = False + if self.condition.is_punished and self.block.reinforcement.consequate(self): + self.punish = True + self.experiment.punish_pre() + self.experiment.punish_main() + self.experiment.punish_post() + + # Emit trial end event + self.event.update(action="end", metadata=str(self.index)) + events.write(self.event) + + # Finalize trial + self.experiment.trial_post() + + # Store trial data + self.experiment.subject.store_data(self) + + # Update session schedulers + self.experiment.session.update() + + if self.experiment.check_session_schedule() is False: + logger.debug("Session has run long enough. Ending") + raise EndSession diff --git a/pyoperant/utils.py b/pyoperant/utils.py index bc2ceda3..689e86ec 100644 --- a/pyoperant/utils.py +++ b/pyoperant/utils.py @@ -7,6 +7,7 @@ import traceback import shlex import os +import fnmatch import string import random import datetime as dt @@ -42,6 +43,37 @@ def default(self, obj): return obj.tolist() return json.JSONEncoder.default(self, obj) + +def filter_files(directory, file_pattern="*", recursive=False): + """ Finds all files in directory that match a specific pattern. + + Parameters + ---------- + directory: string + File path to the directory to search + file_pattern: string + A glob to filter on using fnmatch.filter + recursive: bool + Whether or not to search subdirectories + + Returns + ------- + list of matching files + """ + + if not os.path.isdir(directory): + raise IOError("%s is not a directory" % directory) + + files = list() + for rootdir, dirname, fnames in os.walk(directory): + matches = fnmatch.filter(fnames, file_pattern) + files.extend(os.path.join(rootdir, fname) for fname in matches) + if not recursive: + dirname[:] = list() + + return files + + # consider importing this from python-neo class Event(object): """docstring for Event""" @@ -132,27 +164,27 @@ def __init__(self, self.events = [] self.stim_event = None - + class Command(object): """ Enables to run subprocess commands in a different thread with TIMEOUT option. - + via https://gist.github.com/kirpit/1306188 - + Based on jcollado's solution: http://stackoverflow.com/questions/1191374/subprocess-with-timeout/4825933#4825933 - + """ command = None process = None status = None output, error = '', '' - + def __init__(self, command): if isinstance(command, basestring): command = shlex.split(command) self.command = command - + def run(self, timeout=None, **kwargs): """ Run a command then return: (status, output, error). """ def target(**kwargs): @@ -250,6 +282,9 @@ def check_time(schedule,fmt="%H:%M"): if schedule == 'sun': if is_day(): return True + elif schedule == "night": + if not is_day(): + return True else: for epoch in schedule: assert len(epoch) is 2 diff --git a/setup.py b/setup.py index 6e7523f7..6061ddce 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ 'scripts/pyoperantctl', 'scripts/allsummary.py', ], - license = "BSD", + license = "GNU Affero General Public License v3", classifiers = [ "Development Status :: 4 - Beta", "Environment :: Console",