Source code for spaic.Network.BaseModule

# -*- coding: utf-8 -*-
"""
Created on 2020/9/9
@project: SPAIC
@filename: BaseModule
@author: Hong Chaofei
@contact: hongchf@gmail.com

@description: 
"""
from abc import abstractmethod
from collections import OrderedDict
import spaic
from typing import Optional,Any,List
from dataclasses import dataclass,field
from copy import copy
import uuid



[docs]class BaseModule(): ''' Base class for all snn modules (assemblies, connection, learner, monitor, piplines). ''' _Module_Count = 0 _class_label = '<bm>' def __init__(self): self.id = None self.name = None self.enabled = True self.training = True self._backend: spaic.Backend = None self._supers = [] self._var_names = [] self._var_dict = dict() self._ops = list() self.prefer_device = None
[docs] @abstractmethod def build(self, backend): NotImplementedError()
[docs] @abstractmethod def get_str(self, level): NotImplementedError()
[docs] def set_name(self, given_name): if isinstance(given_name, str): if self.name is None: self.name = given_name elif 'autoname' in self.name: self.name = given_name # elif isinstance(given_name, list): # context = given_name[-1] # spaic.global_module_name_count += 1 # self.name = context.name +'subgroup' + str(spaic.global_module_name_count) else: spaic.global_module_name_count += 1 self.name = 'autoname' + str(spaic.global_module_name_count) return self.name
[docs] def set_id(self): if len(self._supers) == 0: self.id = self.name + self.__class__._class_label else: super_ids = [] for super in self._supers: if super.id is not None: super_ids.append(super.id) else: super_ids.append(super.set_id()) self.id = self.name + self.__class__._class_label if len(super_ids) == 1: self.id = super_ids[0] + '_' + self.id else: pre_id = '/' for prefix in super_ids: pre_id += prefix + ',' pre_id += '/' self.id = pre_id + '_' + self.id return self.id
[docs] def set_build_level(self, level): if self.build_level < 0: self.build_level = level elif self.build_level > level: self.build_level = level
[docs] def variable_to_backend(self, name, shape, value=None, is_parameter=False, is_sparse=False, init=None, init_param=None, min=None, max=None, is_constant=False, prefer_device=None): self._var_names.append(name) self._var_dict[name] = self._backend.add_variable(self, name, shape, value, is_parameter, is_sparse, init, init_param, min, max, is_constant, prefer_device) return self._var_dict[name]
[docs] def op_to_backend(self, outputs:list, func:callable, inputs:list): # check if the inputs and outputs variables belongs to this module object, if backend don't have this variable it will be added the module label if isinstance(inputs, list): for ind, input_name in enumerate(inputs): if not self._backend.has_variable(input_name): input_name = self._add_label(input_name) inputs[ind] = input_name # assert self._backend.has_variable(input_name) elif isinstance(inputs, str): if not self._backend.has_variable(inputs): inputs = self._add_label(inputs) inputs = [inputs] # assert self._backend.has_variable(inputs[-1]) else: raise ValueError("the preprocessing of op_to_backend do not support this input type") if isinstance(outputs, list): for ind, output_name in enumerate(outputs): if not self._backend.has_variable(output_name): output_name = self._add_label(output_name) outputs[ind] = output_name # assert self._backend.has_variable(output_name) elif isinstance(outputs, str): if not self._backend.has_variable(outputs): outputs = self._add_label(outputs) outputs = [outputs] # assert self._backend.has_variable(outputs[-1]) else: raise ValueError("the preprocessing of op_to_backend do not support this input type") addcode_op = Op(outputs, func, inputs, owner=self, operation_type='_operations') self._backend.add_operation(addcode_op)
[docs] def init_op_to_backend(self, outputs, func, inputs, prefer_device=0): addcode_op = Op(outputs, func, inputs, place=prefer_device, owner=self, operation_type='_operations') self._backend.register_initial(addcode_op)
# adding label of the module object, cut from neurongroup and generalized to all Modules def _add_label(self, key): if isinstance(key, str): if key == '[dt]': return key elif '[updated]' in key: return self.id + ':' + '{' + key.replace('[updated]', "") + '}' + '[updated]' else: return self.id + ':' + '{' + key + '}' elif isinstance(key, VariableAgent): return key.var_name else: raise ValueError(" the key data type is not supported for add_label")
[docs] def get_full_name(self, name): name = '{'+name+'}' full_name = None for key in self._var_names: if name in key: if full_name is not None: raise ValueError("multiple variable with same name in this module") else: full_name = key return full_name
[docs] def get_value(self, name): full_name = self.get_full_name(name) if full_name is None: raise ValueError("No such variable name in this module") else: return self._var_dict[full_name].value
[docs] def set_value(self, name, value): name = '{' + name + '}' full_name = None for key in self._var_names: if name in key: if full_name is not None: raise ValueError("multiple variable with same name in this module") else: full_name = key if full_name is None: raise ValueError("No such variable name in this module") else: self._var_dict[full_name].value = value
def _direct_set_variable(self, name, variable): # only for debug at the beginning of the network run name = '{' + name + '}' full_name = None for key in self._var_names: if name in key: if full_name is not None: raise ValueError("multiple variable with same name in this module") else: full_name = key if full_name is None: raise ValueError("No such variable name in this module") else: is_parameter = self._var_dict[full_name]._is_parameter if is_parameter: self._backend._parameters_dict[full_name] = variable else: self._backend._InitVariables_dict[full_name] = variable
[docs]class VariableAgent(object): def __init__(self, backend, var_name, is_parameter=False, dict_label=None): super(VariableAgent, self).__init__() assert isinstance(backend, spaic.Backend) self._backend: spaic.Backend = backend self._var_name = var_name self._is_parameter = is_parameter self.data_type = None self.device = None self.dict_label = dict_label self.set_funcs = [] self.get_funcs = [] @property def var_name(self): return self._var_name
[docs] def new_labeled_agent(self, dict_label): assert (dict_label=='variables_dict' or dict_label=='update_dict' or dict_label=='reduce_dict' or dict_label=='temp_dict') agent = copy(self) agent.dict_label = dict_label return agent
@property def value(self): if self.dict_label is None: return self._backend.get_varialble(self._var_name) elif self.dict_label == 'variables_dict': return self._backend._variables[self._var_name] elif self.dict_label == 'update_dict': return self._backend._update_dict[self._var_name] elif self.dict_label == 'reduce_dict': return self._backend._reduce_dict[self._var_name] elif self.dict_label == 'temp_dict': return self._backend._temp_dict[self._var_name] else: raise ValueError("can't find variable %s"%self._var_name) @value.setter def value(self, value): if self.dict_label is None: self._backend.set_variable_value(self._var_name, value, self._is_parameter) elif self.dict_label == 'update_dict': self._backend._update_dict[self._var_name] = value elif self.dict_label == 'reduce_dict': if self._var_name in self._backend._reduce_dict: self._backend._reduce_dict[self._var_name].append(value) else: self._backend._reduce_dict[self._var_name] = [value] elif self.dict_label == 'temp_dict': self._backend._temp_dict[self._var_name] = value elif self.dict_label == 'variables_dict': self._backend._variables[self._var_name] = value else: raise ValueError("can't set value of variable %s" % self._var_name)
[docs]class OperationCommand(object): def __init__(self, front_module, output, function, input): super(OperationCommand, self).__init__() assert isinstance(front_module, BaseModule) assert isinstance(output, list) assert isinstance(function, str) or callable(function) assert isinstance(input, list) self.front_module = front_module self.output = output self.function = function self.input = input self.training_only = False @property def enabled(self): if self.training_only: return self.front_module.enabled and self.front_module.training else: return self.front_module.enabled
[docs]@dataclass class Op: ''' Operation data class. ''' output: Optional[List] = field(default_factory=list) func_name: Optional[str] = None input: Optional[List] = field(default_factory=list) place: Optional[Any] = None owner: Optional[BaseModule] = None requires_grad: Optional[bool] = False operation_type : Optional[str] = None # _opertaions, _init_operations, _standalone_operations func: Optional[Any] = None
[docs] def set_identifier(self, nid=None): """Initialize self._identifier""" if nid is None: self._identifier = str(uuid.uuid1()) else: self._identifier = nid
# class NetModule(BaseModule): # ''' # Base class for snn network modules: assemblies, connection # ''' # # def __init__(self): # super(NetModule, self).__init__() # # self.trainable_parameter_names = OrderedDict() # # def add_trainable_names(self, name): # pass