-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimpleNeuralModel.py
More file actions
138 lines (117 loc) · 6.1 KB
/
simpleNeuralModel.py
File metadata and controls
138 lines (117 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# simpleNeuralModel.py
#
# Bryan Daniels
# 2023/8/25 branched from CAS-503-Collectives/neural
# 2021/9/10
#
# A simple model of neural dynamics. This is equivalent to the model found in the
# following reference:
# Daniels, Bryan C., Jessica C. Flack, and David C. Krakauer.
# “Dual Coding Theory Explains Biphasic Collective Computation in Neural
# Decision-Making.” Frontiers in Neuroscience 11, 1–16 (2017).
# https://doi.org/10.3389/fnins.2017.00313
#
import numpy as np
import scipy.optimize as opt
import pandas as pd
def simpleNeuralDynamics(weightMatrix,inputExt=0,noiseVar=1,
tFinal=10,deltat=1e-3,initialState=None,nonlinearity=np.tanh,sigma=1):
"""
Simulates the following stochastic process:
dx_i / dt = inputExt - x_i + sum_j weightMatrix_{i,j} tanh(x_j/sigma_{i,j}) + xi
where xi is uncorrelated Gaussian noise with variance 'noiseVar' per unit time.
Time is discretized into units of deltat, and the simulation is run until time tFinal.
weightMatrix : (N x N) matrix indicating the synaptic strength from
neuron j to neuron i
inputExt (0) : If given a constant or list of length N, add this as
a constant external input.
If given an array of shape (# timepoints)x(N),
add external current as an input that
varies over time. (# timepoints = t_final/delta_t)
initialState (None) : If given a list of length N, start the system in the
given state. If None, initial state defaults to
all zeros.
nonlinearity (np.tanh) : A function taking neural states x to synaptic currents
sigma (1) : Parameter defining the scale over which the nonlinear
function acts. Given a single number, this is treated
as a constant over all interactions. Given a matrix
of shape (N x N), this specifies sigma individually for
each interaction.
"""
N = len(weightMatrix)
# make sure the weight matrix is square
assert(len(weightMatrix[0])==N)
# make sure sigma has the right shape
assert(np.shape(sigma) == () or np.shape(sigma) == (N,N) )
# set up the initial state
if initialState is None:
initialState = np.zeros(N)
# make sure the initial state has the correct length
assert(len(initialState)==N)
# set up the simulation times and a list to hold the simulated steps
times = np.arange(0,tFinal+deltat,deltat)
stateList = [initialState,]
# set up the external input, possibly varying as a function of time
if np.shape(inputExt) == (len(times)-1,N):
# input is varying in time
inputExtVsT = inputExt
elif np.shape(inputExt) == () or np.shape(inputExt) == (N,):
# input is constant in time
inputExtVsT = [ inputExt for t in times[:-1] ]
else:
raise Exception("Unrecognized form of inputExt")
# run the simulation (we already have the state for t=0)
for time,inputCurrent in zip(times[1:],inputExtVsT):
currentState = stateList[-1]
# compute deltax for current timestep
if np.shape(sigma) == (N,N):
scaledStates = np.tile(currentState,(N,1))/sigma
synapticCurrent = np.sum(weightMatrix * nonlinearity(scaledStates),axis=1)
else: # faster calculation for constant sigma
synapticCurrent = np.dot(weightMatrix,nonlinearity(currentState/sigma))
deterministicPart = deltat*( inputCurrent - currentState + synapticCurrent )
stochasticPart = np.sqrt(deltat*noiseVar)*np.random.normal(size=N)
deltax = deterministicPart + stochasticPart
# update to find the new state
newState = currentState + deltax
# record the new state
stateList.append(newState)
# return simulation output as a pandas dataframe
df = pd.DataFrame(stateList,index=times,columns=['Neuron {}'.format(i) for i in range(N)])
df.index.set_names('Time',inplace=True)
return df
def allToAllNetworkAdjacency(N):
return 1 - np.eye(N)
def findFixedPoint(weightMatrix,initialGuessState,inputExt=0,nonlinearity=np.tanh,sigma=1):
"""
Find a fixed point of the deterministic part of dynamics
"""
N = len(weightMatrix)
# make sure the input is either a simple number or length-N
assert(np.shape(inputExt)==() or np.shape(inputExt)==(N,))
if np.shape(sigma) == (N,N):
deterministicDeltaX = lambda x: inputExt - x + np.sum(weightMatrix * nonlinearity(np.tile(x,(N,1))/sigma),axis=1)
else: # simpler in case of constant sigma
deterministicDeltaX = lambda x: inputExt - x + np.dot(weightMatrix,nonlinearity(x/sigma))
sol = opt.root(deterministicDeltaX,initialGuessState)
return sol.x
def findFixedPoints(weightMatrix,inputExt=0,useMeanField=True,startMin=-10,
startMax=10,numToTest=100):
"""
look for all fixed points nearby a set of starting points
"""
N = len(weightMatrix)
# make sure the input is either a simple number or length-N
assert(np.shape(inputExt)==() or np.shape(inputExt)==(N,))
fixedPointList = []
if useMeanField and np.mean(np.sum(weightMatrix,axis=0)) > 1.:
xMF = 2.*np.sqrt(np.mean(np.sum(weightMatrix,axis=0))-1.)
startingPoints = [-xMF,0.,+xMF]
else:
startingPoints = list(np.linspace(startMin,startMax,numToTest))
for startingPoint in startingPoints:
initialGuessState = startingPoint*np.ones(N)
fixedPoint = findFixedPoint(weightMatrix,initialGuessState,inputExt=inputExt)
fixedPointList.append(fixedPoint)
uniqueFixedPoints = np.unique(np.round(fixedPointList,5),axis=0)
return pd.DataFrame(uniqueFixedPoints,columns=['Neuron {}'.format(i) for i in range(N)])