Skip to content

Commit

Permalink
add G2 model example
Browse files Browse the repository at this point in the history
  • Loading branch information
paillarj committed Jul 31, 2024
1 parent cce678e commit f04916b
Showing 1 changed file with 246 additions and 7 deletions.
253 changes: 246 additions & 7 deletions research_code/example_wo_wav.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
"import torch\n",
"\n",
"from research_code.pl_utils import get_green_g2, GreenClassifierLM\n",
"from research_code.crossval_utils import pl_crossval\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Create a dummy dataset\n",
"n = 10 # subjects\n",
"n = 50 # subjects\n",
"f = 4 # filterbank size\n",
"c = 3 # channels \n",
"\n",
Expand All @@ -30,19 +31,49 @@
"V = np.random.randn(n, f, c, c)\n",
"spd = V @ diag @ np.transpose(V, (0, 1, 3, 2))\n",
"\n",
"y = np.random.randint(0, 2, n)\n",
"X = torch.Tensor(spd).to(torch.float32)\n",
"y = torch.Tensor(np.random.randint(2, size=(n, 2))).to(torch.float32)\n",
"\n",
"dataset = TensorDataset(torch.Tensor(spd), torch.Tensor(y))"
"dataset = TensorDataset(X, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"GreenClassifierLM(\n",
" (model): GreenG2(\n",
" (spd_layers): Sequential(\n",
" (0): LedoitWold(n_freqs=4, init_shrinkage=-3.0, learnable=True)\n",
" (1): BiMap(d_in=3, d_out=2, n_freqs=4\n",
" )\n",
" (proj): LogEig(ref=logeuclid, reg=0.0001, n_freqs=4, size=2\n",
" (head): Sequential(\n",
" (0): BatchNorm1d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (1): Dropout(p=0.5, inplace=False)\n",
" (2): Linear(in_features=12, out_features=8, bias=True)\n",
" (3): GELU(approximate='none')\n",
" (4): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (5): Dropout(p=0.5, inplace=False)\n",
" (6): Linear(in_features=8, out_features=2, bias=True)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = get_green_g2(\n",
" n_ch=3,\n",
" n_ch=c,\n",
" n_freqs=f,\n",
" orth_weights=True,\n",
" dropout=.5,\n",
" hidden_dim=[8],\n",
Expand All @@ -54,6 +85,214 @@
"model_pl = GreenClassifierLM(model=model,)\n",
"model_pl"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "719614b4f52946e58e5f9f8a057216f5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Finding best initial lr: 0%| | 0/20 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n",
"LR finder stopped early after 2 steps due to diverging loss.\n",
"Failed to compute suggestion for learning rate because there are not enough points. Increase the loop iteration limits or the size of your dataset/dataloader.\n",
"Restoring states from the checkpoint path at checkpoints\\test\\fold0\\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt\n",
"Restored all states from the checkpoint at checkpoints\\test\\fold0\\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------\n",
"0 | model | GreenG2 | 244 \n",
"----------------------------------\n",
"190 Trainable params\n",
"54 Non-trainable params\n",
"244 Total params\n",
"0.001 Total estimated model params size (MB)\n",
"Restored all states from the checkpoint at checkpoints\\test\\fold0\\.lr_find_53d79eea-094a-4aad-8615-13ab86532a9f.ckpt\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e48324e5e5274902a5eb65ec0f1f660e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "39e8587dce004d2582590d46678d8fde",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c58f08e0161f4949be7ef2a536b66b24",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\sklearn\\metrics\\_classification.py:2394: UserWarning: y_pred contains classes not in y_true\n",
" warnings.warn(\"y_pred contains classes not in y_true\")\n",
"`Trainer.fit` stopped: `max_epochs=2` reached.\n",
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1ce1ab56632f4aa5ba0ed8d0843e9271",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Predicting: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\sklearn\\metrics\\_classification.py:2394: UserWarning: y_pred contains classes not in y_true\n",
" warnings.warn(\"y_pred contains classes not in y_true\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"pred_acc = 0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3da57d1058524c7b9ec6bf6b52eebe79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Testing: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\paillarj\\AppData\\Local\\anaconda3\\envs\\riemann\\lib\\site-packages\\sklearn\\metrics\\_classification.py:2394: UserWarning: y_pred contains classes not in y_true\n",
" warnings.warn(\"y_pred contains classes not in y_true\")\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> test_score </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n",
"└───────────────────────────┴───────────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36m test_score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n",
"└───────────────────────────┴───────────────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"pl_crossval_output, _ = pl_crossval(\n",
" model, \n",
" dataset=dataset,\n",
" n_epochs=2,\n",
" save_preds=True,\n",
" ckpt_prefix='checkpoints/test',\n",
" train_splits=[[0,1]],\n",
" test_splits=[[2]],\n",
" batch_size=4,\n",
" pl_module=GreenClassifierLM,\n",
" num_workers=0, \n",
")\n"
]
}
],
"metadata": {
Expand Down

0 comments on commit f04916b

Please sign in to comment.