Source code for spaic.Network.Network

# -*- coding: utf-8 -*-
"""
Created on 2020/8/5
@project: SPAIC
@filename: Network
@author: Hong Chaofei
@contact: hongchf@gmail.com

@description:
定义网络以及子网络,网络包含所有的神经网络元素、如神经集群、连接以及学习算法、仿真器等,实现最终的网络仿真与学习。
执行过程:网络定义->网络生成->网络仿真与学习
"""
from .Assembly import Assembly
from collections import OrderedDict
from warnings import warn
from ..Backend.Backend import Backend
from ..Backend.Torch_Backend import Torch_Backend

try:
    import torch
except:
    pass


[docs]class Network(Assembly): _class_label = '<net>' def __init__(self, name=None): super(Network, self).__init__(name=name) self._monitors = OrderedDict() self._learners = OrderedDict() self._pipline = None self._backend: Backend = None self._forward_build = False pass # --------- Frontend code ----------
[docs] def set_backend(self, backend=None, device='cpu', partition=False): if isinstance(device, str): device = [device] if backend is None: self._backend = Torch_Backend(device) self._backend.partition = partition elif isinstance(backend, Backend): self._backend = backend elif isinstance(backend, str): if backend == 'torch' or backend == 'pytorch': self._backend = Torch_Backend(device) self._backend.partition = partition
# elif backend == 'tensorflow': # self._backend = spaic.Tensorflow_Backend(device)
[docs] def set_backend_dt(self, dt=0.1, partition=False): if self._backend is None: warn("have not set backend, default pytorch backend is set automatically") self._backend = Torch_Backend('cpu') self._backend.dt = dt else: self._backend.dt = dt self._backend.partition = partition
[docs] def set_random_seed(self, seed): if isinstance(self._backend, Torch_Backend): import torch torch.random.manual_seed(int(seed)) if self._backend.device == 'cuda': torch.cuda.manual_seed(int(seed))
[docs] def get_testparams(self): self.all_Wparams = list() for key, value in self._backend._parameters_dict.items(): self.all_Wparams.append(value) return self.all_Wparams
[docs] def add_learner(self, name, learner): from ..Learning.Learner import Learner assert isinstance(learner, Learner) self.__setattr__(name, learner)
# TODO: 这里的setattr是否有必要要,是否可以全部放到Assembly里? def __setattr__(self, name, value): from ..Monitor.Monitor import Monitor from ..Learning.Learner import Learner super(Network, self).__setattr__(name, value) if isinstance(value, Monitor): self._monitors[name] = value elif isinstance(value, Learner): self._learners[name] = value # --------- backend code ----------
[docs] def build(self, backend=None, strategy=0, full_enable_grad=None, device=None): if full_enable_grad is not None: self.enable_full_grad(full_enable_grad) if self._backend is None: if backend is not None: if device is not None: self.set_backend(backend, device) else: self.set_backend(backend) else: if device is not None: self.set_backend(device=device) else: self.set_backend() self._backend.clear_step() # build 试运行时,假设一个runtime if self._backend.runtime is None: self._backend.runtime = 10 * self._backend.dt all_groups = self.get_groups() for asb in all_groups: asb.set_id() self.build_projections(self._backend) all_connections = self.get_connections() # for debug con_debug = False con_syn_count = 0 for con in all_connections: con.set_id() # ----根据连接,对每个神经元建立input_connection和output_connection con.pre.register_connection(con, True) con.post.register_connection(con, False) if strategy == 1: # 采取单纯的从头递归地build,一旦出现环路会陷入死循环,可以避开固有延迟的问题, # Use directly build strategy to avoid inherent delay. But cannot be used on models with loop, will fall in an endless loop. # Unfortunately, self.forward_build(all_groups, all_connections) self._backend.forward_build = True # elif strategy == 2: # # 采取策略性构建,但是目前存在两个问题: # # 1. 网络中存在Assembly块时会出现bug,尚未修复 # # 2. Connection所使用的input_spike为上一步的,需要添加[updated],目前暂不使用所以未添加 # # 该构建方式可以较大程度上避开固有延迟的问题 # self.strategy_build(self.get_groups(False)) else: # 原本的构建方式,首先构建连接,每个连接都是用上一轮神经元的输出脉冲,从而存在固有延迟的问题 # 但是可以避开环路的问题。 from multiprocessing.pool import ThreadPool as Pool self._backend.forward_build = False def build_fn(module): # if con_debug: # con_syn_count += torch.count_nonzero(connection.weight.value).item() module.build(self._backend) # for connection in all_connections: # connection.build(self._backend) # if con_debug: # import torch # con_syn_count += torch.count_nonzero(connection.weight.value).item() # for group in all_groups: # group.build(self._backend) pool = Pool(4) pool.map(build_fn, all_connections) pool.close() pool.join() pool = Pool(4) pool.map(build_fn, all_groups) pool.close() pool.join() for learner in self._learners.values(): learner.set_id() learner.build(self._backend) for monitor in self._monitors.values(): monitor.build(self._backend) self._backend.build_graph() # self._backend.build() self._backend.builded = True # if con_debug: # print("Connection synapses count:%d"%con_syn_count) # for group in all_groups: # if hasattr(group, 'index'): # group.index = 0 pass
[docs] def forward_build(self, all_groups=None, all_connections=None): builded_groups = [] builded_connections = [] nod_groups = [] for group in all_groups.copy(): if group._class_label == '<nod>': if (group._node_sub_class == '<encoder>') or (group._node_sub_class == '<generator>'): group.build(self._backend) builded_groups.append(group) all_groups.remove(group) for conn in group._output_connections: self.deep_forward_build(conn, all_groups, all_connections, builded_groups, builded_connections) for module in group._output_modules: self.deep_forward_build(module, all_groups, all_connections, builded_groups, builded_connections) else: all_groups.remove(group) nod_groups.append(group) while all_groups or all_connections: for group in all_groups: self.deep_forward_build(group, all_groups, all_connections, builded_groups, builded_connections) for conn in all_connections: self.deep_forward_build(conn, all_groups, all_connections, builded_groups, builded_connections) for group in nod_groups: group.build(self._backend) builded_groups.append(group)
[docs] def deep_forward_build(self, target, all_groups, all_connections, builded_groups, builded_connections): if (target in builded_groups) or (target in builded_connections): return if target._class_label == '<con>': pre = [target.pre] post = [target.post] elif target._class_label == '<neg>': pre = target._input_connections + target._input_modules post = target._output_connections + target._output_modules elif target._class_label == '<mod>': pre = target.input_targets.copy() post = target.output_targets.copy() else: raise ValueError("Deep forward build Error, unsupported class label.") for pr in pre: if (pr in all_groups) or (pr in all_connections): return target.build(self._backend) if target._class_label == '<con>': builded_connections.append(target) all_connections.remove(target) elif (target._class_label == '<neg>') or (target._class_label == '<mod>'): builded_groups.append(target) all_groups.remove(target) for po in post: self.deep_forward_build(po, all_groups, all_connections, builded_groups, builded_connections) return
# def strategy_build(self, all_groups=None): # builded_groups = [] # unbuild_groups = {} # output_groups = [] # level = 0 # from ..Neuron.Node import Encoder, Decoder, Generator # # ===================从input开始按深度构建计算图============== # for group in all_groups: # if isinstance(group, Encoder) or isinstance(group, Generator): # # 如果是input节点,则开始深度构建计算图 # group.build(self._backend) # builded_groups.append(group) # # all_groups.remove(group) # for conn in group._output_connections: # builded_groups, unbuild_groups = self.deep_build_conn(conn, builded_groups, # unbuild_groups, level) # elif isinstance(group, Decoder): # # 如果节点是output节点,则放入output组在最后进行构建 # output_groups.append(group) # else: # if (not group._input_connections) and (not group._output_connections): # # 孤立点的情况 # import warnings # warnings.warn('Isolated group occurs, please check the network.') # group.build(self._backend) # # if unbuild_groups: # import warnings # warnings.warn('Loop occurs') # # ====================开始构建环路================== # for key in unbuild_groups.keys(): # for i in unbuild_groups[key]: # if i in builded_groups: # continue # else: # builded_groups = self.deep_build_neurongroup_with_delay(i, builded_groups) # # # ====================构建output节点=============== # for group in output_groups: # group.build(self._backend) # # def deep_build_neurongroup(self, neuron=None, builded_groups=None, unbuild_groups=None, level=0): # conns = [i for i in neuron._input_connections if i not in builded_groups] # # conns表示神经元还没有被建立的依赖连接 # if conns: #==========如果存在conns说明有input_connections还没有被build=========== # if str(level) in unbuild_groups.keys(): # unbuild_groups[str(level)].append(neuron) # else: # unbuild_groups[str(level)] = [neuron] # return builded_groups, unbuild_groups # else: # # if neuron not in builded_groups: # if neuron._class_label == '<asb>': # neuron.build(self._backend, strategy=2) # else: # neuron.build(self._backend) # builded_groups.append(neuron) # for conn in neuron._output_connections: # builded_groups, unbuild_groups = self.deep_build_conn(conn, builded_groups, # unbuild_groups, level) # return builded_groups, unbuild_groups # # def deep_build_conn(self, conn=None, builded_groups=None, unbuild_groups=None, level=0): # conn.build(self._backend) # builded_groups.append(conn) # level += 1 # builded_groups, unbuild_groups = self.deep_build_neurongroup(conn.post_assembly, builded_groups, unbuild_groups, level) # return builded_groups, unbuild_groups # # def deep_build_conn_with_delay(self, conn, builded_groups): # conn.build(self._backend) # builded_groups.append(conn) # if conn.post_assembly not in builded_groups: # builded_groups = self.deep_build_neurongroup_with_delay(conn.post_assembly, builded_groups) # return builded_groups # # def deep_build_neurongroup_with_delay(self, neuron, builded_groups): # conns = [i for i in neuron._input_connections if i not in builded_groups] # if conns: # for conn in conns: # conn.build(self._backend) # builded_groups.append(conn) # neuron.build(self._backend) # else: # neuron.build(self._backend) # builded_groups.append(neuron) # for conn in neuron._output_connections: # if conn not in builded_groups: # builded_groups = self.deep_build_conn_with_delay(conn, builded_groups) # return builded_groups
[docs] def run(self, backend_time): self._backend.set_runtime(backend_time) if self._backend.builded is False: self.build() self._backend.initial_step() self._backend.update_time_steps()
[docs] def run_continue(self, backend_time): self._backend.set_runtime(backend_time) if self._backend.builded is False: self.build() self._backend.initial_step() self._backend.initial_continue_step() self._backend.update_time_steps()
[docs] def reset(self, ): if self._backend.builded is True: self._backend.initial_step()
[docs] def enable_full_grad(self, requires_grad=True): self._backend.full_enable_grad = requires_grad
[docs] def init_run(self): self._backend.initial_step()
[docs] def add_monitor(self, name, monitor): from ..Monitor.Monitor import Monitor assert isinstance(monitor, Monitor), "Type Error, it is not monitor" assert monitor not in self._monitors.values(), "monitor %s is already added" % (name) assert name not in self._monitors.keys(), "monitor with name: %s have the same name with an already exists monitor" % ( name) self.__setattr__(name, monitor)
# self._monitors[name] = monitor
[docs] def get_elements(self): element_dict = dict() for element in self.get_groups(): element_dict[element.id] = element return element_dict
[docs] def save_state(self, filename=None, direct=None, save=True, hdf5=False): """ Save weights in memory or on hard disk. Args: filename: The name of saved file. direct: Target direction for saving state. mode: Determines whether saved in hard disk, default set false, it means will not save on disk. Returns: state: Connections' weight of the network. """ from ..Neuron.Module import Module state = self._backend._parameters_dict if not save: return state if not filename: filename = self.name if self.name else 'autoname' if not direct: direct = './' file = filename.split('.')[0] path = direct + file + '/parameters/' import os import torch origin_path = os.getcwd() os.chdir(direct) if file not in os.listdir(): os.mkdir(file) if 'parameters' not in os.listdir('./' + file): os.mkdir('./' + file + '/parameters') # os.mkdir('./NetData/' + dict + '/backend/_parameters_dict') os.chdir('./' + file + '/parameters') if hdf5: import h5py filename = filename if direct.endswith('.hdf5') else direct + '.hdf5' with h5py.File(direct, "w") as f: for i, item in enumerate(state): f.create_dataset(item, data=self._backend._parameters_dict[item].cpu().detach().numpy()) # torch.save(self._backend._parameters_dict[item], os.getcwd()+'/'+str(i)+'.pt') print(i, item, ': saved') else: torch.save(self._backend._parameters_dict, './_parameters_dict.pt') module_dict = {} module_exist = False for group in self.get_groups(): if isinstance(group, Module): module_dict[group.id] = group.state_dict module_exist = True if module_exist: torch.save(module_dict, './module_dict.pt') os.chdir(origin_path) return
[docs] def state_from_dict(self, state=False, filename=None, direct=None, device=None): """ Reload states from memory or disk. Args: state: contains backend._parameters_dict . filename: The name of saved file. direct: Target direction for reloading state. mode: Determines whether saved in hard disk, default set false, it means will not save on disk. Returns: state: Connections' weight of the network. """ from ..Neuron.Module import Module if not self._backend: if device: self.set_backend('torch', device=device) else: self.set_backend('torch') if self._backend.builded is False: self.build() if self._backend.device != device: import warnings warnings.warn( 'Backend device setting is ' + str(self._backend.device) + '. Backend device selection is priority.') # device = self._backend.device if state: import torch if isinstance(state, dict) or isinstance(state, torch.Tensor): for key, para in state.items(): backend_key = self._backend.check_key(key, self._backend._parameters_dict) if backend_key: target_device = self._backend._parameters_dict[backend_key].device self._backend._parameters_dict[backend_key] = para.to(target_device) # if self._backend.device return else: raise ValueError("Given state has wrong type") if direct: if filename: path = direct + '/' + filename + '/parameters/' else: path = direct + '/parameters/' else: if filename: path = './' + filename + '/parameters/' else: path = './parameters/' import os import torch origin_path = os.getcwd() try: os.chdir(path) except: raise ValueError('Wrong Path.') if '_parameters_dict.pt' in os.listdir('./'): data = torch.load('./_parameters_dict.pt', map_location=self._backend.device0) for key, para in data.items(): backend_key = self._backend.check_key(key, self._backend._parameters_dict) if backend_key: target_device = self._backend._parameters_dict[backend_key].device self._backend._parameters_dict[backend_key] = para.to(target_device) if 'module_dict.pt' in os.listdir('./'): module_data = torch.load('./module_dict.pt', map_location=self._backend.device0) for group in self.get_groups(): if isinstance(group, Module): target_key = self._backend.check_key(group.id, module_data) group.load_state_dict(module_data[target_key]) else: for file in os.listdir('./'): if file.endswith('.hdf5'): import h5py with h5py.File(direct, 'r') as f: for key, para in f.items(): backend_key = self._backend.check_key(key, self._backend._parameters_dict) if key: target_device = self._backend._parameters_dict[backend_key].device self._backend._parameters_dict[backend_key] = para.to(target_device) os.chdir(origin_path) return