Skip to content

Commit

Permalink
Add experiment script and notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Oct 28, 2023
1 parent e5611eb commit db8c4ec
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 0 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/experiment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: CI

on:
pull_request:
push:
branches:
- main
schedule:
- cron: "0 13 * * 1" # Every Monday at 9AM EST

jobs:
experiment:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev]"
- name: Run experiment
run: |
python scripts/experiment.py --workers=1
126 changes: 126 additions & 0 deletions notebooks/experiment_analysis.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "fc33dfc4-383e-4b94-b857-c293d67d5f9f",
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"import json\n",
"import glob\n",
"import time\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"from skimage import measure\n",
"\n",
"from totypes import json_utils\n",
"\n",
"from invrs_gym import challenges"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ccd23b8-e946-4e6e-9a2a-c4aa9485f3ed",
"metadata": {},
"outputs": [],
"source": [
"# Launch an experiment.\n",
"!python ../scripts/experiment.py --path=\"../experiments/test_experiment\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0dcda0b3-17db-4fce-ba38-9d3d1f68d846",
"metadata": {},
"outputs": [],
"source": [
"# Recover logged scalars and parameters.\n",
"experiment_path = \"../experiments/test_experiment\"\n",
"wid_paths = glob.glob(experiment_path + \"/*\")\n",
"wid_paths.sort()\n",
"\n",
"scalars = {}\n",
"hparams = {}\n",
"params = {}\n",
"\n",
"for path in wid_paths:\n",
" print(path)\n",
" name = path.split(\"/\")[-1]\n",
" checkpoint_fname = glob.glob(path + \"/checkpoint_*.json\")\n",
" if not checkpoint_fname:\n",
" continue\n",
" checkpoint_fname.sort()\n",
" with open(checkpoint_fname[-1], \"r\") as f:\n",
" checkpoint = json_utils.pytree_from_json(f.read())\n",
" scalars[name] = checkpoint[\"scalars\"]\n",
" params[name] = checkpoint[\"params\"]\n",
" with open(path + \"/setup.json\", \"r\") as f:\n",
" hparams[name] = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b127996-77f0-47a3-8c41-3325db52f832",
"metadata": {},
"outputs": [],
"source": [
"# Plot the efficiency trajectory, and the final, optimized and designs.\n",
"\n",
"plt.figure(figsize=(6, 9))\n",
"for i, wid in enumerate(scalars.keys()):\n",
" efficiency = scalars[wid][\"average_efficiency\"] * 100\n",
" mask = scalars[wid][\"distance\"] <= 0\n",
" step = onp.arange(1, len(efficiency) + 1)\n",
" plt.subplot(3, 2, 2 * i + 1)\n",
" line, = plt.plot(step, efficiency)\n",
" plt.plot(step[mask], efficiency[mask], 'o', color=line.get_color())\n",
" plt.xlabel(\"step\")\n",
" plt.ylabel(\"Efficiency (%)\")\n",
"\n",
" ax = plt.subplot(3, 2, 2 * i + 2)\n",
" im = ax.imshow(params[wid].array, cmap=\"gray\")\n",
" im.set_clim([0, 1])\n",
" ax.axis(False)\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44b2d5a4-03af-497b-a60b-474037d47852",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions notebooks/readme_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"source": [
"challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()\n",
"\n",
"\n",
"# Define loss function, which also returns auxilliary quantities.\n",
"def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
Expand All @@ -36,6 +37,7 @@
" metrics = challenge.metrics(response, params, aux)\n",
" return loss, (response, distance, metrics, aux)\n",
"\n",
"\n",
"value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
"\n",
"# Select an optimizer.\n",
Expand Down
Loading

0 comments on commit db8c4ec

Please sign in to comment.