Source code for spaic.Learning.RSTDP

# encoding: utf-8
"""
@author: Yuan Mengwen
@contact: mwyuan94@gmail.com
@project: PyCharm
@filename: RSTDP.py
@time:2021/4/8 10:46
@description:
"""

from .Learner import Learner
import numpy as np
import torch

[docs]class RSTDP(Learner): """ Reward-modulated STDP. Adapted from `(Florian 2007) <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`. Args: lr (int or float): learning rate trainable: It can be network or neurongroups Attributes: tau_plus (int or float): Time constant for pre -synaptic firing trace determines the range of interspike intervals over which synaptic occur. tau_minus (int or float): Time constant for post-synaptic firing trace. a_plus (float): Learning rate for post-synaptic. a_minus (flaot): Learning rate for pre-synaptic. """ def __init__(self, trainable=None, **kwargs): super(RSTDP, self).__init__(trainable=trainable) self.prefered_backend = ['pytorch'] self.name = 'RSTDP' self.learning_rate = kwargs.get('lr', 0.1) self._tau_constant_variables = dict() self._tau_constant_variables['tau_plus'] = kwargs.get('tau_plus', 20.0) self._tau_constant_variables['tau_minus'] = kwargs.get('tau_minus', 20.0) self._constant_variables = dict() self._constant_variables['A_plus'] = kwargs.get('A_plus', 1.0) self._constant_variables['A_minus'] = kwargs.get('A_minus', -1.0)
[docs] def weight_update(self, weight, eligibility, reward): """ RSTDP learning rule for ``Connection`` subclass of ``AbstractConnection`` class. Args: weight : weight between pre and post neurongroup eligibility: a decaying memory of the relationships between the recent pairs of pre and postsynaptic spike pairs reward: reward signal """ # Compute weight update based on the eligibility value of the past timestep. with torch.no_grad(): if len(reward.shape) > 1 and reward.shape[1] == eligibility.shape[0]: reward = reward.transpose(1, 0) reward = reward.repeat(1, eligibility.shape[1]) weight.add_(self.learning_rate * reward * eligibility) return weight
[docs] def build(self, backend): super(RSTDP, self).build(backend) self.dt = backend.dt for (key, tau_var) in self._tau_constant_variables.items(): tau_var = np.exp(-self.dt / tau_var) shape = () self.variable_to_backend(key, shape, value=tau_var) for (key, var) in self._constant_variables.items(): if isinstance(var, np.ndarray): if var.size > 1: var_shape = var.shape shape = (1, *var_shape) # (1, shape) else: shape = () elif isinstance(var, list): if len(var) > 1: var_len = len(var) shape = (1, var_len) # (1, shape) else: shape = () else: shape = () self.variable_to_backend(key, shape, value=var) permute_name = 'rstdp_permute_dim' permute_dim_value = [1, 0] self.variable_to_backend(permute_name, shape=None, value=permute_dim_value, is_constant=True) reward_name = 'Output_Reward[updated]' # Traverse all trainable connections for conn in self.trainable_connections.values(): preg = conn.pre postg = conn.post pre_name = conn.get_input_name(preg, postg) post_name = conn.get_group_name(postg, 'O') weight_name = conn.get_link_name(preg, postg, 'weight') # p_plus tracks the influence of presynaptic spikes; p_minus tracks the influence of postsynaptic spikes p_plus_name = pre_name + '_{p_plus}' p_minus_name = post_name + '_{p_minus}' eligibility_name = weight_name + '_{eligibility}' pre_shape_temp = backend._variables[pre_name].shape if len(pre_shape_temp) > 2 and len(pre_shape_temp) == 4: pre_shape = [pre_shape_temp[0], pre_shape_temp[1] * pre_shape_temp[2] * pre_shape_temp[3]] else: pre_shape = pre_shape_temp self.variable_to_backend(p_plus_name, pre_shape, value=0.0) self.variable_to_backend(p_minus_name, backend._variables[post_name].shape, value=0.0) self.variable_to_backend(eligibility_name, backend._variables[weight_name].shape, value=0.0) # Equations # p_plus *= np.exp(-dt / tau_plus) # p_plus += A_plus * pre # p_minus *= np.exp(-dt / tau_minus) # p_minus += A_minus * post # eligibility = torch.matmul(post.transpose(1, 0), p_plus) + torch.matmul(p_minus.transpose(1, 0), pre) # Update p_plus values # self.op_to_backend('p_plus_temp', 'var_mult', 'A_plus', pre_name)) # self.op_to_backend(p_plus_name, 'var_linear', 'tau_plus', p_plus_name, 'p_plus_temp')) self.op_to_backend('p_plus_temp', 'var_mult', ['tau_plus', p_plus_name]) if len(pre_shape_temp) > 2 and len(pre_shape_temp) == 4: # if pre layer is 2d feature map self.op_to_backend('pre_name_temp', 'feature_map_flatten', pre_name) self.op_to_backend(p_plus_name, 'var_linear', ['A_plus', 'pre_name_temp', 'p_plus_temp']) else: self.op_to_backend(p_plus_name, 'var_linear', ['A_plus', pre_name, 'p_plus_temp']) # Update p_minus values # self.op_to_backend('p_minus_temp', 'var_mult', 'A_minus', post_name)) # self.op_to_backend(p_minus_name, 'var_linear', 'tau_minus', p_minus_name, 'p_minus_temp')) self.op_to_backend('p_minus_temp', 'var_mult', ['tau_minus', p_minus_name]) self.op_to_backend(p_minus_name, 'var_linear', ['A_minus', post_name, 'p_minus_temp']) # Calculate point eligibility value self.op_to_backend('post_permute', 'permute', [post_name, permute_name]) self.op_to_backend('pre_post', 'mat_mult', ['post_permute', p_plus_name + '[updated]']) self.op_to_backend('p_minus_permute', 'permute', [p_minus_name + '[updated]', permute_name]) if len(pre_shape_temp) > 2 and len(pre_shape_temp) == 4: self.op_to_backend('post_pre', 'mat_mult', ['p_minus_permute', 'pre_name_temp']) else: self.op_to_backend('post_pre', 'mat_mult', ['p_minus_permute', pre_name]) self.op_to_backend(eligibility_name, 'add', ['pre_post', 'post_pre']) self.op_to_backend(weight_name, self.weight_update, [weight_name, eligibility_name, reward_name])
Learner.register('rstdp', RSTDP)
[docs]class RSTDPET(Learner): """ Reward-modulated STDP with eligibility trace. Adapted from `(Florian 2007) <https://florian.io/papers/2007_Florian_Modulated_STDP.pdf>`. Args: lr (int or float): learning rate trainable: It can be network or neurongroups Attributes: tau_plus (int or float): Time constant for pre -synaptic firing trace determines the range of interspike intervals over which synaptic occur. tau_minus (int or float): Time constant for post-synaptic firing trace. a_plus (float): Learning rate for post-synaptic. a_minus (flaot): Learning rate for pre-synaptic. Notes: Batch_size for network using RSTDPET as learning algorithm must be 1. """ def __init__(self, lr, trainable=None, **kwargs): super(RSTDPET, self).__init__(trainable=trainable) self.prefered_backend = ['pytorch'] self.name = 'RSTDPET' self.learning_rate = lr self._tau_constant_variables = dict() self._tau_constant_variables['tau_plus'] = kwargs.get('tau_plus', 20.0) self._tau_constant_variables['tau_minus'] = kwargs.get('tau_minus', 20.0) self._tau_constant_variables['tau_e_trace'] = kwargs.get('tau_e_trace', 25.0) self._tau_membrane_variables = dict() self._tau_membrane_variables['tau_e'] = kwargs.get('tau_e', 25.0) self._constant_variables = dict() self._constant_variables['A_plus'] = kwargs.get('A_plus', 1.0) self._constant_variables['A_minus'] = kwargs.get('A_minus', -1.0)
[docs] def weight_update(self, weight, eligibility_trace, reward): """ RSTDPET learning rule for ``Connection`` subclass of ``AbstractConnection`` class. Notes: The batch of pre and post should be 1 Args: weight : weight between pre and post neurongroup eligibility_trace: a decaying memory of the relationships between the recent pairs of pre and postsynaptic spike pairs reward: reward signal """ with torch.no_grad(): # Keep parameters consistent through inplace operation if len(reward.shape) > 1 and reward.shape[1] == eligibility_trace.shape[0]: reward = reward.transpose(1, 0) reward = reward.repeat(1, eligibility_trace.shape[1]) weight.add_(self.learning_rate * self.dt * reward * eligibility_trace) return weight
[docs] def build(self, backend): super(RSTDPET, self).build(backend) view_name = 'rstdpet_view_dim' view_dim_value = [-1] self.variable_to_backend(view_name, shape=None, value=view_dim_value, is_constant=True) reward_name = 'Output_Reward[updated]' # Traverse all trainable connections for conn in self.trainable_connections.values(): preg = conn.pre postg = conn.post pre_name = conn.get_input_name(preg, postg) post_name = conn.get_group_name(postg, 'O') weight_name = conn.get_link_name(preg, postg, 'weight') # p_plus tracks the influence of presynaptic spikes; p_minus tracks the influence of postsynaptic spikes p_plus_name = pre_name + '_{p_plus}' p_minus_name = post_name + '_{p_minus}' eligibility_name = weight_name + '_{eligibility}' eligibility_trace_name = weight_name + '_{eligibility_trace}' self.variable_to_backend(p_plus_name, backend._variables[pre_name].shape[1:], value=0.0) self.variable_to_backend(p_minus_name, backend._variables[post_name].shape[1:], value=0.0) self.variable_to_backend(eligibility_name, backend._variables[weight_name].shape, value=0.0) self.variable_to_backend(eligibility_trace_name, backend._variables[weight_name].shape, value=0.0) # Equations # pre = pre.view(-1) # post = post.view(-1) # p_plus *= np.exp(-dt / tau_plus) # p_plus += A_plus * pre # p_minus *= np.exp(-dt / tau_minus) # p_minus += A_minus * post # eligibility = torch.ger(post, p_plus) + torch.ger(p_minus, pre) # eligibility_trace *= np.exp(-dt / tau_e_trace) # eligibility_trace += eligibility / tau_e # Update p_plus values self.op_to_backend('pre_view', 'view', [pre_name, view_name]) # self.op_to_backend('p_plus_temp', 'var_mult', 'A_plus', 'pre_view')) # self.op_to_backend(p_plus_name, 'var_linear', 'tau_plus', p_plus_name, 'p_plus_temp')) self.op_to_backend('p_plus_temp', 'var_mult', ['tau_plus', p_plus_name]) self.op_to_backend(p_plus_name, 'var_linear', ['A_plus', 'pre_view', 'p_plus_temp']) # Update p_minus values self.op_to_backend('post_view', 'view', [post_name, view_name]) # self.op_to_backend('p_minus_temp', 'var_mult', 'A_minus', 'post_view')) # self.op_to_backend(p_minus_name, 'var_linear', 'tau_minus', p_minus_name, 'p_minus_temp')) self.op_to_backend('p_minus_temp', 'var_mult', ['tau_minus', p_minus_name]) self.op_to_backend(p_minus_name, 'var_linear', ['A_minus', 'post_view', 'p_minus_temp']) # Calculate point eligibility value self.op_to_backend('pre_post', 'ger', ['post_view', p_plus_name + '[updated]']) self.op_to_backend('post_pre', 'ger', [p_minus_name + '[updated]', 'pre_view']) self.op_to_backend(eligibility_name, 'add', ['pre_post', 'post_pre']) self.op_to_backend('eligibility_trace_temp', 'var_mult', ['tau_e', eligibility_name + '[updated]']) self.op_to_backend(eligibility_trace_name, 'var_linear', ['tau_e_trace', eligibility_trace_name, 'eligibility_trace_temp']) self.op_to_backend(weight_name, self.weight_update, [weight_name, eligibility_trace_name, reward_name])
Learner.register('rstdpet', RSTDPET)
[docs]class RewardSwitchSTDP(Learner): def __init__(self, trainable=None, **kwargs): super(RewardSwitchSTDP, self).__init__(trainable=trainable, **kwargs) self.trainable = trainable self.prefered_backend = ['pytorch'] self.name = 'Reward_Switch_STDP' self._constant_variables = dict() self._constant_variables['Apost_pos'] = kwargs.get('Apost_pos', 5.0e-3) self._constant_variables['Apre_pos'] = kwargs.get('Apre_pos', -4.0e-3) self._constant_variables['Apost_neg'] = kwargs.get('Apost_neg', -5.0e-3) self._constant_variables['Apre_neg'] = kwargs.get('Apre_neg', 4.0e-3) self._constant_variables['pre_decay'] = kwargs.get('pre_decay', np.exp(-1/20.0)) post_decay = kwargs.get('post_decay', np.exp(-1/20.0)) self._constant_variables['post_decay'] = post_decay self._constant_variables['homoestatic'] = kwargs.get('homoestatic', 1.0e-10) m_rate = kwargs.get("m_rate", 20.0) self._constant_variables['m_rate'] = m_rate*(1-post_decay)/(1000.0*post_decay) self.w_min = kwargs.get('w_min', 0.0) self.w_max = kwargs.get('w_max', 0.5) self.w_norm = 1.2 self.w_mean = None self.lr = kwargs.get('lr',0.1) self.param_run_update = True self.reward_name = kwargs.get('reward_name','Output_Reward[updated]')
[docs] def update(self, input, output, reward, input_trace, output_trace, pre_decay, post_decay, Apost_pos, Apost_neg, Apre_pos, Apre_neg, m_rate, homoestatic, weight): if self.w_mean is None: self.w_mean = torch.mean(weight, dim=1, keepdim=True).detach() self.aw_mean = torch.mean(self.w_mean) if self.training: input_trace = pre_decay * input_trace * input.le(0.0) + input output_trace = post_decay * output_trace * output.le(0.0) + output # with torch.no_grad(): # self.w_mean = self.w_mean * self.aw_mean / torch.mean(self.w_mean) Apost = reward.gt(0) * Apost_pos + reward.le(0) * Apost_neg Apre = reward.gt(0) * Apre_pos + reward.le(0) * Apre_neg pre_post = torch.matmul(output.permute(1, 0), Apost*input_trace) post_pre = torch.matmul((Apre*output_trace).permute(1, 0), input) dw = (pre_post + post_pre)/(1.0*input.shape[0]) # + torch.mean((m_rate-output_trace)*homoestatic,dim=0).unsqueeze(dim=1) weight = weight + dw weight = torch.clamp(weight, 0, 0.1) return input_trace, output_trace, weight
[docs] def build(self, backend): self._constant_variables['m_rate'] = self._constant_variables['m_rate']*backend.dt self._constant_variables['homoestatic'] = self._constant_variables['homoestatic'] / self._constant_variables['m_rate'] super(RewardSwitchSTDP, self).build(backend) self.dt = backend.dt self.run_time = backend.runtime for conn in self.trainable_connections.values(): preg = conn.pre postg = conn.post pre_name = conn.get_input_name(preg, postg) post_name = conn.get_group_name(postg, 'O') weight_name = conn.get_link_name(preg, postg, 'weight') input_trace_name = conn.id + '_{input_trace}' output_trace_name = conn.id + '_{output_trace[stay]}' self.variable_to_backend(input_trace_name, backend._variables[pre_name].shape, value=0.0) self.variable_to_backend(output_trace_name, backend._variables[post_name].shape, value=0.0) self.op_to_backend([input_trace_name, output_trace_name, weight_name], self.update, [pre_name, post_name, self.reward_name, input_trace_name, output_trace_name, 'pre_decay', 'post_decay', 'Apost_pos', 'Apost_neg', 'Apre_pos', 'Apre_neg', 'm_rate', 'homoestatic', weight_name])
Learner.register('switch_rstdp', RewardSwitchSTDP)