-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Martin Schubert
authored and
Martin Schubert
committed
Oct 28, 2023
1 parent
e5611eb
commit db8c4ec
Showing
4 changed files
with
474 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
name: CI | ||
|
||
on: | ||
pull_request: | ||
push: | ||
branches: | ||
- main | ||
schedule: | ||
- cron: "0 13 * * 1" # Every Monday at 9AM EST | ||
|
||
jobs: | ||
experiment: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout repository | ||
uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" | ||
cache: "pip" | ||
cache-dependency-path: pyproject.toml | ||
|
||
- name: Setup environment | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install ".[tests,dev]" | ||
- name: Run experiment | ||
run: | | ||
python scripts/experiment.py --workers=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fc33dfc4-383e-4b94-b857-c293d67d5f9f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import dataclasses\n", | ||
"import json\n", | ||
"import glob\n", | ||
"import time\n", | ||
"\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as onp\n", | ||
"from skimage import measure\n", | ||
"\n", | ||
"from totypes import json_utils\n", | ||
"\n", | ||
"from invrs_gym import challenges" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2ccd23b8-e946-4e6e-9a2a-c4aa9485f3ed", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Launch an experiment.\n", | ||
"!python ../scripts/experiment.py --path=\"../experiments/test_experiment\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0dcda0b3-17db-4fce-ba38-9d3d1f68d846", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Recover logged scalars and parameters.\n", | ||
"experiment_path = \"../experiments/test_experiment\"\n", | ||
"wid_paths = glob.glob(experiment_path + \"/*\")\n", | ||
"wid_paths.sort()\n", | ||
"\n", | ||
"scalars = {}\n", | ||
"hparams = {}\n", | ||
"params = {}\n", | ||
"\n", | ||
"for path in wid_paths:\n", | ||
" print(path)\n", | ||
" name = path.split(\"/\")[-1]\n", | ||
" checkpoint_fname = glob.glob(path + \"/checkpoint_*.json\")\n", | ||
" if not checkpoint_fname:\n", | ||
" continue\n", | ||
" checkpoint_fname.sort()\n", | ||
" with open(checkpoint_fname[-1], \"r\") as f:\n", | ||
" checkpoint = json_utils.pytree_from_json(f.read())\n", | ||
" scalars[name] = checkpoint[\"scalars\"]\n", | ||
" params[name] = checkpoint[\"params\"]\n", | ||
" with open(path + \"/setup.json\", \"r\") as f:\n", | ||
" hparams[name] = json.load(f)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "8b127996-77f0-47a3-8c41-3325db52f832", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot the efficiency trajectory, and the final, optimized and designs.\n", | ||
"\n", | ||
"plt.figure(figsize=(6, 9))\n", | ||
"for i, wid in enumerate(scalars.keys()):\n", | ||
" efficiency = scalars[wid][\"average_efficiency\"] * 100\n", | ||
" mask = scalars[wid][\"distance\"] <= 0\n", | ||
" step = onp.arange(1, len(efficiency) + 1)\n", | ||
" plt.subplot(3, 2, 2 * i + 1)\n", | ||
" line, = plt.plot(step, efficiency)\n", | ||
" plt.plot(step[mask], efficiency[mask], 'o', color=line.get_color())\n", | ||
" plt.xlabel(\"step\")\n", | ||
" plt.ylabel(\"Efficiency (%)\")\n", | ||
"\n", | ||
" ax = plt.subplot(3, 2, 2 * i + 2)\n", | ||
" im = ax.imshow(params[wid].array, cmap=\"gray\")\n", | ||
" im.set_clim([0, 1])\n", | ||
" ax.axis(False)\n", | ||
"\n", | ||
"plt.tight_layout()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "44b2d5a4-03af-497b-a60b-474037d47852", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.