Source code for spaic.Neuron.Generators

# encoding: utf-8
"""
@author: Yuan Mengwen
@contact: mwyuan94@gmail.com
@project: PyCharm
@filename: Generators.py
@time:2021/6/21 16:35
@description:
"""
from .Node import Node, Generator
import torch
import numpy as np

[docs]class Poisson_Generator(Generator): """ 泊松生成器,根据输入脉冲速率生成。 Generate a poisson spike train according input rate. time: encoding window ms dt: time step HZ: cycles/s """ def __init__(self, shape=None, num=None, dec_target=None, dt=None, coding_method=('poisson', 'spike_counts', '...'), coding_var_name='O', node_type=('excitatory', 'inhibitory', 'pyramidal', '...'), rate=None, **kwargs): super(Poisson_Generator, self).__init__(shape, num, dec_target, dt, coding_method, coding_var_name, node_type, **kwargs) self.num = num # the unit of dt is 0.1ms, for each time step, the rate has to be multiplied by 1e-4 self.unit_conversion = kwargs.get('unit_conversion', 0.1) self.weight = kwargs.get('weight', 1.0) self.start_time = kwargs.get('start_time', None) self.end_time = kwargs.get('end_time', None) if rate is not None: if hasattr(rate, '__iter__'): self.source = rate else: self.source = np.array([rate]) self.new_input = True self.batch = kwargs.get('batch', 1)
[docs] def torch_coding(self, source, device): if source.size != self.num: import warnings warnings.warn("The dimension of input data should be consistent with the number of input neurons.") if source.__class__.__name__ == 'ndarray': source = torch.tensor(source, device=device) self.inp_source = source # # # if source.ndim == 0: # # batch = 1 # # else: # # batch = source.shape[SAS0] # # # shape = [self.batch, self.nu] # spk_shape = [self.time_step] + list(self.shape) # spikes = self.weight*torch.rand(spk_shape, device=device).le(source*self.unit_conversion).float() # if self.start_time is not None: # start_time_step = int(self.start_time/self.dt) # spikes[:start_time_step, ...] = 0.0 # if self.end_time is not None: # end_time_step = int(self.end_time/self.dt) # spikes[end_time_step:, ...] = 0.0 return None
[docs] def next_stage(self): if self.new_input: self.get_input() self.shape[0] = self.inp_source.shape[0] self.new_input = False if (self.start_time is None or self.start_time< self.index*self.dt) \ and (self.end_time is None or self.end_time> self.index*self.dt): spikes = self.weight*torch.rand(self.shape, device=self._backend.device[0]).le( self.inp_source*self.unit_conversion) return spikes.type(self._backend.data_type) else: return torch.zeros(self.shape, dtype=self._backend.data_type, device=self._backend.device)
Generator.register('poisson_generator', Poisson_Generator)
[docs]class Poisson_Generator2(Generator): """ 泊松生成器,根据输入脉冲速率生成。 Generate a poisson spike train according input rate. time: encoding window ms dt: time step """ def __init__(self, shape=None, num=None, dec_target=None, dt=None, coding_method=('poisson', 'spike_counts', '...'), coding_var_name='O', node_type=('excitatory', 'inhibitory', 'pyramidal', '...'), **kwargs): super(Poisson_Generator2, self).__init__(shape, num, dec_target, dt, coding_method, coding_var_name, node_type, **kwargs) self.num = num
[docs] def torch_coding(self, source, device): # assert (source >= 0).all(), "Input rate must be non-negative" if not (source >= 0).all(): import warnings warnings.warn('Input rate shall be non-negative') if source.__class__.__name__ == 'ndarray': source = torch.tensor(source, device=device) if source.ndim == 0: batch = 1 else: batch = source.shape[0] shape = list(self.shape) shape[0] = batch spk_shape = [self.time_step] + list(shape) spikes = torch.rand(spk_shape, device=device).le(source * self.dt).float() times = torch.zeros_like(spikes) spikes = torch.stack([spikes, times],dim=2) return spikes.type(self._backend.data_type)
Generator.register('poisson_generator2', Poisson_Generator2)
[docs]class CC_Generator(Generator): """ 恒定电流生成器。 Generate a constant current input. time: encoding window ms dt: time step """ def __init__(self, shape=None, num=None, dec_target=None, dt=None, coding_method=('poisson_generator', 'cc_generator', '...'), coding_var_name='O', node_type=('excitatory', 'inhibitory', 'pyramidal', '...'), **kwargs): super(CC_Generator, self).__init__(shape, num, dec_target, dt, coding_method, coding_var_name, node_type, **kwargs)
[docs] def torch_coding(self, source, device): # assert (source >= 0).all(), "Input rate must be non-negative" if not (source >= 0).all(): import warnings warnings.warn('Input rate shall be non-negative') if source.__class__.__name__ == 'ndarray': source = torch.tensor(source, dtype=self._backend.data_type, device=device) spk_shape = [self.time_step] + list(self.shape) spikes = source * torch.ones(spk_shape, device=device) return spikes.type(self._backend.data_type)
Generator.register('cc_generator', CC_Generator) Generator.register('constant_current', CC_Generator)
[docs]class Sin_Generator(Generator): """ Generate a sin current input. time: encoding window ms dt: time step """ def __init__(self, shape=None, num=None, dec_target=None, dt=None, coding_method=('poisson_generator', 'cc_generator', '...'), coding_var_name='O', node_type=('excitatory', 'inhibitory', 'pyramidal', '...'), **kwargs): super(Sin_Generator, self).__init__(shape, num, dec_target, dt, coding_method, coding_var_name, node_type, **kwargs) self.num = num
[docs] def torch_coding(self, source, device): # assert (source >= 0).all(), "Input rate must be non-negative" if source.__class__.__name__ == 'ndarray': source = torch.tensor(source, dtype=torch.float, device=device) amp = source[0] omg = 2*np.pi/source[1] # if source.ndim == 0: # batch = 1 # else: # batch = source.shape[0] # # shape = [batch, self.num] spk_shape = [self.time_step] + [1 for _ in range(len(list(self.shape)))] t = torch.arange(0, self.time_step*self.dt, self.dt, device=device).view(spk_shape) spikes = amp*torch.sin(omg*t) return spikes
Generator.register('sin_generator', Sin_Generator) Generator.register('sin', Sin_Generator)
[docs]class Ramp_Generator(Generator): def __init__(self, shape=None, num=None, dec_target=None, dt=None, coding_method=('poisson_generator', 'cc_generator', '...'), coding_var_name='O', node_type=('excitatory', 'inhibitory', 'pyramidal', '...'), **kwargs): super(Ramp_Generator, self).__init__(shape, num, dec_target, dt, coding_method, coding_var_name, node_type, **kwargs) self.base = kwargs.get('base', 0.0) self.end_time = kwargs.get('end_time', None) self.amp = kwargs.get('amp', 0.001)
[docs] def torch_coding(self, source, device): if source.__class__.__name__ == 'ndarray': source = torch.tensor(source, dtype=torch.float, device=device) if self.base.__class__.__name__ == 'ndarray': self.base = torch.tensor(self.base, dtype=torch.float, device=device) slope = source # spk_shape = [self.time_step] + [1 for _ in range(len(list(self.shape)))] t_shape = [self.time_step] + [1 for _ in range(len(list(self.shape)))] t = torch.arange(0, self.time_step*self.dt, self.dt, device=device).view(t_shape) spikes = self.amp*(slope*t + self.base) if self.end_time is not None: time_step = int(self.end_time/self.dt) spikes[time_step:,...] = 0.0 return spikes
Generator.register('ramp_generator', Ramp_Generator) Generator.register('ramp', Ramp_Generator)