# -*- coding: utf-8 -*-
"""
Created on 2021/4/12
@project: SPAIC
@filename: Module
@author: Hong Chaofei
@contact: hongchf@gmail.com
@description:
wrap around deep learning module such as a cnn network lstm cell
"""
import torch
from ..Network.Assembly import Assembly
from ..Network.Operator import Op
[docs]class Module(Assembly):
_class_label = '<mod>'
def __init__(self, module=None, name=None, input_targets=[], input_var_names=['O[updated]'], output_targets=None, output_var_names=['Isyn'], module_backend='pytorch'):
super(Module, self).__init__(name)
self.module: torch.nn.Module = module
if isinstance(input_targets, list):
self.input_targets = input_targets
else:
self.input_targets = [input_targets]
if isinstance(input_var_names, list):
self.input_var_names = input_var_names
else:
self.input_var_names = [input_var_names]
if isinstance(output_targets, list):
self.output_targets = output_targets
else:
self.output_targets = [output_targets]
if isinstance(output_var_names, list):
self.output_var_names = output_var_names
else:
self.output_var_names = [output_var_names]
for in_targ in self.input_targets:
in_targ.register_module(self, True)
for out_targ in self.output_targets:
out_targ.register_module(self, False)
self.module_backend = module_backend
[docs] def standalone_run(self, *args):
return self.module(*args)
[docs] def init_variable(self, var_names=None, var_shapes=None, var_value_dict=None):
if var_names is None:
self._var_names = []
elif hasattr(var_names, '__iter__'):
self._var_names = var_names
else:
self._var_names = [var_names]
if var_shapes is None:
self.var_shapes = []
elif isinstance(var_shapes, list):
self._var_shapes = var_shapes
else:
self._var_shapes = [var_shapes for _ in range(len(self._var_names))]
self._var_values = []
if var_value_dict is None:
var_value_dict = []
elif hasattr(var_value_dict, '__iter__'):
var_value_dict = var_value_dict
else:
var_value_dict = [var_value_dict]
for var_name in self._var_names:
if var_name in var_value_dict:
self._var_values.append(var_value_dict[var_name])
else:
self._var_values.append(0.0)
[docs] def build(self, backend):
# Add module owned variables to backend
self._backend = backend
self.init_variable()
var_len = len(self._var_names)
for ii in range(var_len):
key = self.id + ":" + "{" + self._var_names[ii] + "}"
shape = (1, *self._var_shapes[ii])
self.variable_to_backend(key, shape, self._var_values[ii])
# add standalone operation
output_var_name = self.output_targets[0].id + ":" + "{" + self.output_var_names[0] + "}"
self._var_names.append(output_var_name)
input_var_names = []
for input_target, input_name in zip(self.input_targets, self.input_var_names):
input_var_name = input_target.id + ":" + "{" + input_name + "}"
self._var_names.append(input_var_name)
input_var_names.append(input_var_name)
backend.register_standalone(Op(output_var_name, self.standalone_run, input_var_names, owner=self))
self.module.to(backend.device0)
@property
def parameters(self):
return self.module.parameters() # .state_dict()
@property
def state_dict(self):
return self.module.state_dict()
[docs] def load_state_dict(self, state):
return self.module.load_state_dict(state)
[docs] def train(self, mode=True):
self.module.train(mode)