Source code for village.pybpodapi.state_machine.state_machine_base

# mypy: ignore-errors
import logging

from village.pybpodapi.bpod.hardware.events import EventName
from village.pybpodapi.bpod.hardware.output_channels import OutputChannel
from village.pybpodapi.state_machine.conditions import Conditions
from village.pybpodapi.state_machine.global_counters import GlobalCounters
from village.pybpodapi.state_machine.global_timers import GlobalTimers

logger = logging.getLogger(__name__)


[docs] class StateMachineBase(object): """ Each Bpod trial is programmed as a virtual finite state machine. This ensures precise timing of events - for any state machine you program, state transitions will be completed in less than 250 microseconds - so inefficient coding won't reduce the precision of events in your data. .. warning:: A lot of data structures are kept here for compatibility with original matlab library which are not so python-like. Anyone is welcome to enhance this class but keep in mind that it will affect the whole pybpodapi library. :ivar Hardware hardware: bpod box hardware description associated with this state machine :ivar Channels channels: bpod box channels handling :ivar list(str) state_names: list that holds state names added to this state machine :ivar list(float) state_timers: list that holds state timers :ivar int total_states_added: holds all states added, even if name is repeated :ivar list(int) state_timer_matrix: TODO: :ivar Conditions conditions: holds conditions :ivar GlobalCounters global_counters: holds global timers :ivar GlobalTimers global_timers: holds global counters :ivar list(tuple(int)) input_matrix: TODO: :ivar list(str) manifest: list of states names that have been added to the state machine :ivar list(str) undeclared: list of states names that have been referenced but not yet added :ivar tuple(str) meta_output_names: TODO: :ivar list(tuple(int)) output_matrix: TODO: :ivar bool is_running: whether this state machine is being run on bpod box """
[docs] def __init__(self, bpod): """ :param Hardware hardware: hardware description associated with this state machine """ self.hardware = bpod.hardware self.state_names = [] self.state_timers = [0] * self.hardware.max_states # list(float) self.total_states_added = 0 # type: int # state change conditions self.state_timer_matrix = [0] * self.hardware.max_states self.conditions = Conditions( self.hardware.max_states, self.hardware.n_conditions ) self.global_counters = GlobalCounters( self.hardware.max_states, self.hardware.n_global_counters ) self.global_timers = GlobalTimers( self.hardware.max_states, self.hardware.n_global_timers ) self.input_matrix = [[] for i in range(self.hardware.max_states)] # should be incremented whenever the user uses a timer self.n_global_timers_used = 0 # should be incremented whenever the user uses a counter self.n_global_counters_used = 0 # should be incremented whenever the user uses a conditions self.n_global_conditions_used = 0 # if active uses the state 255 to store the previous state, # so the user can go back in the state machine self.use_255_back_signal = False # List of states that have been added to the state machine self.manifest = [] # List of states that have been referenced but not yet added self.undeclared = [] # output actions self.output_matrix = [[] for i in range(self.hardware.max_states)] self.is_running = False
[docs] def add_state( self, state_name, state_timer=0, state_change_conditions={}, output_actions=(), ): """ Adds a state to an existing state matrix. :param str name: A character string containing the unique name of the state. The state will automatically be assigned a number for internal use and state synchronization via the sync port :param float timer: The state timer value, given in seconds. This value must be zero or positive, and can range between 0-3600s. If set to 0s and linked to a state transition, the state will still take ~100us to execute the state's output actions before the transition completes :param dict state_change_conditions: Dictionary whose keys are names of a valid input event (state change) and values are names of states to enter if the previously listed event occurs (or 'exit' to exit the matrix and return all captured data) :param list(tuple) output_actions: a list of binary tuples where first value should contain the name of a valid output action and the second value should contain the value of the previously listed output action (see output actions for valid values). Example: .. code-block:: python sma.add_state( state_name='Port1Lit', state_timer=.25, state_change_conditions={'_Tup': 'Port3Lit'}, output_actions=[('PWM1', 255)]) """ # TODO: WHY DO WE NEED THIS IF-ELSE? if state_name not in self.manifest: self.state_names.append(state_name) self.manifest.append(state_name) state_name_idx = len(self.manifest) - 1 else: state_name_idx = self.manifest.index(state_name) self.state_names[state_name_idx] = state_name self.state_timer_matrix[state_name_idx] = state_name_idx self.state_timers[state_name_idx] = state_timer for ( event_name, event_state_transition, ) in state_change_conditions.items(): try: event_code = self.hardware.channels.event_names.index(event_name) logger.debug("Event code: %s", event_code) except: # noqa: E722 raise SMAError( "Error creating state: " + state_name + ". " + event_name + " is an invalid event name." ) if event_state_transition in self.manifest: destination_state_number = self.manifest.index(event_state_transition) else: if event_state_transition in ["exit", ">exit"]: destination_state_number = float("NaN") elif event_state_transition in ["back", ">back"]: self.use_255_back_signal = True destination_state_number = 255 else: # Send to an undeclared state (replaced later with actual state # in myBpod.sendStateMachine) self.undeclared.append(event_state_transition) destination_state_number = (len(self.undeclared) - 1) + 10000 if EventName.is_state_timer(event_name): self.state_timer_matrix[state_name_idx] = destination_state_number elif EventName.is_condition(event_name): self.conditions.matrix[state_name_idx].append( (event_code, destination_state_number) ) elif EventName.is_global_counter_end(event_name): self.global_counters.matrix[state_name_idx].append( (event_code, destination_state_number) ) elif EventName.is_global_timer_trigger(event_name): if isinstance(event_state_transition, str): v = int(event_state_transition, 2) else: v = event_state_transition self.global_timers.end_matrix[state_name_idx] = v elif EventName.is_global_timer_cancel(event_name): if isinstance(event_state_transition, str): v = int(event_state_transition, 2) else: v = event_state_transition self.global_timers.end_matrix[state_name_idx] = v elif EventName.is_global_timer_end(event_name): self.global_timers.end_matrix[state_name_idx].append( (event_code, destination_state_number) ) elif EventName.is_global_timer_start(event_name): self.global_timers.start_matrix[state_name_idx].append( (event_code, destination_state_number) ) else: self.input_matrix[state_name_idx].append( (event_code, destination_state_number) ) for action_name, action_value in output_actions: if action_name == "Valve": output_code = self.hardware.channels.output_channel_names.index( OutputChannel.Valve + str(action_value) ) output_value = 1 """ elif action_name == 'ValveState': output_code = self.hardware.channels.output_channel_names.index( OutputChannel.Valve+str(action_value) ) output_value = math.pow(2, action_value - 1) """ else: try: output_code = self.hardware.channels.output_channel_names.index( action_name ) except: # noqa: E722 raise SMAError( "Error creating state: " + state_name + ". " + action_name + " is an invalid output name." ) output_value = action_value if action_name == OutputChannel.GlobalCounterReset: self.global_counters.reset_matrix[output_value] = 1 # For backwards compatibility, integers specifying global timers # convert to equivalent binary decimals. # To specify binary, use a string of bits. if ( output_code == self.hardware.channels.events_positions.globalTimerTrigger ): self.global_timers.triggers_matrix[state_name_idx] = 2 ** ( output_value - 1 ) if output_code == self.hardware.channels.events_positions.globalTimerCancel: self.global_timers.cancels_matrix[output_value - 1] = 1 self.output_matrix[state_name_idx].append((output_code, output_value)) self.total_states_added += 1
[docs] def set_global_timer_legacy(self, timer_id=None, timer_duration=None): """ Set global timer (legacy version) :param int timer_ID: :param float timer_duration: timer duration in seconds """ self.global_timers.timers[timer_id - 1] = timer_duration
[docs] def set_global_timer( self, timer_id, timer_duration, on_set_delay=0, channel=None, on_message=1, off_message=0, loop_mode=0, loop_intervals=0, send_events=1, oneset_triggers=None, ): """ Sets the duration of a global timer. Unlike state timers, global timers can be triggered from any state (as an output action), and handled from any state (by causing a state change). :param int timer_ID: the number of the timer you are setting (an integer, 1-5). :param float timer_duration: the duration of the timer, following timer start (0-3600 seconds) :param float on_set_delay: :param str channel: channel/port name Ex: 'PWM2' :param int on_message: """ timer_channel_idx = 255 if channel is not None: try: timer_channel_idx = self.hardware.channels.output_channel_names.index( channel ) # type: int except: # noqa: E722 raise SMAError( "Error: {0} is an invalid output channel name.".format(channel) ) index = timer_id - 1 self.global_timers.timers[index] = timer_duration self.global_timers.on_set_delays[index] = on_set_delay self.global_timers.channels[index] = timer_channel_idx self.global_timers.on_messages[index] = on_message self.global_timers.off_messages[index] = off_message self.global_timers.loop_mode[index] = loop_mode self.global_timers.loop_intervals[index] = loop_intervals self.global_timers.send_events[index] = send_events if len(self.global_timers.onset_matrix) < index: for i in range(len(self.global_timers.onset_matrix), index + 1): self.global_timers.onset_matrix.append(0) if oneset_triggers is not None: self.global_timers.onset_matrix[index] = oneset_triggers
[docs] def set_global_counter( self, counter_number=None, target_event=None, threshold=None ): """ Sets the threshold and monitored event for one of the 5 global counters. Global counters can count instances of events, and handle when the count exceeds a threshold from any state (by triggering a state change). :param int counter_number: the number of the counter you are setting (an integer, 1-5). :param str target_event: port where to listen for event to count :param int threshold: number of times that should be count until trigger timer """ event_code = self.hardware.channels.event_names.index(target_event) self.global_counters.attached_events[counter_number - 1] = event_code self.global_counters.thresholds[counter_number - 1] = threshold
[docs] def set_condition(self, condition_number, condition_channel, channel_value): """ Set condition :param int condition_number: :param str condition_channel: :param int channel_value: """ channel_code = self.hardware.channels.input_channel_names.index( condition_channel ) self.conditions.channels[condition_number - 1] = channel_code self.conditions.values[condition_number - 1] = channel_value
[docs] class SMAError(Exception): pass