Skip to content

Commit

Permalink
calculate tm-loss
Browse files Browse the repository at this point in the history
  • Loading branch information
Sourmpis committed Oct 10, 2023
1 parent dd19b8f commit 4e8975d
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 6,608 deletions.
34 changes: 29 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ Contact:

## Glossary
1) [Installation](#Installation)
2) Download recorded data from Esmaeli et al. 2021
3) Generate simpler artificial data
4) Load pre-trained models
5) Code snippet for computing the trial-matching loss function
6) Generate paper figures
2) [Download recorded data from Esmaeli et al. 2021](#download-recorded-data)
3) [Generate simpler artificial data](#generate-artificial-data)
4) [Load pre-trained models](#pre-trained-models)
5) [Code snippet for computing the trial-matching loss function](#compute-the-trial-matching-loss-function)
6) [Generate paper figures](#generate-figures-from-a-pre-trained-model)
7) [Training the RSNN](#training-the-rsnn-model)

## Installation
Expand Down Expand Up @@ -80,6 +80,30 @@ with torch.no_grad():

## Compute the trial matching loss-function

Calculate the trial-matching loss with the hard matching (Hungarian Algorithm)

Args:


* filt_data_spikes (torch.tensor): $\mathcal{T}_{trial}(z^\mathcal{D})$, with dimension: K x T

* filt_model_spikes (torch.tensor): $\mathcal{T}_{trial}(z)$, with dimension K' x T

```python
def hard_trial_matching_loss(filt_data_spikes, filt_model_spikes):
# subsample the biggest tensor, so both data and model have the same #trials
min_trials = min(filt_model_spikes.shape[0], filt_data_spikes.shape[0])
filt_data_spikes = filt_data_spikes[:min_trials] # shape: K x T (assuming K = min(K,K'))
filt_model_spikes = filt_model_spikes[:min_trials] # shape: K x T
with torch.no_grad():
cost = mse_2d(filt_model_spikes.T, filt_data_spikes.T) # shape: K x K
keepx, ytox = linear_sum_assignment(cost.detach().cpu().numpy()) # keepx and ytox are trial indices
return torch.nn.MSELoss()(filt_model_spikes[keepx], filt_data_spikes[ytox])
```

The function above is in the `infopath/losses.py` file.

You can explore the loss function in a simple demo in [trial_matching_loss_demo.ipynb](trial_matching_loss_demo.ipynb).


## Generate figures from a pre-trained model
Expand Down
8 changes: 4 additions & 4 deletions infopath/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def firing_rate_loss(firing_rate, model_spikes, trial_onset, timestep):
return ((model_fr - data_fr) ** 2).mean()


def population_average_loss(filt_data_spikes, filt_model_spikes):
def hard_trial_matching_loss(filt_data_spikes, filt_model_spikes):
"""Calculate the trial-matching loss with the hard matching (Hungarian Algorithm)
Args:
filt_data_spikes (torch.tesnor): $\mathcal{T}_{trial}(z^\mathcal{D})$, dim Trials x Time
filt_model_spikes (torch.tesnor): $\mathcal{T}_{trial}(z)$, dim Trials x Time
filt_data_spikes (torch.tensor): $\mathcal{T}_{trial}(z^\mathcal{D})$, dim Trials x Time
filt_model_spikes (torch.tensor): $\mathcal{T}_{trial}(z)$, dim Trials x Time
"""
# downsample the biggest tensor, so both data and model have the same #trials
min_trials = min(filt_model_spikes.shape[0], filt_data_spikes.shape[0])
Expand Down Expand Up @@ -234,7 +234,7 @@ def trial_matching_loss(
session_info (list): information about different sessions
data_jaw (torch.tensor): filtered data jaw trace
model_jaw (torch.tensor): filtered model jaw trace
loss_fun (function): function either population_average_loss or sinkhorn loss
loss_fun (function): function either hard_trial_matching_loss or sinkhorn loss
area_index (torch.tensor): tensor with areas
z_score (bool, optional): whether to z_score or not. Defaults to True.
trial_loss_area_specific (bool, optional): Whether the T_trial is area specific. Defaults to True.
Expand Down
2 changes: 1 addition & 1 deletion infopath/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, opt):
if opt.geometric_loss:
self.trial_loss_fun = SamplesLoss(loss="sinkhorn", p=1, blur=0.01)
else:
self.trial_loss_fun = population_average_loss
self.trial_loss_fun = hard_trial_matching_loss
# setup the loss splitter
task_splitters = opt.loss_trial_wise
task_splitters += opt.loss_neuron_wise
Expand Down
Loading

0 comments on commit 4e8975d

Please sign in to comment.