# -*- coding: utf-8 -*-
"""
Created on 2020/9/9
@project: SPAIC
@filename: BaseModule
@author: Hong Chaofei
@contact: hongchf@gmail.com
@description:
"""
from abc import abstractmethod
from collections import OrderedDict
from typing import Optional, Any, List
from dataclasses import dataclass, field
from copy import copy
import uuid
from ..Backend.Backend import Backend
from .Operator import Op
from .. import global_module_name_count
global global_module_name_count
[docs]class BaseModule():
'''
Base class for all snn modules (assemblies, connection, learner, monitor, piplines).
'''
_Module_Count = 0
_class_label = '<bm>'
def __init__(self):
self.id = None
self.name = None
self.enabled = True
self.training = True
self._backend: Backend = None
self._supers = []
self._var_names = []
self._var_dict = dict()
self._ops = list()
self.prefer_device = None
[docs] @abstractmethod
def build(self, backend):
NotImplementedError()
[docs] @abstractmethod
def get_str(self, level):
NotImplementedError()
[docs] def set_name(self, given_name):
global global_module_name_count
if isinstance(given_name, str):
if self.name is None:
self.name = given_name
elif 'autoname' in self.name:
self.name = given_name
# elif isinstance(given_name, list):
# context = given_name[-1]
# spaic.global_module_name_count += 1
# self.name = context.name +'subgroup' + str(spaic.global_module_name_count)
else:
global_module_name_count += 1
self.name = 'autoname' + str(global_module_name_count)
return self.name
[docs] def set_id(self):
if len(self._supers) == 0:
self.id = self.name + self.__class__._class_label
else:
super_ids = []
for super in self._supers:
if super.id is not None:
super_ids.append(super.id)
else:
super_ids.append(super.set_id())
self.id = self.name + self.__class__._class_label
if len(super_ids) == 1:
self.id = super_ids[0] + '_' + self.id
else:
pre_id = '/'
for prefix in super_ids:
pre_id += prefix + ','
pre_id += '/'
self.id = pre_id + '_' + self.id
return self.id
[docs] def set_build_level(self, level):
if self.build_level < 0:
self.build_level = level
elif self.build_level > level:
self.build_level = level
[docs] def variable_to_backend(self, name, shape, value=None, is_parameter=False, is_sparse=False, init=None,
init_param=None,
min=None, max=None, is_constant=False, prefer_device=None):
self._var_names.append(name)
self._var_dict[name] = self._backend.add_variable(self, name, shape, value, is_parameter, is_sparse, init,
init_param, min, max, is_constant, prefer_device)
return self._var_dict[name]
[docs] def op_to_backend(self, outputs: list, func: callable, inputs: list):
# check if the inputs and outputs variables belongs to this module object, if backend don't have this variable it will be added the module label
if isinstance(inputs, list):
for ind, input_name in enumerate(inputs):
if '[updated]' in input_name:
input_name = input_name.replace('[updated]', '')
if not self._backend.has_variable(input_name):
input_name = self._add_label(input_name)
inputs[ind] = input_name
# assert self._backend.has_variable(input_name)
elif isinstance(inputs, str):
if not self._backend.has_variable(inputs):
inputs = self._add_label(inputs)
inputs = [inputs]
# assert self._backend.has_variable(inputs[-1])
else:
raise ValueError("the preprocessing of op_to_backend do not support this input type")
if isinstance(outputs, list):
for ind, output_name in enumerate(outputs):
if not self._backend.has_variable(output_name):
output_name = self._add_label(output_name)
outputs[ind] = output_name
# assert self._backend.has_variable(output_name)
elif isinstance(outputs, str):
if not self._backend.has_variable(outputs):
outputs = self._add_label(outputs)
outputs = [outputs]
# assert self._backend.has_variable(outputs[-1])
else:
raise ValueError("the preprocessing of op_to_backend do not support this input type")
addcode_op = Op(outputs, func, inputs, owner=self, operation_type='_operations')
self._backend.add_operation(addcode_op)
[docs] def init_op_to_backend(self, outputs, func, inputs, prefer_device=0):
addcode_op = Op(outputs, func, inputs, place=prefer_device, owner=self, operation_type='_operations')
self._backend.register_initial(addcode_op)
# adding label of the module object, cut from neurongroup and generalized to all Modules
def _add_label(self, key):
if isinstance(key, str):
if key == '[dt]':
return key
elif '[updated]' in key:
return self.id + ':' + '{' + key.replace('[updated]', "") + '}' + '[updated]'
else:
return self.id + ':' + '{' + key + '}'
elif isinstance(key, VariableAgent):
return key.var_name
else:
raise ValueError(" the key data type is not supported for add_label")
[docs] def get_full_name(self, name):
name = '{' + name + '}'
full_name = None
for key in self._var_names:
if name in key:
if full_name is not None:
raise ValueError("multiple variable with same name in this module")
else:
full_name = key
return full_name
[docs] def get_value(self, name):
full_name = self.get_full_name(name)
if full_name is None:
raise ValueError("No such variable name in this module")
else:
return self._var_dict[full_name].value
[docs] def set_value(self, name, value):
name = '{' + name + '}'
full_name = None
for key in self._var_names:
if name in key:
if full_name is not None:
raise ValueError("multiple variable with same name in this module")
else:
full_name = key
if full_name is None:
raise ValueError("No such variable name in this module")
else:
self._var_dict[full_name].value = value
def _direct_set_variable(self, name, variable):
# only for debug at the beginning of the network run
name = '{' + name + '}'
full_name = None
for key in self._var_names:
if name in key:
if full_name is not None:
raise ValueError("multiple variable with same name in this module")
else:
full_name = key
if full_name is None:
raise ValueError("No such variable name in this module")
else:
is_parameter = self._var_dict[full_name]._is_parameter
if is_parameter:
self._backend._parameters_dict[full_name] = variable
else:
self._backend._InitVariables_dict[full_name] = variable
[docs]class VariableAgent(object):
def __init__(self, backend, var_name, is_parameter=False, dict_label=None):
super(VariableAgent, self).__init__()
assert isinstance(backend, Backend)
self._backend: Backend = backend
self._var_name = var_name
self._is_parameter = is_parameter
self.data_type = None
self.device = None
self.dict_label = dict_label
self.set_funcs = []
self.get_funcs = []
@property
def var_name(self):
return self._var_name
[docs] def new_labeled_agent(self, dict_label):
assert (dict_label == 'variables_dict' or dict_label == 'update_dict'
or dict_label == 'reduce_dict' or dict_label == 'temp_dict')
agent = copy(self)
agent.dict_label = dict_label
return agent
@property
def value(self):
if self.dict_label is None:
return self._backend.get_varialble(self._var_name)
elif self.dict_label == 'variables_dict':
return self._backend._variables[self._var_name]
elif self.dict_label == 'update_dict':
return self._backend._update_dict[self._var_name]
elif self.dict_label == 'reduce_dict':
return self._backend._reduce_dict[self._var_name]
elif self.dict_label == 'temp_dict':
return self._backend._temp_dict[self._var_name]
else:
raise ValueError("can't find variable %s" % self._var_name)
@value.setter
def value(self, value):
if self.dict_label is None:
self._backend.set_variable_value(self._var_name, value, self._is_parameter)
elif self.dict_label == 'update_dict':
self._backend._update_dict[self._var_name] = value
elif self.dict_label == 'reduce_dict':
if self._var_name in self._backend._reduce_dict:
self._backend._reduce_dict[self._var_name].append(value)
else:
self._backend._reduce_dict[self._var_name] = [value]
elif self.dict_label == 'temp_dict':
self._backend._temp_dict[self._var_name] = value
elif self.dict_label == 'variables_dict':
self._backend._variables[self._var_name] = value
else:
raise ValueError("can't set value of variable %s" % self._var_name)
[docs]class OperationCommand(object):
def __init__(self, front_module, output, function, input):
super(OperationCommand, self).__init__()
assert isinstance(front_module, BaseModule)
assert isinstance(output, list)
assert isinstance(function, str) or callable(function)
assert isinstance(input, list)
self.front_module = front_module
self.output = output
self.function = function
self.input = input
self.training_only = False
@property
def enabled(self):
if self.training_only:
return self.front_module.enabled and self.front_module.training
else:
return self.front_module.enabled
# class NetModule(BaseModule):
# '''
# Base class for snn network modules: assemblies, connection
# '''
#
# def __init__(self):
# super(NetModule, self).__init__()
#
# self.trainable_parameter_names = OrderedDict()
#
# def add_trainable_names(self, name):
# pass