# -*- coding: utf-8 -*-
"""
Created on 2020/9/14
@project: SPAIC
@filename: Torch_Backend
@author: Hong Chaofei
@contact: hongchf@gmail.com
@description:
"""
from .Backend import Backend, backends
import torch
import numpy as np
# from torch import fx
# from torch.nn import Module, Parameter
import torch.nn.functional as fn
import torch.nn as nn
from typing import Tuple, Dict, Callable
from collections import defaultdict
from torch.cuda.amp import autocast
import torch.distributed
import threading
[docs]class Torch_Engine(torch.nn.Module):
def __init__(self, graph_operations):
super(Torch_Engine, self).__init__()
self._graph_operations = graph_operations
[docs] def forward(self, variables: Dict[str, torch.Tensor]):
temp_dict = dict()
update_dict = dict()
reduce_dict = dict()
for op in self._graph_operations:
# for inputs
inputs = []
for var in op.input:
if var[0] == 'variables_dict':
inputs.append(variables[var[1]])
elif var[0] == 'temp_dict':
inputs.append(temp_dict[var[1]])
elif var[0] == 'update_dict':
inputs.append(update_dict[var[1]])
elif var[0] == 'reduce_dict':
inputs.append(reduce_dict[var[1]])
# compute the operation
result = op.func(*inputs)
if len(op.output) == 1: result = [result]
# assign the result variables
for ind, var in enumerate(op.output):
if var[0] == 'temp_dict':
temp_dict[var[1]] = result[ind]
elif var[0] == 'update_dict':
update_dict[var[1]] = result[ind]
elif var[0] == 'reduce_dict':
if var[1] in reduce_dict:
reduce_dict[var[1]].append(result[ind])
else:
reduce_dict[var[1]] = [result[ind]]
return update_dict
[docs]class Torch_Backend(Backend):
backend_name = 'pytorch'
def __init__(self, device='cpu'):
super(Torch_Backend, self).__init__()
self.device = device if isinstance(device, list) else [device]
self.device0 = self.device[0]
self.device_count = len(self.device)
self.data_type = torch.float32
self.debug_data = []
self.nograd_decorator = torch.no_grad()
self.enablegrad_decorator = torch.enable_grad()
pass
[docs] def build(self):
from torch import fx
# self._graph_var_dicts = {'variables_dict': self._variables, 'temp_dict': dict(), 'update_dict': dict(),
# 'reduce_dict': dict()}
# self._graph_var_dicts['temp_dict']['example_temp_dict_pytorch_datatype'] = torch.empty(1)
# self._graph_var_dicts['update_dict']['example_temp_dict_pytorch_datatype'] = torch.empty(1)
# self._graph_var_dicts['reduce_dict']['example_temp_dict_pytorch_datatype'] = torch.empty(1)
#
# self.update_step = jit.trace(self.update_step)
self.engine = Torch_Engine(self._graph_operations)
self.engine = fx.symbolic_trace(self.engine)
# self.graph_update_step = torch.jit.script(self.engine)
self.graph_update_step = self.engine
# print(self.engine.code)
[docs] def build_graph(self):
for key, value in self._InitVariables_dict.items():
if isinstance(value, torch.Tensor):
value = value.to(self.device0)
self._InitVariables_dict[key] = value
super(Torch_Backend, self).build_graph()
[docs] def remove_tensor(self, inputs):
if not torch.is_tensor(inputs[0]):
device = inputs[1].device
else:
device = inputs[0].device
for ind, t in enumerate(inputs):
if torch.is_tensor(t):
inputs[ind] = t.to(device)
[docs] def is_insert(self, inputs):
if len(inputs) == 0:
return False
elif len(set(inputs)) < len(inputs):
return True
[docs] def move_compute_and_assign_tensors(self, op):
inputs = []
for var in op.input:
if torch.is_tensor(var.value):
var.value = var.value.to(op.place)
inputs.append(var.value)
elif isinstance(var.value, list):
for ind, x in enumerate(var.value):
if torch.is_tensor(x):
var.value[ind] = x.to(op.place)
inputs.append(var.value)
else:
inputs.append(var.value)
result = op.func(*inputs)
if len(op.output) == 1: result = [result]
for ind, var in enumerate(op.output):
var.value = result[ind]
[docs] def graph_update_step_multigpu(self):
def _worker(*tuple):
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
for group in sorted(tuple):
op = self._graph_operations[group]
# with torch.cuda.device(self.device[index]), autocast(enabled=autocast_enabled):
self.move_compute_and_assign_tensors(op)
if self.partition:
if self.partition == 'multithread':
threads = [threading.Thread(target=_worker, args=group) for group in self.groups]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
if self.partition == 'thread1':
for group in self.groups:
for i in sorted(group):
op = self._graph_operations[i]
self.move_compute_and_assign_tensors(op)
for op in self.isolate:
op = self._graph_operations[op]
self.move_compute_and_assign_tensors(op)
else:
for group in self.groups:
for op in group:
self.move_compute_and_assign_tensors(op)
[docs] def to_nograd_func(self, func):
return self.nograd_decorator(func)
[docs] def to_grad_func(self, func):
return self.enablegrad_decorator(func)
# As of now, autograd support floating point Tensor types ( half, float, double and bfloat16) and complex Tensor types (cfloat, cdouble).
[docs] def add_backend_variable(self, module, name, shape, value=None, grad=False, is_sparse=False, init=None,
init_param=None, prefer_device=None):
'''
Parameters
----------
name
shape
value
init
Returns
-------
'''
l = len(self._parameters_dict)
if prefer_device != None:
device0 = self.device[prefer_device]
else:
device0 = self.device[l % self.device_count]
if init_param is None:
init_param = dict()
if init is not None:
# self._variables[name] = self.init_param(grad, init)
data = torch.empty(shape, dtype=self.data_type, device=device0, requires_grad=grad)
init = init.lower()
if init in self.param_init_operate.keys():
self._variables[name] = self.param_init_operate[init](data, **init_param)
else:
raise ValueError("No initialize method: %s in param_init_operate" % init)
elif value is not None:
# if value is not None:
if hasattr(value, "__len__"):
if (value.shape == torch.Size([1, ]) or value.shape == torch.Size([])) and isinstance(value,
torch.Tensor):
self._variables[name] = (value.to(self.device0) * torch.ones(shape, dtype=self.data_type,
device=self.device0)).clone()
self._variables[name].requires_grad = grad
elif tuple(value.shape) != tuple(shape):
raise ValueError("Value is not scalar and the shape of Value is not equal to shape")
# add a sparse matrices with all dimensions greater than 2
elif is_sparse:
i = np.nonzero(value)
v = value[i]
# Index for sparse matrix
sparse_index = name + '_sparse_index'
self._variables[sparse_index] = torch.LongTensor(i).to(device=self.device0)
self._InitVariables_dict[sparse_index] = self._variables[sparse_index]
# Value for sparse matrix
sparse_value = name + '_sparse_value'
if init is not None:
# self._variables[sparse_value] = self.init_param(True, init)
data = torch.empty(shape, dtype=self.data_type, device=self.device0, requires_grad=True)
self._variables[sparse_value] = self.param_init_operate[init](data, **init_param)
else:
self._variables[sparse_value] = torch.tensor(v, dtype=self.data_type, requires_grad=True,
device=self.device0)
self._parameters_dict[sparse_value] = self._variables[sparse_value]
# The shape of sparse matrix
sparse_shape = name + '_sparse_shape'
self._variables[sparse_shape] = torch.Size(shape)
self._InitVariables_dict[sparse_shape] = self._variables[sparse_shape]
# Sparse matrix
self._variables[name] = torch.sparse.FloatTensor(self._variables[sparse_index],
self._variables[sparse_value],
self._variables[sparse_shape])
else:
# add a non sparse matrices with all dimensions greater than 2
if init is not None:
data = torch.empty(shape, dtype=self.data_type, device=self.device0, requires_grad=grad)
init = init.lower()
if init in self.param_init_operate.keys():
self._variables[name] = self.param_init_operate[init](data, **init_param)
else:
raise ValueError("No initialize method: %s in param_init_operate" % init)
else:
if isinstance(value, torch.Tensor):
# device0 = random.choice(self.device)
self._variables[name] = value.clone().detach().to(device0)
else:
self._variables[name] = torch.tensor(value, dtype=self.data_type, device=device0,
requires_grad=grad)
elif len(shape) == 0:
# add constant
self._variables[name] = torch.tensor(value, dtype=self.data_type, device=self.device0,
requires_grad=grad)
else:
# add a matrix through constant
if init is not None:
# self._variables[name] = self.init_param(grad, init)
data = value * torch.ones(shape, dtype=self.data_type, device=device0, requires_grad=grad)
init = init.lower()
if init in self.param_init_operate.keys():
self._variables[name] = self.param_init_operate[init](data, **init_param)
else:
raise ValueError("No initialize method: %s in param_init_operate" % init)
else:
# add a matrix through constant
self._variables[name] = (
value * torch.ones(shape, dtype=self.data_type, device=self.device0)).clone()
self._variables[name].requires_grad = grad
return self._variables[name]
[docs] def set_variable_value(self, name, value, is_parameter):
if is_parameter:
assert name in self._parameters_dict
assert self._parameters_dict[name].shape == value.shape
if not isinstance(value, torch.Tensor):
value = torch.tensor(value, dtype=self._parameters_dict[name].dtype,
device=self._parameters_dict[name].device)
with torch.no_grad():
self._parameters_dict[name].data = value
else:
assert name in self._InitVariables_dict
if isinstance(self._InitVariables_dict[name], torch.Tensor):
if not isinstance(value, torch.Tensor):
value = torch.tensor(value, dtype=self._InitVariables_dict[name].dtype,
device=self._InitVariables_dict[name].device)
assert self._InitVariables_dict[name].shape == value.shape
with torch.no_grad():
self._InitVariables_dict[name].data = value
self._variables[name].data = value
elif (type(self._InitVariables_dict[name]) is float) and (type(value) is float):
self._InitVariables_dict[name] = value
self._variables[name] = value
elif (type(self._InitVariables_dict[name]) is int) and (type(value) is int):
self._InitVariables_dict[name] = value
self._variables[name] = value
# def init_param(self, grad, *init):
# if init[0] in self.param_init_operate:
# init_op = self.param_init_operate[init[0]]
# else:
# raise ValueError("No init operate %s in param_init_operate" % init[0])
# inputs = []
# shape = init[1]
# data = torch.empty(shape, dtype=self.data_type, device=self.device, requires_grad=grad)
# inputs.append(data)
#
# for var in init[2:]:
# inputs.append(var)
# return init_op(*inputs)
[docs] def sparse_to_dense(self, index_name, value_name, shape_name):
return torch.sparse.FloatTensor(self._variables[index_name], self._variables[value_name],
self._variables[shape_name])
[docs] def get_str(self, level):
return level * ' ' + 'torch_backend'
[docs] def threshold(self, x, v_th):
return torch.gt(x, v_th).type(self.data_type)
[docs] def reset(self, v, o):
return o.eq(0) * v
[docs] def cat(self, x, dim=1):
return torch.cat(x, dim)
[docs] def stack(self, x, dim=1): # 在指定维度dim上连接(concatenate)若干个张量。
try:
return torch.stack(x, dim)
except:
# patch for SLIF 2[O]
for ii in range(len(x)):
if x[ii].dim() == 2:
tmp = torch.zeros_like(x[ii])
tmp = torch.stack([x[ii], tmp], dim=1)
x[ii] = tmp
return torch.stack(x, dim)
[docs] def reduce_sum(self, x, *dim):
if len(dim) == 0:
dim = 1
return torch.sum(x, dim=dim)
[docs] def index_select(self, x, indices, dim=1):
return torch.index_select(x, dim=dim, index=indices)
[docs] def permute(self, x, permute_dim):
return x.permute(permute_dim)
[docs] def view(self, x, view_dim):
x = x.contiguous().view(view_dim)
return x
[docs] def scatter(self, x, indices):
return torch.scatter(x, dim=0, index=indices)
[docs] def conv1d(self, x, kernel):
return torch.conv1d(x, kernel)
[docs] def conv_trans1d(self, x, kernel, bias=None):
return torch.conv_transpose1d(x, kernel, bias)
[docs] def conv_2d(self, x, kernel, stride, padding, dilation, groups, bias=None, padding_mode='constant'):
if x.dim() == kernel.dim() + 1:
xshape = list(x.shape)
xshape[0] = xshape[0] * xshape[1]
extend_size = xshape[1]
xshape.pop(1)
out = fn.conv2d(x.reshape(xshape), kernel, bias=bias, stride=stride, padding=padding, dilation=dilation,
groups=groups,
padding_mode=padding_mode)
outshape = list(out.shape)
outshape[0] = outshape[0] // extend_size
outshape.insert(1, extend_size)
return out.view(outshape)
else:
return fn.conv2d(x, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
[docs] def conv_2d_complex(self, x, kernel, stride, padding, dilation, groups, beta, bias=None, delay=None):
if x.dtype.is_complex:
if delay is not None:
d_delay = delay / self.dt
d_delay = torch.ceil(d_delay) - d_delay
x = beta ** (x.imag + d_delay) * (x.real * (0 + 1.0j))
else:
x = beta ** x.imag * (x.real * (0 + 1.0j))
else:
x = x * (0 + 1.0j)
real = fn.conv2d(x.real, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
imag = fn.conv2d(x.imag, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
return torch.complex(real, imag)
[docs] def conv_trans2d(self, x, kernel, stride=1, padding=0, dilation=0, groups=1):
return torch.conv_transpose2d(x, kernel, stride=stride, padding=padding, dilation=dilation, groups=groups)
[docs] def conv_max_pool2d(self, x, kernel, pool_kernel, stride, pool_stride, padding, pool_padding, dilation, groups):
return fn.max_pool2d(fn.conv2d(x, kernel, stride=stride, padding=padding,
dilation=dilation, groups=groups), kernel_size=pool_kernel,
stride=pool_stride, padding=pool_padding)
# return fn.conv2d(fn.max_pool2d(x, int(max_kernel[0])), kernel, stride=int(stride), padding=int(padding), dilation=int(dilation), groups=int(groups))
[docs] def conv_avg_pool2d(self, x, kernel, pool_kernel, stride, pool_stride, padding, pool_padding, dilation, groups):
return fn.avg_pool2d(fn.conv2d(x, kernel, stride=stride, padding=padding,
dilation=dilation, groups=groups), kernel_size=pool_kernel,
stride=pool_stride, padding=pool_padding)
[docs] def conv_add_bias(self, x, bias):
bias_t = bias.repeat(x.shape[-2], x.shape[-1], 1).permute(2, 1, 0)
return x + bias_t
[docs] def max_pool2d(self, x, pool_kernel, pool_stride, pool_padding):
return fn.max_pool2d(x, kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding)
[docs] def post_max_pool2d_complex(self, x, pool_kernel, pool_stride, pool_padding):
pool_imag, pool_index = fn.max_pool2d(x.imag, kernel_size=pool_kernel, return_indices=True,
stride=pool_stride, padding=pool_padding)
x_shape = x.shape
pool_shape = pool_index.shape
pool_real = torch.gather(x.real.view(x_shape[0], x_shape[1], -1), dim=-1,
index=pool_index.view(x_shape[0], x_shape[1], -1)).view(pool_shape[0], pool_shape[1],
pool_shape[2], pool_shape[3])
return torch.complex(real=pool_real, imag=pool_imag)
[docs] def avg_pool2d(self, x, pool_kernel, pool_stride, pool_padding):
return fn.avg_pool2d(x, kernel_size=pool_kernel, stride=pool_stride, padding=pool_padding)
[docs] def batchnorm2d(self, x, num_features):
# 该实现方式忽略了running_mean 和 running_var
# 当 batch_size 较小时,在推理阶段的统计特性就会和全局统计特性有着较大偏差,从而导致糟糕的效果,这种情况下推荐使用
# SPAIC 的 Module 模块实现 batchnorm。
device = x.device
bn_2d = torch.nn.BatchNorm2d(num_features).to(device=device)
return bn_2d(x)
[docs] def dropout(self, x, p, inplace=False):
return fn.dropout(x, p=p, inplace=inplace)
[docs] def reshape_mat_mult(self, A, X):
if A.dim() == 4:
(batchsize, outchannels, H, W) = A.shape
A = A.view(batchsize, -1)
elif A.dim() == 5:
(batchsize, extend, outchannels, H, W) = A.shape
A = A.view(batchsize, extend, -1)
return torch.matmul(A, X.permute(1, 0))
[docs] def im2col_indices(self, x, kh, kw, padding, stride):
return fn.unfold(x, (kh, kw), padding=padding, stride=stride)
[docs] def conv2d_flatten(self, x):
return x.view(x.shape[0], x.shape[1], -1)
[docs] def feature_map_flatten(self, x):
return x.view(x.shape[0], -1)
[docs] def add(self, x, y):
return x + y
[docs] def minus(self, x, y):
return x - y
[docs] def div(self, x, y):
return torch.div(x, y)
[docs] def relu(self, x):
return torch.relu(x)
[docs] def sigmoid(self, x):
return torch.sigmoid(x)
[docs] def mat_mult_weight(self, A, X):
'''
Parameters
----------
A--->preGroup:input
X--->postGroup:weight
Returns
-------
'''
X = X.permute(1, 0)
return torch.matmul(A, X)
[docs] def mat_mult_weight_complex(self, A, X, beta, delay=None):
'''
Parameters
----------
A--->preGroup:input
X--->postGroup:weight
beta---> postGroup:beta_complex
Returns
-------
'''
if A.dtype.is_complex:
beta = beta.unsqueeze(-1)
if delay is not None:
A = A.permute(0, 2, 1)
real = A.real
imag = A.imag
d_delay = delay.unsqueeze(0) / self.dt
d_delay = torch.ceil(d_delay) - d_delay
O = beta ** (imag + d_delay) * (real * (0 + 1.0j))
else:
A = A.unsqueeze(-2)
real = A.real
imag = A.imag
O = beta ** imag * (real * (0 + 1.0j))
# if torch.any(torch.isnan(O)):
# print("real:", real)
# print("real:", real)
# print("imag:", imag)
# print("O:", O)
# raise ValueError(" nan mat_mult_complex error")
return torch.sum(O * X, dim=-1)
elif delay is not None:
A = A.permute(0, 2, 1)
return torch.sum(A * X, dim=-1) * (0.0 + 1.0j)
else:
X = X.permute(1, 0)
Out = torch.matmul(A.to(X.dtype), X)
Out = Out * (0.0 + 1.0j)
# if torch.any(torch.isnan(Out)):
# print("input:", A)
# print("weight:", X)
# print("Out:", Out)
# raise ValueError(" nan mat_mult_complex error")
return Out
[docs] def mat_mult_weight_2complex(self, A, X, beta, delay=None):
if A.dtype.is_complex:
A = A.unsqueeze(-2).unsqueeze(-1)
beta = beta.unsqueeze(-1)
real = A.real
imag = A.imag
O = torch.sum((beta ** imag * (real * (0 + 1.0j))) * X, dim=-2)
return O
else:
A = A.unsqueeze(-2).unsqueeze(-1)
O = torch.sum(A * X, dim=-2)
return O
[docs] def mat_mult_pre(self, A, X):
'''
Parameters
----------
A--->preGroup:input
X--->postGroup:weight
Returns
-------
'''
A = A.permute(1, 0)
return torch.matmul(A, X)
[docs] def mat_mult(self, A, X):
'''
Parameters
----------
A--->preGroup:input
X--->postGroup:weight
Returns
-------
'''
return torch.matmul(A, X)
[docs] def bmm(self, A, X):
'''
Parameters
----------
A---> postGroup
X---> preGroup
Returns
-------
'''
return torch.bmm(A, X)
[docs] def ger(self, A, X):
'''
Parameters
----------
A---> postGroup
X---> preGroup
Returns
-------
'''
return torch.ger(A, X)
[docs] def sparse_mat_mult_weight(self, A, X):
'''
Parameters
----------
A--->preGroup:sparseWeight(post, pre)
X--->postGroup:input(batch, pre)
Returns
-------
'''
X = X.permute(1, 0)
result = torch.sparse.mm(A, X)
result = result.permute(1, 0)
return result
[docs] def var_mult(self, A, X):
return A * X
[docs] def mult_sum_weight(self, A, X):
# X = X.permute(1, 0)
# A = A.permute(0, 2, 1)
return torch.sum(A * X, dim=-1)
[docs] def mat_linear(self, A, X, b):
return torch.matmul(A, X) + b
[docs] def var_linear(self, A, X, b):
return A * X + b
[docs] def unsqueeze(self, X, dim):
return torch.unsqueeze(X, dim)
[docs] def to_numpy(self, data: torch.Tensor):
return data.detach().cpu().numpy()
[docs] def to_tensor(self, data):
if isinstance(data, torch.Tensor):
return data.to(torch.float).to(self.device)
else:
return torch.tensor(data, dtype=torch.float, device=self.device)
[docs] def upsample(self, x, scale):
return torch.nn.functional.interpolate(x, scale_factor=scale, mode='nearest')
[docs] def exp(self, x):
return torch.exp(x)
[docs] def clamp_(self, data, min, max):
with torch.no_grad():
data.clamp_(min, max)
[docs] def clamp_max_(self, data, max):
with torch.no_grad():
data.clamp_max_(max)
[docs] def clamp_min_(self, data, min):
with torch.no_grad():
data.clamp_min_(min)
[docs] def normal(self, data, mean=0.0, std=1.0):
'''
Args:
data(tensor): an n-dimensional torch.Tensor
mean(float): the mean of the normal distribution
std(float): the standard deviation of the normal distribution
Returns:
torch.nn.init.normal_(data, mean=0.0, std=1.0)
'''
return torch.nn.init.normal_(data, mean, std)
[docs] def xavier_normal(self, data, gain=1.0):
'''
Args:
data(tensor): an n-dimensional torch.Tensor
gain: an optional scaling factor
Returns:
torch.nn.init.xavier_normal_(data, gain=1.0)
'''
return torch.nn.init.xavier_normal_(data, gain)
[docs] def kaiming_normal(self, data, a=0, mode='fan_in', nonlinearity='leaky_relu'):
'''
Args:
data(tensor): an n-dimensional torch.Tensor
a: the negative slope of the rectifier used after this layer (only used with 'leaky_relu')
mode: either 'fan_in' (default) or 'fan_out'. Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass.
nonlinearity: the non-linear function (nn.functional name), recommended to use only with 'relu' or 'leaky_relu' (default).
Returns:
torch.nn.init.kaiming_normal_(data, a=0, mode='fan_in', nonlinearity='leaky_relu')
'''
return torch.nn.init.kaiming_normal_(data, a, mode, nonlinearity)
[docs] def constant(self, data, constant_value=0.0):
'''
Args:
data(tensor): an n-dimensional torch.Tensor
constant_value(float): the value to fill the tensor with
Returns:
torch.nn.init.constant_(data, constant_value)
'''
return torch.nn.init.constant_(data, constant_value)
[docs] def sparse(self, data, sparsity=0.1, std=0.01):
'''
Args:
data(tensor): an n-dimensional `torch.Tensor`
sparsity(float): The fraction of elements in each column to be set to zero
std(float): the standard deviation of the normal distribution used to generate
the non-zero values
Returns:
torch.nn.init.sparse_(data, sparsity, std)
'''
return torch.nn.init.sparse_(data, sparsity, std)
[docs] def weight_norm(self, weight, amp):
w_norm = torch.norm(weight, p=2, dim=1, keepdim=True)
# print(amp.item(), w_norm.item())
return weight * amp / w_norm
# TODO: THis "TO" should be named to_device
[docs] def to(self, x, device):
return x.to(device)
[docs] def sin(self, x):
return torch.sin(x)
[docs] def cos(self, x):
return torch.cos(x)
[docs] def tan(self, x):
return torch.tan(x)
[docs] def log(self, x):
return torch.log(x)
[docs] def log2(self, x):
return torch.log2(x)
[docs] def log10(self, x):
return torch.log10(x)
# def reset(self, x, v_reset, u_reset, spike):
#
# # if hasattr(x, "__len__"):
# # if x.shape != spike.shape:
# # raise ValueError("%s and %s do not match" % (x.shape, spike.shape))
# mask = torch.eq(spike, 1)
# x[mask] = v_reset
# x[mask] += u_reset
# return x
# def izh_v(self, v, u, psp):
# v = v+self.dt*(0.04*v*v+5*v+140-u+psp)
# return v
#
# def izh_u(self, a, b, v, u):
# u = u+self.dt*a*(b*v-u)
# return u
backends[Torch_Backend.backend_name] = Torch_Backend
# test = Torch_Backend()
# th = test.basic_operate['threshold']
# print(th(-1.0))