Source code for spaic.Monitor.Monitor

# -*- coding: utf-8 -*-
"""
Created on 2020/8/12
@project: SPAIC
@filename: Monitor
@author: Hong Chaofei
@contact: hongchf@gmail.com

@description:
定义神经集群放电以及神经元状态量、连接状态量的仿真记录模块
"""
from ..Network.Assembly import BaseModule, Assembly
from ..Network.Connections import Connection
from ..Learning.Learner import Learner
from ..Backend.Backend import Backend
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.image import AxesImage
from mpl_toolkits.axes_grid1 import make_axes_locatable


[docs]class Monitor(BaseModule): def __init__(self, target, var_name, index='full', dt=None, get_grad=False, nbatch=True): super().__init__() if isinstance(target, Assembly): self.target = target self.target_type = 'Assembly' elif isinstance(target, Connection): self.target = target self.target_type = 'Connection' elif isinstance(target, Learner): self.target = target self.target_type = 'Learner' elif target == None: self.target = None self.target_type = None else: raise ValueError("The target does not belong to types that can be watched (Assembly, Connection).") self.var_name = '{'+var_name+'}' self.index = index self.var_container = None self.get_grad = get_grad self.nbatch = nbatch self._nbatch_records = [] # all time window's record self._nbatch_times = [] self._records = [] # single time window's record self._times = [] self.dt = dt self.is_recording = True self.new_record = True
[docs] def check_var_name(self, var_name): ''' Check if variable is in the target model, and add the target id label to the variable name. Parameters ---------- var_name : original variable name Returns : modified variable name ------- ''' tar_var_name = None if var_name[1:-1] in self.backend._variables.keys(): tar_var_name = var_name[1:-1] else: for tar_name in self.target.get_var_names(): # 没有中间变量 if var_name in tar_name: tar_var_name = tar_name break if tar_var_name is not None: return tar_var_name else: raise ValueError(" Variable %s is not in the target model"%var_name)
[docs] def get_str(self, level): pass
[docs] def monitor_on(self): self.is_recording = True
[docs] def monitor_off(self): self.is_recording = False
[docs] def clear(self): NotImplementedError()
[docs] def build(self, backend: Backend): NotImplementedError()
[docs] def init_record(self): NotImplementedError()
[docs] def update_step(self): NotImplementedError()
[docs] def push_data(self, data, time): "push data to monitor by backend" self._records.append(data) self._times.append(time)
[docs]class SpikeMonitor(Monitor): def __init__(self, target, var_name='O', index='full', dt=None, get_grad=False, nbatch=False): super().__init__(target=target, var_name=var_name, index=index, dt=dt, get_grad=get_grad, nbatch=nbatch) self._transform_len = 0 self._nbatch_index = [] # all time window's record self._nbatch_times = [] self._spk_index = [] self._spk_times = [] self._records = [] # single time window's record self._times = []
[docs] def build(self, backend: Backend): self.backend = backend self.backend._monitors.append(self) self.var_name = self.check_var_name(self.var_name) self.shape = self.backend._variables[self.var_name].shape if self.dt is None: self.dt = self.backend.dt
[docs] def clear(self): self._transform_len = -1 self._nbatch_index = [] # all time window's record self._nbatch_times = [] self._spk_index = [] self._spk_times = [] self._records = [] # single time window's record self._times = []
[docs] def init_record(self): self.new_record = True if len(self._spk_index) > 0: if self.nbatch is True: if isinstance(self._spk_index[0], torch.Tensor): self._nbatch_index.append(torch.stack(self._spk_index[1:], dim=-1).cpu().detach().numpy()) else: self._nbatch_index.append(np.stack(self._spk_index[1:], axis=-1)) self._nbatch_times.append(self._times[1:]) elif self.nbatch > 0: if isinstance(self._spk_index[0], torch.Tensor): self._nbatch_index.append(torch.stack(self._spk_index[1:], dim=-1).cpu().detach().numpy()) else: self._nbatch_index.append(np.stack(self._spk_index[1:], axis=-1)) self._nbatch_times.append(self._times[1:]) if len(self._nbatch_times) > self.nbatch: self._nbatch_index = self._nbatch_index[-self.nbatch:] self._nbatch_times = self._nbatch_times[-self.nbatch:] self._records = [] # single time window's record self._times = [] self._transform_len = -1
[docs] def push_spike_train(self, spk_times, spk_index, batch_index=0): if len(self._spk_index) < batch_index+1: add_num = batch_index + 1 - len(self._spk_index) for _ in range(add_num): self._spk_index.append([]) self._spk_times.append([]) if isinstance(spk_times, list) or isinstance(spk_times, tuple): self._spk_times[batch_index].extend(spk_times) self._spk_index[batch_index].extend(spk_index) else: self._spk_times[batch_index].append(spk_times) self._spk_index[batch_index].append(spk_index) #to override the _spike_transform function when getting spk_times and spk_index self._transform_len = 1
[docs] def update_step(self, variables): ''' Recoding the variable values of the current step. Returns ------- ''' if self.is_recording is False: return from decimal import Decimal acttime = Decimal(self.backend.time/ self.dt).quantize(Decimal(str(min(self.dt, 0.1))), rounding="ROUND_HALF_UP") if int(10000 * float(acttime)) % 10000 == 0: record_value = variables[self.var_name] if self.get_grad: variables[self.var_name].retain_grad() if self.index == 'full': self._records.append(record_value) self._times.append(self.backend.time) else: if len(self.index) == record_value.ndim: self._records.append(record_value[tuple(self.index)]) self._times.append(self.backend.time) else: assert len(self.index) == record_value.ndim -1 if self.backend.backend_name == 'pytorch': record_value = torch.movedim(record_value, 0, -1) indexed_value = record_value[tuple(self.index)] indexed_value = torch.movedim(indexed_value, -1, 0) else: record_value = np.array(record_value) record_value = np.moveaxis(record_value, 0, -1) indexed_value = record_value[tuple(self.index)] indexed_value = np.moveaxis(indexed_value, -1, 0) self._records.append(indexed_value) self._times.append(self.backend.time)
def _spike_transform(self): batch_size = self.backend.get_batch_size() if len(self._records) > self._transform_len: self._transform_len = len(self._records) self._spk_index = [] self._spk_times = [] # self._spk_p = [] if isinstance(self._records[0], torch.Tensor): step = len(self._records) # for i,x in enumerate(self._records): # self._records[i] = x.cpu() rec_spikes = torch.stack(self._records, dim=-1).detach() if '{[2]' in self.var_name: for ii in range(batch_size): rec_spikes_i = rec_spikes[ii,0,...].bool().reshape(-1).cpu() rec_spikes_t = rec_spikes[ii,1,...].reshape(-1) num = int(rec_spikes_i.size(0)/step) time_seq = torch.tensor(self._times).unsqueeze(dim=0).expand(num, -1).reshape(-1) indx_seq = torch.arange(0, num).unsqueeze(dim=1).expand(-1, step).reshape(-1) time_seq = (torch.masked_select(time_seq - rec_spikes_t, rec_spikes_i) ).numpy() indx_seq = torch.masked_select(indx_seq, rec_spikes_i).numpy() self._spk_index.append(indx_seq) self._spk_times.append(time_seq) else: for ii in range(batch_size): if rec_spikes.dtype.is_complex: rec_spikes_i = (rec_spikes[ii, ...].imag.bool()*rec_spikes[ii, ...].real.gt(0.0)).reshape(-1).cpu() else: rec_spikes_i = rec_spikes[ii,...].bool().reshape(-1).cpu() num = int(rec_spikes_i.size(0)/step) time_seq = torch.tensor(self._times).unsqueeze(dim=0).expand(num, -1).reshape(-1) indx_seq = torch.arange(0, num).unsqueeze(dim=1).expand(-1, step).reshape(-1) time_seq = torch.masked_select(time_seq, rec_spikes_i).numpy() indx_seq = torch.masked_select(indx_seq, rec_spikes_i).numpy() self._spk_index.append(indx_seq) self._spk_times.append(time_seq) # self._spk_p.append(torch.masked_select(rec_spikes[ii, ...].real.reshape(-1).cpu(), rec_spikes_i).numpy()) @property def spk_times(self): self._spike_transform() return self._spk_times @property def spk_index(self): self._spike_transform() return self._spk_index @property # def spk_p(self): # self._spike_transform() # return self._spk_p @property def spk_grad(self): pass return None @property def time_spk_rate(self): if isinstance(self._records[0], torch.Tensor): if '{[2]' in self.var_name: spike = torch.stack(self._records, dim=-1).cpu().detach()[:,0,...] else: spike = torch.stack(self._records, dim=-1).cpu().detach() return torch.mean(spike, dim=0).numpy() else: if '{[2]' in self.var_name: spike = np.stack(self._records, axis=-1)[:,0,...] else: spike = np.stack(self._records, axis=-1) return np.mean(spike, axis=0).numpy() @property def time_pop_rate(self): if isinstance(self._records[0], torch.Tensor): if '{[2]' in self.var_name: spike = torch.stack(self._records, dim=-1).cpu().detach()[:,0,...] else: spike = torch.stack(self._records, dim=-1).cpu().detach() return torch.mean(spike, dim=1).numpy() else: if '{[2]' in self.var_name: spike = np.stack(self._records, axis=-1)[:,0,...] else: spike = np.stack(self._records, axis=-1) return np.mean(spike, axis=1).numpy() # def smooth_pop_rate(self, window=1.0): # @property def spk_count(self): if isinstance(self._records[0], torch.Tensor): spike = torch.stack(self._records, dim=-1).cpu() if spike.dtype.is_complex: spike = spike.real return torch.sum(spike.gt(0.0), dim=-1).numpy() else: spike = np.stack(self._records, axis=-1).__gt__(0.0) return np.sum(spike, axis=-1) @property def time(self): return np.stack(self._times, axis=-1)
[docs]class StateMonitor(Monitor): def __init__(self, target, var_name, index='full', dt=None, get_grad=False, nbatch=False): # TODO: 初始化有点繁琐,需要知道record的变量,考虑采用更直接的监控函数 super().__init__(target=target, var_name=var_name, index=index, dt=dt, get_grad=get_grad, nbatch=nbatch) self._nbatch_records = [] # all time window's record self._nbatch_times = [] self._records = [] # single time window's record self._times = []
[docs] def build(self, backend: Backend): self.backend = backend self.backend._monitors.append(self) self.var_name = self.check_var_name(self.var_name) if self.index != 'full': self.index = tuple(self.index) if self.dt is None: self.dt = self.backend.dt
[docs] def clear(self): self._nbatch_records = [] # all time window's record self._nbatch_times = [] self._records = [] # single time window's record self._times = []
[docs] def init_record(self): ''' Inite record of new trial Returns: ''' self.new_record = True self._last_step_time = 0 if len(self._records) > 0: if self.nbatch is True: if isinstance(self._records[0], torch.Tensor): self._nbatch_records.append(torch.stack(self._records, dim=-1).cpu().detach().numpy()) else: self._nbatch_records.append(np.stack(self._records, axis=-1)) self._nbatch_times.append(self._times) elif self.nbatch > 0: if isinstance(self._records[0], torch.Tensor): self._nbatch_records.append(torch.stack(self._records, dim=-1).cpu().detach().numpy()) else: self._nbatch_records.append(np.stack(self._records, axis=-1)) self._nbatch_times.append(self._times) if len(self._nbatch_times) > self.nbatch: self._nbatch_records = self._nbatch_records[-self.nbatch:] self._nbatch_times = self._nbatch_times[-self.nbatch:] self._records = [] self._times = []
[docs] def update_step(self, variables): ''' Recoding the variable values of the current step. Returns ------- ''' if self.is_recording is False: return # only data in variable_dict can be recorded now from decimal import Decimal acttime = Decimal(self.backend.time/ self.dt).quantize(Decimal(str(min(self.dt, 0.1))), rounding="ROUND_HALF_UP") if int(10000 * float(acttime)) % 10000 == 0: record_value = variables[self.var_name] if self.get_grad: var = variables[self.var_name] if var.requires_grad is True: var.retain_grad() if self.index == 'full': self._records.append(record_value) self._times.append(self.backend.time) else: if len(self.index) == record_value.ndim: self._records.append(record_value[self.index]) self._times.append(self.backend.time) else: assert len(self.index) == record_value.ndim -1 if self.backend.backend_name == 'pytorch': record_value = torch.movedim(record_value, 0, -1) indexed_value = record_value[tuple(self.index)] indexed_value = torch.movedim(indexed_value, -1, 0) else: record_value = np.array(record_value) record_value = np.moveaxis(record_value, 0, -1) indexed_value = record_value[tuple(self.index)] indexed_value = np.moveaxis(indexed_value, -1, 0) self._records.append(indexed_value) self._times.append(self.backend.time) else: print(self.backend.time) self._last_step_time = self.backend.time
@property def nbatch_values(self): if self.new_record: self._nbatch_records_ = self._nbatch_records + [torch.stack(self._records, dim=-1).cpu().detach().numpy()] self._nbatch_times_ = self._nbatch_times + [self._times] self.new_record = False return np.array([np.stack(records, axis=-1) for records in self._nbatch_records_]) @property def nbatch_times(self): if self.new_record: self._nbatch_records_ = self._nbatch_records + [torch.stack(self._records, dim=-1).cpu().detach().numpy()] self._nbatch_times_ = self._nbatch_times + [self._times] self.new_record = False return np.array([np.stack(times, axis=-1) for times in self._nbatch_times_]) @property def values(self): # return np.concatenate(self._records) if isinstance(self._records[0], torch.Tensor): # for i,x in enumerate(self._records): # self._records[i] = x.cpu() return torch.stack(self._records, dim=-1).cpu().detach().numpy() else: return np.stack(self._records, axis=-1) @property def tensor_values(self): assert isinstance(self._records[0], torch.Tensor) return torch.stack(self._records, dim=-1) @property def grads(self): if self.get_grad: grads = [] for v in self._records: if v.grad is not None: grads.append(v.grad.cpu().numpy()) else: grads.append(torch.zeros_like(v).cpu().numpy()) grads = np.stack(grads[1:], axis=-1) return grads else: return None @property def times(self): if isinstance(self._times[0], torch.Tensor): return torch.stack(self._times, dim=-1).cpu().detach().numpy() else: return np.stack(self._times, axis=-1)
[docs] def plot_weight(self, **kwargs): neuron_id = kwargs.get('neuron_id') time_id = kwargs.get('time_id') batch_id = kwargs.get('batch_id', None) new_shape = kwargs.get('new_shape') reshape = kwargs.get('reshape') axes = kwargs.get('Axes', None) ims = kwargs.get('AxesImage', None) n_sqrt = kwargs.get('n_sqrt', None) side = kwargs.get('side', None) figsize = kwargs.get('figsize', (5, 5)) cmap = kwargs.get('camp', 'hot_r') wmin = kwargs.get('wmin', 0) wmax = kwargs.get('wmax', 128) im = kwargs.get('im', None) if batch_id == None: value = self.values[:, :, time_id] # value = self.simulator._variables[ # 'autoname1<net>_connection1<con>:autoname1<net>_layer1<neg><-autoname1<net>_input<nod>:{weight}'] # value = value.cpu().detach().numpy() if reshape: # value = value.reshape(2, 5, side, side) # # value = value.transpose(0, 2, 1, 3) # value = value.reshape(2*side, 5*side) value = value.reshape(n_sqrt, n_sqrt, side, side) value = value.transpose(0, 2, 1, 3) value = value.reshape(n_sqrt * side, n_sqrt * side) square_weights = value else: square_weights = value else: value = self.nbatch_values[batch_id, :, time_id, :] if reshape: value = value.reshape(n_sqrt, n_sqrt, side, side) value = value.transpose(0, 2, 1, 3) value = value.reshape(n_sqrt * side, n_sqrt * side) square_weights = value else: square_weights = value if not im: fig, ax = plt.subplots(figsize=figsize) im = ax.imshow(square_weights, cmap=cmap, vmin=wmin, vmax=wmax) div = make_axes_locatable(ax) cax = div.append_axes("right", size="5%", pad=0.05) ax.set_xticks(()) ax.set_yticks(()) ax.set_aspect("auto") plt.colorbar(im, cax=cax) fig.tight_layout() else: im.set_data(square_weights) plt.pause(0.1) return im