Source code for spaic.Learning.STBP_Learner

# -*- coding: utf-8 -*-
"""
Created on 2021/3/30
@project: SPAIC
@filename: STBP_Learner
@author: Mengxiao Zhang
@contact: mxzhangice@gmail.com

@description:
"""
from .Learner import Learner


[docs]class STBP(Learner): ''' STBP learning rule. Args: alpha(num) : The parameter alpha of STBP learning model. trainable : The parameter whether it can be trained. Methods: build(self, backend): Build the backend, realize the algorithm of STBP model. threshold(self, x, v_th): Get the threshold of the STBP model. Example: Net._learner = STBP(0.5, Net) Reference: Yujie Wu et al. "Spatio-Temporal Backpropagation for Training High-performance Spiking Neural Networks" In: Frontiers in Neuroscience, 2018. Volume 12. pp. 331. doi:10.3389/fnins.2018.00331 url:ttps://www.frontiersin.org/article/10.3389/fnins.2018.00331 ''' def __init__(self,trainable=None, **kwargs): super(STBP, self).__init__(trainable=trainable, **kwargs) self.alpha = kwargs.get('alpha', 0.5) self.prefered_backend = ['pytorch'] self.name = 'STBP' self.firing_func = None self.parameters = kwargs
[docs] def build(self, backend): ''' Build the backend, realize the algorithm of STBP model. Args: backend: The backend we used to compute. ''' super(STBP, self).build(backend) self.device = backend.device0 if backend.backend_name == 'pytorch': import torch import math class ActFun(torch.autograd.Function): """ Approximate firing func. """ @staticmethod def forward( ctx, input, thresh, alpha ): import torch import math ctx.thresh = thresh ctx.alpha = alpha ctx.save_for_backward(input) return input.gt(thresh).type_as(input) @staticmethod def backward( ctx, grad_output ): input, = ctx.saved_tensors grad_input = grad_output.clone() ctx.alpha = ctx.alpha.to(input) temp = torch.exp(-(input - ctx.thresh) ** 2 / (2 * ctx.alpha)) \ / (2 * math.pi * ctx.alpha) result = grad_input * temp.type_as(input) return result, None, None self.firing_func = ActFun() self.alpha = torch.tensor(self.alpha).to(self.device) # self.backend.basic_operate['threshold'] = self.threshold # replace threshold operation in all trainable neuron_groups for neuron in self.trainable_groups.values(): for key in neuron._operations.keys(): if 'threshold' in key: neuron._operations[key].func = self.threshold
[docs] def threshold(self, x, v_th): ''' Get the threshold of the STBP model. return: A method that use STBP model to compute the threshold. ''' return self.firing_func.apply(x, v_th, self.alpha)
Learner.register("stbp", STBP)