Custom algorithm

Surrogate Gradient Algorithms

The most import part of surrogate gradient algorithms is that use custom gradient function to replace the original backpropagation gradient. Here we use STCA and STBP as examples to show how to use custom gradient formula.

In STCA [1] learning algorithm, the graident function is:

h(V)=\frac{1}{\alpha}sign(|V-\theta|<\alpha)

In STBP [2] learning algorithm, the graident function is:

h_4(V)=\frac{1}{\sqrt{2\pi a_4}} e^{-\frac{(V-V_th)^2)}{2a_4}}

@staticmethod
def forward(
        ctx,
        input,
        thresh,
        alpha
):
    ctx.thresh = thresh
    ctx.alpha = alpha
    ctx.save_for_backward(input)
    output = input.gt(thresh).float()
    return output

@staticmethod
def backward(
        ctx,
        grad_output
    ):
    input, = ctx.saved_tensors
    grad_input = grad_output.clone()
    temp = abs(input - ctx.thresh) < ctx.alpha  # According to STCA learning algorithm
    # temp = torch.exp(-(input - ctx.thresh) ** 2 / (2 * ctx.alpha)) \  # According to STBP learning algorithm
    #                  / (2 * math.pi * ctx.alpha)
    result = grad_input * temp.float()
    return result, None, None

Synaptic Plasticity Algorithms

We have constructed two kinds of STDP learning algorithm. The first one is based on the global synaptic plasticity, we call it full_online_STDP [3] ,another one is based on the nearest synaptic plasticity, we call it nearest_online_STDP [4] .

Full Synaptic Plasticity STDP learning algorithm

The weight update formula and weight normalization formula of this algorithm [2] :

dw &= Apost * (output\_spike * input\_trace) – Apre * (output\_trace * input\_spike) \\
weight &= weight + dw \\
weight &= self.w\_norm * weight/sum(torch.abs(weight))

At first, get the presynaptic and postsynaptic NeuronGroups from trainable_connection :

preg = conn.pre
postg = conn.post

Then, get parameters ID, such as input spike, output spike and weight name:

pre_name = conn.get_input_name(preg, postg)
post_name = conn.get_group_name(postg, 'O')
weight_name = conn.get_link_name(preg, postg, 'weight')

Add necessary parameters to Backend :

self.variable_to_backend(input_trace_name, backend._variables[pre_name].shape, value=0.0)
self.variable_to_backend(output_trace_name, backend._variables[post_name].shape, value=0.0)
self.variable_to_backend(dw_name, backend._variables[weight_name].shape, value=0.0)

Append calculate formula to Backend :

self.op_to_backend('input_trace_temp', 'var_mult', [input_trace_name, 'trace_decay'])
self.op_to_backend(input_trace_name, 'add', [pre_name, 'input_trace_temp'])

self.op_to_backend('output_trace_temp', 'var_mult', [output_trace_name, 'trace_decay'])
self.op_to_backend(output_trace_name, 'add', [post_name, 'output_trace_temp'])

self.op_to_backend('pre_post_temp', 'mat_mult_pre', [post_name, input_trace_name+'[updated]'])
self.op_to_backend('pre_post', 'var_mult', ['Apost', 'pre_post_temp'])
self.op_to_backend('post_pre_temp', 'mat_mult_pre', [output_trace_name+'[updated]', pre_name])
self.op_to_backend('post_pre', 'var_mult', ['Apre', 'post_pre_temp'])
self.op_to_backend(dw_name, 'minus', ['pre_post', 'post_pre'])
self.op_to_backend(weight_name, self.full_online_stdp_weightupdate,[dw_name, weight_name])

Weight update part:

with torch.no_grad():
    weight.add_(dw)

Weight normalization part:

weight[...] = (self.w_norm * torch.div(weight, torch.sum(torch.abs(weight), 1, keepdim=True)))
weight.clamp_(0.0, 1.0)