Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# This allows method to be called by Mechanism.add_input_states() with set of user-specified input_states,
# while calls from init_methods continue to use owner.input_states (i.e., FeatureInputState specifications
# assigned in the **input_states** argument of the Mechanism's constructor)
input_states = input_states or owner.input_states
state_list = _instantiate_state_list(owner=owner,
state_list=input_states,
state_type=FeatureInputState,
state_param_identifier=INPUT_STATE,
reference_value=reference_value if reference_value is not None
else owner.defaults.variable,
# reference_value=reference_value,
reference_value_name=VALUE,
context=context)
# Call from Mechanism.add_states, so add to rather than assign input_states (i.e., don't replace)
if context & (ContextFlags.METHOD | ContextFlags.COMMAND_LINE):
owner.input_states.extend(state_list)
else:
owner._input_states = state_list
# Check that number of input_states and their variables are consistent with owner.defaults.variable,
# and adjust the latter if not
variable_item_is_OK = False
for i, input_state in enumerate(owner.input_states):
try:
variable_item_is_OK = iscompatible(owner.defaults.variable[i], input_state.value)
if not variable_item_is_OK:
break
except IndexError:
variable_item_is_OK = False
break
"""Set flags based on a string of ContextFlags keywords
If context is already a ContextFlags mask, return that
Otherwise, return mask with flags set corresponding to keywords in context
"""
# FIX: 3/23/18 UPDATE WITH NEW FLAGS
if isinstance(context, ContextFlags):
return context
if isinstance(context, Context):
context = context.string
context_flag = ContextFlags.UNSET
if VALIDATE in context:
context_flag |= ContextFlags.VALIDATING
if EXECUTING in context:
context_flag |= ContextFlags.EXECUTING
if CONTROL in context:
context_flag |= ContextFlags.CONTROL
if LEARNING in context:
context_flag |= ContextFlags.LEARNING
# if context == ContextFlags.TRIAL.name: # cxt-test
# context_flag |= ContextFlags.TRIAL
# if context == ContextFlags.RUN.name:
# context_flag |= ContextFlags.RUN
if context == ContextFlags.COMMAND_LINE.name:
context_flag |= ContextFlags.COMMAND_LINE
return context_flag
def add_features(self, feature_predictors):
'''Add InputStates and Projections to ModelFreeOptimizationControlMechanism for feature_predictors used to
predict `net_outcome `
**feature_predictors** argument can use any of the forms of specification allowed for InputState(s),
as well as a dictionary containing an entry with *SHADOW_EXTERNAL_INPUTS* as its key and a
list of `ORIGIN` Mechanisms and/or their InputStates as its value.
'''
feature_predictors = self._parse_feature_specs(feature_predictors=feature_predictors,
context=ContextFlags.COMMAND_LINE)
self.add_states(InputState, feature_predictors)
# standard logging
else:
if self.log_condition is None or self.log_condition is LogCondition.OFF:
return
if context is None:
context = self._owner._owner.most_recent_context
time = _get_time(self._owner._owner, context)
context_str = ContextFlags._get_context_string(context.flags)
log_condition_satisfied = self.log_condition & context.flags
if (
not log_condition_satisfied
and self.log_condition & LogCondition.INITIALIZATION
and self._owner._owner.initialization_status is ContextFlags.INITIALIZING
):
log_condition_satisfied = True
if log_condition_satisfied:
if not self.stateful:
execution_id = None
else:
execution_id = context.execution_id
if execution_id not in self.log:
self.log[execution_id] = collections.deque([])
self.log[execution_id].append(
LogEntry(time, context_str, value)
)
def _validate_params(self, request_set, target_set=None, context=None):
super()._validate_params(request_set=request_set, target_set=target_set, context=context)
if self.initialization_status == ContextFlags.INITIALIZING:
from psyneulink.core.components.ports.inputport import InputPort
from psyneulink.core.components.ports.outputport import OutputPort
if not isinstance(self.receiver, (InputPort, OutputPort, Mechanism)):
raise GatingProjectionError("Receiver specified for {} {} is not a "
"Mechanism, InputPort or OutputPort".
format(self.receiver, self.name))
flagged_items.append(ContextFlags.IDLE.name)
break
if c & condition_flags:
flagged_items.append(c.name)
if SOURCE in fields:
for c in SOURCE_FLAGS:
if not condition_flags & ContextFlags.SOURCE_MASK:
flagged_items.append(ContextFlags.NONE.name)
break
if c & condition_flags:
flagged_items.append(c.name)
string += ", ".join(flagged_items)
return string
INITIALIZATION_STATUS_FLAGS = {ContextFlags.DEFERRED_INIT,
ContextFlags.INITIALIZING,
ContextFlags.VALIDATING,
ContextFlags.INITIALIZED,
ContextFlags.RESET,
ContextFlags.UNINITIALIZED}
EXECUTION_PHASE_FLAGS = {ContextFlags.PREPARING,
ContextFlags.PROCESSING,
ContextFlags.LEARNING,
ContextFlags.CONTROL,
ContextFlags.IDLE
}
SOURCE_FLAGS = {ContextFlags.COMMAND_LINE,
ContextFlags.CONSTRUCTOR,
ContextFlags.INSTANTIATE,
ContextFlags.COMPONENT,
super().__init__(owner=owner,
reference_value=reference_value,
variable=default_allocation,
size=size,
projections=modulates,
index=index,
assign=assign,
function=function,
modulation=modulation,
params=params,
name=name,
prefs=prefs,
**kwargs)
if self.initialization_status == ContextFlags.INITIALIZED:
self._assign_default_port_Name()
def __init__(self,
agent_rep=None,
function=None,
features: tc.optional(tc.optional(tc.any(Iterable, Mechanism, OutputPort, InputPort))) = None,
feature_function: tc.optional(tc.optional(tc.any(is_function_type))) = None,
num_estimates = None,
search_function: tc.optional(tc.optional(tc.any(is_function_type))) = None,
search_termination_function: tc.optional(tc.optional(tc.any(is_function_type))) = None,
search_statefulness=None,
context=None,
**kwargs):
"""Implement OptimizationControlMechanism"""
# If agent_rep hasn't been specified, put into deferred init
if agent_rep is None:
if context.source==ContextFlags.COMMAND_LINE:
# Temporarily name InputPort
self._assign_deferred_init_name(self.__class__.__name__, context)
# Store args for deferred initialization
self._store_deferred_init_args(**locals())
# Flag for deferred initialization
self.initialization_status = ContextFlags.DEFERRED_INIT
return
# If constructor is called internally (i.e., for controller of Composition),
# agent_rep needs to be specified
else:
assert False, f"PROGRAM ERROR: 'agent_rep' arg should have been specified " \
f"in internal call to constructor for {self.name}."
super().__init__(
function=function,
def _get_context(context:tc.any(ContextFlags, Context, str)):
"""Set flags based on a string of ContextFlags keywords
If context is already a ContextFlags mask, return that
Otherwise, return mask with flags set corresponding to keywords in context
"""
# FIX: 3/23/18 UPDATE WITH NEW FLAGS
if isinstance(context, ContextFlags):
return context
if isinstance(context, Context):
context = context.string
context_flag = ContextFlags.UNSET
if VALIDATE in context:
context_flag |= ContextFlags.VALIDATING
if EXECUTING in context:
context_flag |= ContextFlags.EXECUTING
if CONTROL in context:
context_flag |= ContextFlags.CONTROL
if LEARNING in context:
context_flag |= ContextFlags.LEARNING
# if context == ContextFlags.TRIAL.name: # cxt-test
# context_flag |= ContextFlags.TRIAL
# if context == ContextFlags.RUN.name:
# context_flag |= ContextFlags.RUN
def _execute(self,
variable=None,
context=None,
function_variable=None,
runtime_params=None,
):
if self.initialization_status == ContextFlags.INITIALIZING:
# Set minus_phase activity, plus_phase, current_activity and initial_value
# all to zeros with size of Mechanism's array
# Should be OK to use attributes here because initialization should only occur during None context
self._set_multiple_parameter_values(
context,
initial_value=self.input_ports[RECURRENT].socket_template,
current_activity=self.input_ports[RECURRENT].socket_template,
minus_phase_activity=self.input_ports[RECURRENT].socket_template,
plus_phase_activity=self.input_ports[RECURRENT].socket_template,
execution_phase=None,
)
if self._target_included:
self.parameters.output_activity._set(self.input_ports[TARGET].socket_template, context)
# Initialize execution_phase as minus_phase
if self.parameters.execution_phase._get(context) is None: