From f04916b329499f3fa0fb654a1a1512e2d04872d7 Mon Sep 17 00:00:00 2001 From: paillarj Date: Wed, 31 Jul 2024 15:51:11 +0200 Subject: [PATCH] add G2 model example --- research_code/example_wo_wav.ipynb | 253 ++++++++++++++++++++++++++++- 1 file changed, 246 insertions(+), 7 deletions(-) diff --git a/research_code/example_wo_wav.ipynb b/research_code/example_wo_wav.ipynb index e0d008d..f369978 100644 --- a/research_code/example_wo_wav.ipynb +++ b/research_code/example_wo_wav.ipynb @@ -11,17 +11,18 @@ "import torch\n", "\n", "from research_code.pl_utils import get_green_g2, GreenClassifierLM\n", + "from research_code.crossval_utils import pl_crossval\n", "\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Create a dummy dataset\n", - "n = 10 # subjects\n", + "n = 50 # subjects\n", "f = 4 # filterbank size\n", "c = 3 # channels \n", "\n", @@ -30,19 +31,49 @@ "V = np.random.randn(n, f, c, c)\n", "spd = V @ diag @ np.transpose(V, (0, 1, 3, 2))\n", "\n", - "y = np.random.randint(0, 2, n)\n", + "X = torch.Tensor(spd).to(torch.float32)\n", + "y = torch.Tensor(np.random.randint(2, size=(n, 2))).to(torch.float32)\n", "\n", - "dataset = TensorDataset(torch.Tensor(spd), torch.Tensor(y))" + "dataset = TensorDataset(X, y)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "GreenClassifierLM(\n", + " (model): GreenG2(\n", + " (spd_layers): Sequential(\n", + " (0): LedoitWold(n_freqs=4, init_shrinkage=-3.0, learnable=True)\n", + " (1): BiMap(d_in=3, d_out=2, n_freqs=4\n", + " )\n", + " (proj): LogEig(ref=logeuclid, reg=0.0001, n_freqs=4, size=2\n", + " (head): Sequential(\n", + " (0): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (1): Dropout(p=0.5, inplace=False)\n", + " (2): Linear(in_features=12, out_features=8, bias=True)\n", + " (3): GELU(approximate='none')\n", + " (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): Dropout(p=0.5, inplace=False)\n", + " (6): Linear(in_features=8, out_features=2, bias=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "model = get_green_g2(\n", - " n_ch=3,\n", + " n_ch=c,\n", + " n_freqs=f,\n", " orth_weights=True,\n", " dropout=.5,\n", " hidden_dim=[8],\n", @@ -54,6 +85,214 @@ "model_pl = GreenClassifierLM(model=model,)\n", "model_pl" ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: False, used: False\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", + "c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "719614b4f52946e58e5f9f8a057216f5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Finding best initial lr: 0%| | 0/20 [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_loss 0.0 │\n", + "│ test_score 0.0 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pl_crossval_output, _ = pl_crossval(\n", + " model, \n", + " dataset=dataset,\n", + " n_epochs=2,\n", + " save_preds=True,\n", + " ckpt_prefix='checkpoints/test',\n", + " train_splits=[[0,1]],\n", + " test_splits=[[2]],\n", + " batch_size=4,\n", + " pl_module=GreenClassifierLM,\n", + " num_workers=0, \n", + ")\n" + ] } ], "metadata": {