Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
###
### Add the exc learning rules to the connection, and the error ensemble to the learning rule ###
###
EtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_rec,
pre_tau=tau) }#,
#clipType=excClipType,
#decay_rate_x_dt=excPES_weightsDecayRate*dt,
#integral_tau=excPES_integralTau) }
plasticConnEE.learning_rule_type = EtoERulesDict
#plasticConnEE.learning_rule['PES'].learning_rate=0
# learning_rate has no effect
# set to zero, yet works fine!
# It works only if you set it
# in the constructor PES() above
# feedforward learning rule
InEtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_FF,
pre_tau=tau) }#,
#clipType=excClipType,
#decay_rate_x_dt=excPES_weightsDecayRate*dt,
#integral_tau=excPES_integralTau) }
InEtoE.learning_rule_type = InEtoERulesDict
if learnIfNoInput: # obsolete, no support for trialClamp
print("Obsolete flag learnIfNoInput")
sys.exit(1)
errorWt = nengo.Node( size_in=Nobs, output = lambda timeval,errWt: \
zeros2N if (timeval%Tperiod) < rampT else errWt*(np.abs(errWt)>weightErrorCutoff) )
# only learn when there is no input,
# using the local (input+err) current
# thus, only the error is used & input doesn't interfere
nengo.Connection(errorOff,errorWt,synapse=weightErrorTau)
# error to errorWt ensemble, filter for weight learning
# tau-filtered expectOut must be compared to tau-filtered ratorOut (post above)
else:
rateEvolve = nengo.Node(rateEvolveFn)
# Error = post - desired_output
rateEvolve2error = nengo.Connection(rateEvolve,error,synapse=tau,transform=-np.eye(N))
#rateEvolve2error = nengo.Connection(rateEvolve,error,synapse=None,transform=-np.eye(N))
# - desired output here (post above)
# unfiltered non-spiking reference is compared to tau-filtered spiking ratorOut (post above)
plasticConnEE = EtoE
rateEvolve_probe = nengo.Probe(rateEvolve2error, 'output')
# save the filtered/unfiltered reference as this is the 'actual' reference
###
### Add the exc learning rules to the connection, and the error ensemble to the learning rule ###
###
EtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_rec,
pre_tau=tau) }
plasticConnEE.learning_rule_type = EtoERulesDict
#plasticConnEE.learning_rule['PES'].learning_rate=0
# learning_rate has no effect
# set to zero, yet works fine!
# It works only if you set it
# in the constructor PES() above
if learnIfNoInput: # obsolete, no support for trialClamp
errorWt = nengo.Node( size_in=N, output = lambda timeval,errWt: \
zerosN if (timeval%Tperiod) < rampT else errWt*(np.abs(errWt)>weightErrorCutoff) )
# only learn when there is no input,
# using the local (input+err) current
# thus, only the error is used & input doesn't interfere
nengo.Connection(errorOff,errorWt,synapse=weightErrorTau)
# error to errorWt ensemble, filter for weight learning
else:
# Error = post - desired_output
if copycatLayer: # copy another network's behaviour
rateEvolve2error = nengo.Connection(expectOut,error,synapse=tau,transform=-np.eye(Nobs))
# - desired output (post above)
else: # copy rate evolution behaviour
#nengo.Connection(rateEvolve,error[:Nobs],synapse=tau,transform=-np.eye(Nobs))
#nengo.Connection(nodeIn,error[Nobs:],synapse=tau,transform=-np.eye(N//2))
rateEvolve2error = nengo.Connection(rateEvolve,error,synapse=None,transform=-np.eye(Nobs))
# - desired output (post above)
plasticConnEE = EtoE
rateEvolve_probe = nengo.Probe(rateEvolve2error, 'output')
###
### Add the exc learning rules to the connection, and the error ensemble to the learning rule ###
###
EtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_rec,
pre_tau=tau) }#,
#clipType=excClipType,
#decay_rate_x_dt=excPES_weightsDecayRate*dt,
#integral_tau=excPES_integralTau) }
plasticConnEE.learning_rule_type = EtoERulesDict
#plasticConnEE.learning_rule['PES'].learning_rate=0
# learning_rate has no effect
# set to zero, yet works fine!
# It works only if you set it
# in the constructor PES() above
# feedforward learning rule
InEtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_FF,
pre_tau=tau) }#,
#clipType=excClipType,
#decay_rate_x_dt=excPES_weightsDecayRate*dt,
#integral_tau=excPES_integralTau) }
# function = lambda err: np.mean(np.abs(err)))
# instead of the above gain[err(t)], I just have gain(t)
autoGainControl = nengo.Node(size_in=1,size_out=1,\
output = lambda timeval,x: \
-errorFeedbackGain*(2.*Tmax-timeval)/2./Tmax)
# multiply error with this calculated gain
errorGain = nengo.Node(size_in=N+1,size_out=N,
output = lambda timeval,x: x[:N]*x[-1])
nengo.Connection(errorOff,errorGain[:N],synapse=0.001)
nengo.Connection(autoGainControl,errorGain[-1],synapse=0.001)
# feedback the error multiplied by the calculated gain
errorFeedbackConn = nengo.Connection(\
errorGain,ratorOut,synapse=errorFeedbackTau)
if errorGainDecay and spikingNeurons: # decaying gain, works only if error is computed from spiking neurons
errorFeedbackConn.learning_rule_type = \
{'wtDecayRule':nengo.PES(decay_rate_x_dt=errorGainDecayRate*dt)}
# PES with error unconnected, so only decay
###
### error and weight probes ###
###
errorOn_p = nengo.Probe(error, synapse=None, label='errorOn')
error_p = nengo.Probe(errorWt, synapse=None, label='error')
if saveWeightsEvolution:
learnedWeightsProbe = nengo.Probe(\
plasticConnEE,'weights',sample_every=weightdt,label='EEweights')
# feedforward weights probe
learnedInWeightsProbe = nengo.Probe(\
InEtoE,'weights',sample_every=weightdt,label='InEEweights')
if initLearned:
if not plastDecoders:
# function = lambda err: np.mean(np.abs(err)))
# instead of the above gain[err(t)], I just have gain(t)
autoGainControl = nengo.Node(size_in=1,size_out=1,\
output = lambda timeval,x: \
-errorFeedbackGain*(2.*Tmax-timeval)/2./Tmax)
# multiply error with this calculated gain
errorGain = nengo.Node(size_in=N+1,size_out=N,
output = lambda timeval,x: x[:N]*x[-1])
nengo.Connection(errorOff,errorGain[:N],synapse=0.001)
nengo.Connection(autoGainControl,errorGain[-1],synapse=0.001)
# feedback the error multiplied by the calculated gain
errorFeedbackConn = nengo.Connection(\
errorGain,ratorOut,synapse=errorFeedbackTau)
if errorGainDecay and spikingNeurons: # decaying gain, works only if error is computed from spiking neurons
errorFeedbackConn.learning_rule_type = \
{'wtDecayRule':nengo.PES(decay_rate_x_dt=errorGainDecayRate*dt)}
# PES with error unconnected, so only decay
###
### error and weight probes ###
###
errorOn_p = nengo.Probe(error, synapse=None, label='errorOn')
error_p = nengo.Probe(errorWt, synapse=None, label='error')
#if not OCL and Nexc<=4000: # GPU mem is not large enough to probe large weight matrices
# learnedInWeightsProbe = nengo.Probe(\
# InEtoE,'weights',sample_every=weightdt,label='InEEweights')
# learnedWeightsProbe = nengo.Probe(\
# plasticConnEE,'weights',sample_every=weightdt,label='EEweights')
#################################
### Initialize weights if requested
#################################
if (timeval%Tperiod)weightErrorCutoff) / (1.+N*2.5e-3/np.linalg.norm(errWt)) )
errorWt = nengo.Node( size_in=N, output = lambda timeval,errWt: \
errWt*(np.abs(errWt)>weightErrorCutoff) )
nengo.Connection(errorOff,errorWt,synapse=weightErrorTau)
# error to errorWt ensemble, filter for weight learning
error_conn = nengo.Connection(\
errorWt,plasticConnEE.learning_rule['PES'],synapse=dt)
# feedforward error connection to learning rule
if not (copycatLayer and not copycatPreLearned): # don't learn ff weights if copycatLayer and not copycatPreLearned
# feedforward learning rule
InEtoERulesDict = { 'PES' : nengo.PES(learning_rate=PES_learning_rate_FF,
pre_tau=tau) }
InEtoE.learning_rule_type = InEtoERulesDict
nengo.Connection(\
errorWt,InEtoE.learning_rule['PES'],synapse=dt)
###
### feed the error back to force output to follow the input (for both recurrent and feedforward learning) ###
###
if errorFeedback and not testLearned: # no error feedback if testing learned weights
#np.random.seed(1)
if not errorGainProportion: # default error feedback
errorFeedbackConn = nengo.Connection(errorOff,ratorOut,\
synapse=errorFeedbackTau,\
transform=-errorFeedbackGain)#*(np.random.uniform(-0.1,0.1,size=(N,N))+np.eye(N)))
else:
## calculate the gain from the filtered mean(abs(error))