# -*- coding: utf-8 -*-
"""
Created on 2020/8/12
@project: SPAIC
@filename: Learner
@author: Hong Chaofei
@contact: hongchf@gmail.com
@description:
定义学习模块,包括各种学习算法对仿真计算过程中插入的各种计算模块,以及记录需要学习连接的接口
"""
from ..Network.Assembly import BaseModule
from abc import ABC, abstractmethod
from collections import OrderedDict
import torch
import torch.nn.functional as F
import numpy as np
from ..Backend.Backend import Backend
from ..Network.Assembly import Assembly
from ..Network.Topology import Connection, Projection
from ..Neuron.Neuron import NeuronGroup
from ..Neuron.Module import Module
from ..Neuron.Node import Node
[docs]class Learner(BaseModule, ABC):
'''
Base learner model for all the learner model.
Args:
parameters(dict) : The parameters for learner.
super_parameters(dict) : Super parameters for future use.
backend_functions(dict) : Contains all the learner model we can choose.
name(str) : The typical name for the learner.
preferred_backend(str) : Choose which kind of backend to use. Like "Pytorch", "Tensorflow" or "Jax".
trainable_groups(dict) : Trainable container, includes nodes and layers to train.
trainable_connections(dict) : Trainable container, includes connections to train.
init_trainable: The initial state of this learner of whether it is trainable.
Methods:
add_trainable(self, trainable) : Add target object (Network, Assembly, Connection,
or list of them) to the trainable container
build(self, backend) : Build Learner, choose the backend as user wish, if we have already finished the api.
'''
_class_label = '<learner>'
learning_algorithms = dict()
learning_optims = dict()
optim_dict = {'Adam': torch.optim.Adam, 'AdamW': torch.optim.AdamW,
'SparseAdam': torch.optim.SparseAdam, 'Adamax': torch.optim.Adamax,
'ASGD': torch.optim.ASGD, 'LBFGS': torch.optim.LBFGS,
'RMSprop': torch.optim.RMSprop, 'Rpop': torch.optim.Rprop,
'SGD': torch.optim.SGD, 'Adadelta': torch.optim.Adadelta,
'Adagrad': torch.optim.Adagrad}
lr_schedule_dict = {'LambdaLR': torch.optim.lr_scheduler.LambdaLR,
'StepLR': torch.optim.lr_scheduler.StepLR,
'MultiStepLR': torch.optim.lr_scheduler.MultiStepLR,
'ExponentialLR': torch.optim.lr_scheduler.ExponentialLR,
'CosineAnnealingLR': torch.optim.lr_scheduler.CosineAnnealingLR,
'ReduceLROnPlateau': torch.optim.lr_scheduler.ReduceLROnPlateau,
'CyclicLR': torch.optim.lr_scheduler.CyclicLR,
'CosineAnnealingWarmRestarts': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts}
def __init__(self, trainable=None, pathway=None, algorithm=('STCA', 'STBP', 'RSTDP', '...'), name=None, **kwargs):
super(Learner, self).__init__()
self.parameters = kwargs
self.super_parameters = OrderedDict()
self.backend_functions = OrderedDict()
self.name = self.set_name(name)
self.optim_name = None
self.optim = None
self.lr_schedule_name = None
self.gradient_based = True
self.prefered_backend = ['pytorch']
self.trainable_groups = OrderedDict()
self.trainable_connections = OrderedDict()
self.trainable_nodes = OrderedDict()
self.trainable_modules = OrderedDict()
self.trainable_others = OrderedDict()
self.init_trainable = trainable
self.training_param_name = []
self.param_run_update = kwargs.get("param_run_update", False)
self._variables = dict()
self._constant_variables = dict()
self._tau_constant_variables = dict()
self._tau_membrane_variables = dict()
self.pathway_groups = OrderedDict()
self.pathway_connections = OrderedDict()
self.pathway_nodes = OrderedDict()
self.pathway_modules = OrderedDict()
self.pathway_others = OrderedDict()
self.init_pathway = pathway
self._operations = []
self._active = True
[docs] def add_trainable(self, trainable: list):
'''
Add target object (Assembly, Connection, or list of them) to the trainable container
Args:
trainable(list) : The trainable target waiting for added.
'''
if not isinstance(trainable, list):
trainable = [trainable]
for target in trainable:
if isinstance(target, NeuronGroup):
self.trainable_groups[target.id] = target
elif isinstance(target, Connection):
self.trainable_connections[target.id] = target
elif isinstance(target, Node):
self.trainable_nodes[target.id] = target
elif isinstance(target, Module):
self.trainable_modules[target.id] = target
elif isinstance(target, Assembly):
for sub_t in target.get_groups():
trainable.append(sub_t)
for sub_t in target.get_connections():
trainable.append(sub_t)
elif isinstance(target, Projection):
for sub_t in target.get_connections():
trainable.append(sub_t)
elif isinstance(target, BaseModule):
self.trainable_others[target.id] = target
[docs] def add_pathway(self, pathway: list):
'''
Add target object (Assembly, Connection, or list of them) to the pathway container
Args:
pathway(list) : The pathway target waiting for added.
'''
if not isinstance(pathway, list):
pathway = [pathway]
for target in pathway:
if isinstance(target, NeuronGroup):
self.pathway_groups[target.id] = target
elif isinstance(target, Connection):
self.pathway_connections[target.id] = target
elif isinstance(target, Node):
self.pathway_nodes[target.id] = target
elif isinstance(target, Module):
self.pathway_modules[target.id] = target
elif isinstance(target, Assembly):
for sub_t in target.get_groups():
pathway.append(sub_t)
for sub_t in target.get_connections():
pathway.append(sub_t)
elif isinstance(target, Projection):
for sub_t in target.get_connections():
pathway.append(sub_t)
elif isinstance(target, BaseModule):
self.pathway_others[target.id] = target
[docs] def build(self, backend: Backend):
'''
Build Learner, choose the backend as user wish, if we have already finished the api.
Args:
backend(backend) : Backend we have.
'''
if self.init_trainable is not None: # If user has given the 'trainable' parameter.
self.add_trainable(self.init_trainable)
if self.init_pathway is not None:
self.add_pathway(self.init_pathway)
if backend.backend_name in self.prefered_backend:
self._backend = backend
else:
raise ValueError(
"the backend %s is not supported by the learning rule %s" % (backend.backend_name, self.name))
if self.optim_name is not None:
self.build_optimizer()
else:
self.get_param()
if self.lr_schedule_name is not None:
self.build_lr_schedule()
# set all trainable or pathway compontents to requires_grad=Ture
if self.gradient_based:
for node in self.trainable_nodes.values():
for op in node._ops:
op.requires_grad = True
for group in self.trainable_groups.values():
for op in group._ops:
op.requires_grad = True
for con in self.trainable_connections.values():
for op in con._ops:
op.requires_grad = True
for other in self.trainable_others.values():
for op in other._ops:
op.requires_grad = True
for mod in self.trainable_modules.values():
mod.module.requires_grad_()
for op in mod._ops:
op.requires_grad = True
for node in self.pathway_nodes.values():
for op in node._ops:
op.requires_grad = True
for group in self.pathway_groups.values():
for op in group._ops:
op.requires_grad = True
for con in self.pathway_connections.values():
for op in con._ops:
op.requires_grad = True
for other in self.pathway_others.values():
for op in other._ops:
op.requires_grad = True
for mod in self.pathway_modules.values():
mod.module.requires_grad_()
for op in mod._ops:
op.requires_grad = True
# add learner variables to the backend
self.dt = backend.dt
for (key, tau_var) in self._tau_constant_variables.items():
tau_var = np.exp(-self.dt / tau_var)
shape = ()
self.variable_to_backend(self._add_label(key), shape, value=tau_var)
for (key, tau_membrane_var) in self._tau_membrane_variables.items():
tau_membrane_var = self.dt/tau_membrane_var
shape = () # (1, neuron_num)
self.variable_to_backend(self._add_label(key), shape, value=tau_membrane_var)
for (key, var) in self._constant_variables.items():
if isinstance(var, np.ndarray):
if var.size > 1:
var_shape = var.shape
shape = (1, *var_shape) # (1, shape)
else:
shape = ()
elif isinstance(var, list):
if len(var) > 1:
var_len = len(var)
shape = (1, var_len) # (1, shape)
else:
shape = ()
else:
shape = ()
self.variable_to_backend(self._add_label(key), shape, value=var)
for (key, var) in self._variables.items():
if isinstance(var, np.ndarray):
if var.size > 1:
var_shape = var.shape
shape = (1, *var_shape) # (1, shape)
else:
shape = ()
elif isinstance(var, list):
if len(var) > 1:
var_len = len(var)
shape = (1, var_len) # (1, shape)
else:
shape = ()
else:
shape = ()
self.variable_to_backend(self._add_label(key), shape, value=var)
# build custom learning rules
if self.is_overridden(self.custom_rule):
self.custom_rule(backend)
if self.is_overridden(self.connection_rule):
for conn in self.trainable_connections.values():
self.connection_rule(conn, backend, 'trainable')
for conn in self.pathway_connections.values():
self.connection_rule(conn, backend, 'pathway')
if self.is_overridden(self.neuron_rule):
for neuron in self.trainable_groups.values():
self.neuron_rule(neuron, backend, 'trainable')
for neuron in self.pathway_groups.values():
self.neuron_rule(neuron, backend, 'pathway')
def __new__(cls, trainable=None, pathway=None, algorithm=('STCA', 'STBP', 'RSTDP', '...'), **kwargs):
if cls is not Learner:
return super().__new__(cls)
if algorithm == ('STCA', 'STBP', 'RSTDP', '...'):
return super().__new__(cls)
elif algorithm.lower() in cls.learning_algorithms:
return cls.learning_algorithms[algorithm.lower()](trainable=trainable, pathway=None, **kwargs)
else:
raise ValueError("No algorithm %s in algorithm list" % algorithm)
[docs] def active(self):
self._active = True
#TODO: setting all learnerable and pathway into gradient = True
[docs] def deactive(self):
self._active = False
#TODO: setting all learnerable and pathway into gradient = Fasle
[docs] def set_optimizer(self, optim_name, optim_lr, **kwargs):
self.optim_lr = optim_lr
self.optim_para = kwargs
self.optim_name = optim_name
if self.optim_name not in Learner.optim_dict.keys():
raise ValueError("No optim %s in optim list" % Learner.optim_dict)
[docs] def set_schedule(self, lr_schedule_name, **kwargs):
self.lr_schedule_para = kwargs
self.lr_schedule_name = lr_schedule_name
if self.lr_schedule_name not in Learner.lr_schedule_dict.keys():
raise ValueError("No lr_schedule %s in lr_schedule list")
[docs] def get_param(self):
param = list()
var_name = list()
for key, conn in self.trainable_connections.items():
for name in conn._var_names:
var_name.append(name)
for key, node in self.trainable_nodes.items():
for name in node._var_names:
var_name.append(name)
for key, group in self.trainable_groups.items():
for name in group._var_names:
var_name.append(name)
self.training_param_name = []
for key, value in self._backend._parameters_dict.items():
if key in var_name:
value.requires_grad = True
param.append(value)
self.training_param_name.append(key)
else:
value.requires_grad =False
for mod in self.trainable_modules.values():
param.extend(mod.parameters)
return param
[docs] def get_varname(self, key):
name = self.name + ':{' + key + '}'
return name
[docs] def get_var_names(self):
return self._var_names
[docs] def build_optimizer(self):
param = self.get_param()
self.optim = Learner.optim_dict[self.optim_name](param, self.optim_lr, **self.optim_para)
[docs] def build_lr_schedule(self):
self.schedule = Learner.lr_schedule_dict[self.lr_schedule_name](self.optim, **self.lr_schedule_para)
[docs] def optim_step(self):
if self.param_run_update:
with torch.no_grad():
for key, value in self._backend._parameters_dict.items():
if key in self.training_param_name:
varialbe_value = self._backend._variables[key]
if varialbe_value is not value:
value.data = varialbe_value.data
self._backend._variables[key] = value #Is this step needed?
if self.optim is not None:
self.optim.step()
[docs] def optim_zero_grad(self):
self.optim.zero_grad()
[docs] def optim_schedule(self):
self.schedule.step()
[docs] @staticmethod
def register(name, algorithm):
name = name.lower()
if name in Learner.learning_algorithms:
raise ValueError(('A learning algorithm with the name "%s" has already been registered') % name)
if not issubclass(algorithm, Learner):
raise ValueError(
('Given algorithm of type %s does not seem to be a valid algorithm.' % str(type(algorithm))))
Learner.learning_algorithms[name] = algorithm
[docs] def custom_rule(self, backend: Backend):
return None
[docs] def connection_rule(self, con : Connection, backend: Backend, obj_type='trainable'):
return None
[docs] def neuron_rule(self, neuron: NeuronGroup, backend: Backend, obj_type='trainable'):
return None
[docs] def is_overridden(self, func):
if isinstance(self, Learner):
return False
super_func = getattr(super(type(self), self), func.__name__)
return super_func != func
[docs]class ReSuMe(Learner):
def __init__(self):
super(ReSuMe, self).__init__()
pass
# Learner.register("ReSuMe", ReSuMe)
[docs]class FORCE(Learner):
def __init__(self):
super(FORCE, self).__init__()
pass
# Learner.register("force", FORCE)