Skip to content

Commit

Permalink
Code for revised paper 2 and for additional investigations in thesis
Browse files Browse the repository at this point in the history
  • Loading branch information
luboeinski committed Feb 24, 2022
1 parent f993476 commit 89abbff
Show file tree
Hide file tree
Showing 208 changed files with 5,243 additions and 718 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.out
*pycache*/
30 changes: 30 additions & 0 deletions BIBTEX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## BibTeX code for publications

@article{LuboeinskiTetzlaff2021a,
title={Memory consolidation and improvement by synaptic tagging and capture in recurrent neural networks},
author={Luboeinski, Jannik and Tetzlaff, Christian},
journal={Communications Biology},
volume={4},
number={275},
year={2021},
publisher={Nature Publishing Group},
doi={10.1038/s42003-021-01778-y}
}

@article{LuboeinskiTetzlaff2021b,
title={Organization and priming of long-term memory representations with two-phase plasticity},
author={Luboeinski, Jannik and Tetzlaff, Christian},
journal={bioRxiv preprint},
year={2021},
publisher={Cold Spring Harbor Laboratory},
doi={10.1101/2021.04.15.439982}
}

@phdthesis{Luboeinski2021thesis,
title={The Role of Synaptic Tagging and Capture for Memory Dynamics in Spiking Neural Networks},
author={Luboeinski, Jannik},
year={2021},
school={University of G\"{o}ttingen},
type={Dissertation},
url={http://hdl.handle.net/21.11130/00-1735-0000-0008-58f8-e}
}
180 changes: 108 additions & 72 deletions README.md

Large diffs are not rendered by default.

157 changes: 105 additions & 52 deletions analysis/adjacencyFunctions.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
### in a network that contains multiple cell assemblies ###
#######################################################################################

### Copyright 2019-2021 Jannik Luboeinski
### Copyright 2019-2022 Jannik Luboeinski
### licensed under Apache-2.0 (http://www.apache.org/licenses/LICENSE-2.0)
### Contact: jannik.lubo[at]gmx.de

import numpy as np
import sys
from utilityFunctions import cond_print

h_0 = 0.420075 # nC, initial synaptic weight and normalization factor for z

np.set_printoptions(precision=8, threshold=1e10, linewidth=200)
epsilon = 1e-9

# loadWeightMatrix
# Loads complete weight matrix from a file (only for excitatory neurons, though)
# filename: name of the file to read the data from
# N_pop: the number of neurons in the considered population
# h_0: initial synaptic weight and normalization factor for z
# return: the adjacency matrix, the early-phase weight matrix, the late-phase weight matrix, the firing rate vector
def loadWeightMatrix(filename, N_pop):
def loadWeightMatrix(filename, N_pop, h_0):
global h
global z
global adj
Expand Down Expand Up @@ -125,7 +125,7 @@ def outgoingConnections(i, pr = True):
# Prints and returns all the early-phase synaptic weights incoming to neuron i
# i: neuron index
# pr [optional]: specifies if result is printed
# return: array of early-phase weights in units of nC
# return: array of early-phase weights
def incomingEarlyPhaseWeights(i, pr = True):
global adj
global h
Expand All @@ -140,7 +140,7 @@ def incomingEarlyPhaseWeights(i, pr = True):
# Prints and returns all the early-phase synaptic weights outgoing from neuron i
# i: neuron index
# pr [optional]: specifies if result is printed
# return: array of early-phase weights in units of nC
# return: array of early-phase weights
def outgoingEarlyPhaseWeights(i, pr = True):
global adj
global h
Expand All @@ -155,7 +155,7 @@ def outgoingEarlyPhaseWeights(i, pr = True):
# Prints and returns all the early-phase synaptic weights incoming to neuron i from a given set of neurons
# i: neuron index
# set: the set of presynaptic neurons
# return: array of early-phase weights in units of nC
# return: array of early-phase weights
def earlyPhaseWeightsFromSet(i, set):
global adj
global h
Expand All @@ -169,7 +169,7 @@ def earlyPhaseWeightsFromSet(i, set):
# Prints and returns all the late-phase synaptic weights incoming to neuron i from a given set of neurons
# i: neuron index
# set: the set of presynaptic neurons
# return: array of late-phase weights in units of nC
# return: array of late-phase weights
def latePhaseWeightsFromSet(i, set):
global adj
global z
Expand All @@ -185,7 +185,7 @@ def latePhaseWeightsFromSet(i, set):
# set: the first set of neurons (presynaptic)
# set2 [optional]: the seconds set of neurons (postsynaptic); if not specified, connections within "set" are considered
# pr [optional]: specifies if result shall be printed
# return: early-phase weight in units of nC
# return: early-phase weight
def meanEarlyPhaseWeight(set, set2 = None, pr = True):
summed_weight = 0
connection_num = 0
Expand Down Expand Up @@ -218,7 +218,7 @@ def meanEarlyPhaseWeight(set, set2 = None, pr = True):
# set: the first set of neurons (presynaptic)
# set2 [optional]: the seconds set of neurons (postsynaptic); if not specified, connections within "set" are considered
# pr [optional]: specifies if result shall be printed
# return: early-phase weight in units of nC
# return: early-phase weight
def sdEarlyPhaseWeight(set, set2 = None, pr = True):
mean = meanEarlyPhaseWeight(set, set2, False)
summed_qu_dev = 0
Expand Down Expand Up @@ -248,7 +248,7 @@ def sdEarlyPhaseWeight(set, set2 = None, pr = True):
# set: the first set of neurons (presynaptic)
# set2 [optional]: the seconds set of neurons (postsynaptic); if not specified, connections within "set" are considered
# pr [optional]: specifies if result shall be printed
# return: late-phase weight in units of nC
# return: late-phase weight
def meanLatePhaseWeight(set, set2 = None, pr = True):
summed_weight = 0
connection_num = 0
Expand Down Expand Up @@ -286,7 +286,7 @@ def meanLatePhaseWeight(set, set2 = None, pr = True):
# set: the first set of neurons (presynaptic)
# set2 [optional]: the seconds set of neurons (postsynaptic); if not specified, connections within "set" are considered
# pr [optional]: specifies if result shall be printed
# return: late-phase weight in units of nC
# return: late-phase weight
def sdLatePhaseWeight(set, set2 = [], pr = True):
mean = meanLatePhaseWeight(set, set2, False)
summed_qu_dev = 0
Expand Down Expand Up @@ -318,46 +318,96 @@ def sdLatePhaseWeight(set, set2 = [], pr = True):
# coreB: array of indices of the second cell assembly (core) neurons
# coreC: array of indices of the third cell assembly (core) neurons
# N_pop: the number of neurons in the considered population
# h_0: initial synaptic weight and normalization factor for z
# pr [optional]: specifies if result shall be printed
def meanCoreWeights(ts, time_for_readout, coreA, coreB, coreC, N_pop, pr = True):

def meanCoreWeights(ts, time_for_readout, coreA, coreB, coreC, N_pop, h_0, pr = True):
cond_print(pr, "##############################################")
cond_print(pr, "At time", time_for_readout)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop, h_0)

# early-phase weights
f = open("cores_mean_early_weights.txt", "a")
f.write(time_for_readout + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "A -> A:")
hm_A = meanEarlyPhaseWeight(coreA, coreA, pr)
hsd_A = sdEarlyPhaseWeight(coreA, coreA, pr)
f.write(str(hm_A) + "\t\t")
f.write(str(hsd_A) + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "B -> B:")
hm_B = meanEarlyPhaseWeight(coreB, coreB, pr)
hsd_B = sdEarlyPhaseWeight(coreB, coreB, pr)
f.write(str(hm_B) + "\t\t")
f.write(str(hsd_B) + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "C -> C:")
hm_C = meanEarlyPhaseWeight(coreC, coreC, pr)
hsd_C = sdEarlyPhaseWeight(coreC, coreC, pr)
f.write(str(hm_C) + "\t\t")
f.write(str(hsd_C) + "\n")
f.close()

# late-phase weights
f = open("cores_mean_late_weights.txt", "a")
f.write(time_for_readout + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "A -> A:")
zm_A = meanLatePhaseWeight(coreA, coreA, pr)
zsd_A = sdLatePhaseWeight(coreA, coreA, pr)
f.write(str(zm_A) + "\t\t")
f.write(str(zsd_A) + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "B -> B:")
zm_B = meanLatePhaseWeight(coreB, coreB, pr)
zsd_B = sdLatePhaseWeight(coreB, coreB, pr)
f.write(str(zm_B) + "\t\t")
f.write(str(zsd_B) + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "C -> C:")
zm_C = meanLatePhaseWeight(coreC, coreC, pr)
zsd_C = sdLatePhaseWeight(coreC, coreC, pr)
f.write(str(zm_C) + "\t\t")
f.write(str(zsd_C) + "\n")
f.close()

# total weights
f = open("cores_mean_tot_weights.txt", "a")

f.write(time_for_readout + "\t\t")

cond_print(pr, "--------------------------------")
cond_print(pr, "A -> A:")
hm = meanEarlyPhaseWeight(coreA, coreA, pr)
hsd = sdEarlyPhaseWeight(coreA, coreA, pr)
zm = meanLatePhaseWeight(coreA, coreA, pr)
zsd = sdLatePhaseWeight(coreA, coreA, pr)
f.write(str(hm + zm) + "\t\t")
f.write(str(np.sqrt(hsd**2 + zsd**2)) + "\t\t")
cond_print(pr, "Mean total weight: " + str(hm + zm))

wm_A = hm_A + zm_A
wsd_A = np.sqrt(hsd_A**2 + zsd_A**2)
f.write(str(wm_A) + "\t\t")
f.write(str(wsd_A) + "\t\t")
cond_print(pr, "Mean total weight: " + str(wm_A))
cond_print(pr, "Std. dev. of total weight: " + str(wsd_A))

cond_print(pr, "--------------------------------")
cond_print(pr, "B -> B:")
hm = meanEarlyPhaseWeight(coreB, coreB, pr)
hsd = sdEarlyPhaseWeight(coreB, coreB, pr)
zm = meanLatePhaseWeight(coreB, coreB, pr)
zsd = sdLatePhaseWeight(coreB, coreB, pr)
f.write(str(hm + zm) + "\t\t")
f.write(str(np.sqrt(hsd**2 + zsd**2)) + "\t\t")
cond_print(pr, "Mean total weight: " + str(hm + zm))

wm_B = hm_B + zm_B
wsd_B = np.sqrt(hsd_B**2 + zsd_B**2)
f.write(str(wm_B) + "\t\t")
f.write(str(wsd_B) + "\t\t")
cond_print(pr, "Mean total weight: " + str(wm_B))
cond_print(pr, "Std. dev. of total weight: " + str(wsd_B))

cond_print(pr, "--------------------------------")
cond_print(pr, "C -> C:")
hm = meanEarlyPhaseWeight(coreC, coreC, pr)
hsd = sdEarlyPhaseWeight(coreC, coreC, pr)
zm = meanLatePhaseWeight(coreC, coreC, pr)
zsd = sdLatePhaseWeight(coreC, coreC, pr)
f.write(str(hm + zm) + "\t\t")
f.write(str(np.sqrt(hsd**2 + zsd**2)) + "\n")
cond_print(pr, "Mean total weight: " + str(hm + zm))

wm_C = hm_C + zm_C
wsd_C = np.sqrt(hsd_C**2 + zsd_C**2)
f.write(str(wm_C) + "\t\t")
f.write(str(wsd_C) + "\n")
cond_print(pr, "Mean total weight: " + str(wm_C))
cond_print(pr, "Std. dev. of total weight: " + str(wsd_C))
f.close()

# meanWeightMatrix
Expand All @@ -368,8 +418,9 @@ def meanCoreWeights(ts, time_for_readout, coreA, coreB, coreC, N_pop, pr = True)
# coreB: array of indices of the second cell assembly (core) neurons
# coreC: array of indices of the third cell assembly (core) neurons
# N_pop: the number of neurons in the considered population
# h_0: initial synaptic weight and normalization factor for z
# pr [optional]: specifies if result shall be printed
def meanWeightMatrix(ts, time_for_readout, coreA, coreB, coreC, N_pop, pr = True):
def meanWeightMatrix(ts, time_for_readout, coreA, coreB, coreC, N_pop, h_0, pr = True):
# define the whole considered population
all = np.arange(N_pop)

Expand Down Expand Up @@ -405,7 +456,7 @@ def meanWeightMatrix(ts, time_for_readout, coreA, coreB, coreC, N_pop, pr = True

cond_print(pr, "##############################################")
cond_print(pr, "At time", time_for_readout)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop, h_0)
f = open("mean_tot_weights_" + time_for_readout + ".txt", "w")
fsd = open("sd_tot_weights_" + time_for_readout + ".txt", "w")

Expand Down Expand Up @@ -1063,19 +1114,20 @@ def meanWeightMatrix(ts, time_for_readout, coreA, coreB, coreC, N_pop, pr = True


# printMeanWeightsSingleCA
# Computes and prints mean and standard deviation of CA, outgoing, incoming, and control weight in units of nC
# Computes and prints mean and standard deviation of CA, outgoing, incoming, and control weights
# ts: timestamp of the data file to read
# time_for_readout: the time of the data file to read
# N_pop: the number of neurons in the considered population
# core: array of indices of the cell assembly (core) neurons
def printMeanWeightsSingleCA(ts, time_for_readout, core, N_pop):
# N_pop: the number of neurons in the considered population
# h_0: initial synaptic weight and normalization factor for z
def printMeanWeightsSingleCA(ts, time_for_readout, core, N_pop, h_0):

all = np.arange(N_pop) # the whole considered population
noncore = all[np.logical_not(np.in1d(all, core))] # the neurons outside the core

print("##############################################")
print("At time", time_for_readout)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop)
loadWeightMatrix(ts + "_net_" + time_for_readout + ".txt", N_pop, h_0)

print("--------------------------------")
print("Core -> core ('CA'):")
Expand Down Expand Up @@ -1127,20 +1179,21 @@ def printMeanWeightsSingleCA(ts, time_for_readout, core, N_pop):
ts2 = str(sys.argv[2]) # timestamp for simulation data after consolidation

core = np.arange(150) # define the cell assembly core
N_pop = 1600
N_pop = 1600 # number of neurons in the population
h_0 = 4.20075 # initial/median synaptic weight

print("##############################################")
print("Before 10s-recall:")
printMeanWeightsSingleCA(ts1, "20.0", core, N_pop)
printMeanWeightsSingleCA(ts1, "20.0", core, N_pop, h_0)

print("##############################################")
print("After 10s-recall:")
printMeanWeightsSingleCA(ts1, "20.1", core, N_pop)
printMeanWeightsSingleCA(ts1, "20.1", core, N_pop, h_0)

print("##############################################")
print("Before 8h-recall:")
printMeanWeightsSingleCA(ts2, "28810.0", core, N_pop)
printMeanWeightsSingleCA(ts2, "28810.0", core, N_pop, h_0)

print("##############################################")
print("After 8h-recall:")
printMeanWeightsSingleCA(ts2, "28810.1", core, N_pop)
printMeanWeightsSingleCA(ts2, "28810.1", core, N_pop, h_0)
14 changes: 10 additions & 4 deletions analysis/analyzeWeights.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
### and mean weight within and between subpopulations) ###
#################################################################################

### Copyright 2020-2021 Jannik Luboeinski
### Copyright 2020-2022 Jannik Luboeinski
### licensed under Apache-2.0 (http://www.apache.org/licenses/LICENSE-2.0)
### Contact: jannik.lubo[at]gmx.de

### example call from shell: python3 analyzeWeights.py "Weight Distributions and Mean Weight Matrix" "OVERLAP10 no AC, no ABC"

Expand All @@ -17,6 +18,7 @@
##############################################################################################
### initialize
N_pop = 2500 # number of neurons in the considered population
h_0 = 4.20075 # initial/median synaptic weight
core_size = 600 # number of excitatory neurons in one cell assembly
MWM = False # specifies whether to create abstract mean weight matrix
MCW = False # specifies whether to create file with mean core weights
Expand Down Expand Up @@ -51,6 +53,7 @@
##############################################################################################
### look for network output files in this directory
rawpaths = Path(".")
timestamp = None

for x in sorted(rawpaths.iterdir()):

Expand All @@ -67,12 +70,15 @@
if WD:
print("Plotting weight distributions from dataset", timestamp, "with time", time_for_readout)
N_pop_row = int(round(np.sqrt(N_pop)))
vd.plotWeightDistributions3CAs(".", timestamp, "", N_pop_row, time_for_readout, coreA, coreB, coreC)
vd.plotWeightDistributions3CAs(".", timestamp, "", N_pop_row, h_0, time_for_readout, coreA, coreB, coreC)

if MWM:
print("Creating abstract mean weight matrix from dataset", timestamp, "with time", time_for_readout)
adj.meanWeightMatrix(timestamp, time_for_readout, coreA, coreB, coreC, N_pop, pr = True)
adj.meanWeightMatrix(timestamp, time_for_readout, coreA, coreB, coreC, N_pop, h_0, pr = True)

if MCW:
print("Computing mean core weights from dataset", timestamp, "with time", time_for_readout)
adj.meanCoreWeights(timestamp, time_for_readout, coreA, coreB, coreC, N_pop, pr = True)
adj.meanCoreWeights(timestamp, time_for_readout, coreA, coreB, coreC, N_pop, h_0, pr = True)
if timestamp is None:
print("No data found!")

Loading

0 comments on commit 89abbff

Please sign in to comment.