Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand readme and add animation #33

Merged
merged 5 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,56 @@ The `invrs_gym` package is an open-source gym containing a diverse set of photon

Each of the challenges consists of a high-dimensional problem in which a physical structure (the photonic device) is optimized. The structure includes typically >10,000 degrees of freedom (DoF), generally including one or more arrays representing the structure or patterning of a layer, and may also include scalar variables representing e.g. layer thickness. In general, the DoF must satisfy certain constraints to be physical: thicknesses must be positive, and layer patterns must be _manufacturable_---they must not include features that are too small, or too closely spaced.

In general, we seek optimization techniques that _reliably_ produce manufacturable, high-quality solutions and require reasonable compute resources.
In general, we seek optimization techniques that _reliably_ produce manufacturable, high-quality solutions and require reasonable compute resources. Among the techniques that could be applied are topology optimization, inverse design, and AI-guided design.

`invrs_gym` is intended to facilitate research on such methods within the jax ecosystem. It includes several challenges that have been used in previous works, so that researchers may directly compare their results to those of the literature. While some challenges are test problems (e.g. where the structure is two-dimensional, which is unphysical but allows fast simulation), others are actual problems that are relevant e.g. for quantum computing or 3D sensing.

## Key concepts
The key types of the challenge are the `Challenge` and `Component` objects.

The `Component` represents the physical structure to be optimized, and has some intended excitation or operating condition (e.g. illumination with a particular wavelength from a particular direction). The `Component` includes methods to obtain initial parameters, and to compute the _response_ of a component to the excitation.

Each `Challenge` has a `Component` as an attribute, and also has a target that can be used to determine whether particular parameters "solve" the challenge. The `Challenge` also provides functions to compute a scalar loss for use with gradient-based optimization, and additional metrics.

## Example
```python
# Select the challenge.
challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()

# Define loss function, which also returns auxilliary quantities.
def loss_fn(params):
response, aux = challenge.component.response(params)
loss = challenge.loss(response)
distance = challenge.distance_to_target(response)
metrics = challenge.metrics(response, params, aux)
return loss, (response, distance, aux)

value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

# Select an optimizer.
opt = invrs_opt.density_lbfgsb(beta=4)

# Generate initial parameters, and use these to initialize the optimizer state.
params = challenge.component.init(jax.random.PRNGKey(0))
state = opt.init(params)

# Carry out the optimization.
for i in range(steps):
params = opt.params(state)
(value, (response, distance, aux)), grad = value_and_grad_fn(params)
state = opt.update(grad=grad, value=value, params=params, state=state)
```
With some plotting (see the [example notebook](notebooks/readme_example.ipynb)), this code will produce the following waveguide bend:

![Animated evolution of waveguide bend design](docs/img/waveguide_bend.gif)

## Challenges
The current list of challenges is below. Check out the notebooks for ready-to-go examples of each.

Traditionally, a designer faced with such challenges would use their knowledge to define a low-dimensional solution space, and use gradient-free methods such as particle swarms to find a local optimum.
- The **ceviche** challenges are jax-wrapped versions of the [Ceviche Challenges](https://github.com/google/ceviche-challenges) open-sourced by Google, with defaults matching [Inverse Design of Photonic Devices with Strict Foundry Fabrication Constraints](https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313) by Schubert et al.
- The **metagrating** challenge is a re-implementation of the [Metagrating3D](https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D) problem using the [fmmax](https://github.com/facebookresearch/fmmax) simulator.
- The **diffractive splitter** challenge involves designing a non-paraxial diffractive beamsplitter useful for 3D sensing, as discussed in [LightTrans documentation](https://www.lighttrans.com/use-cases/application/design-and-rigorous-analysis-of-non-paraxial-diffractive-beam-splitter.html).
- The **photon extractor** challenge is based on [Inverse-designed photon extractors for optically addressable defect qubits](https://opg.optica.org/optica/fulltext.cfm?uri=optica-7-12-1805) by Chakravarthi et al., and aims to create structures that increase photon extraction efficiency for quantum applications.


## Install
Expand Down
Binary file added docs/img/waveguide_bend.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
127 changes: 127 additions & 0 deletions notebooks/readme_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "e4f6b9c2-e829-46b1-bfd1-de01c690ef27",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"from skimage import measure\n",
"import gifcm\n",
"\n",
"import invrs_gym\n",
"import invrs_opt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "340fab20-b508-4d47-8c7d-ae884282dce8",
"metadata": {},
"outputs": [],
"source": [
"challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()\n",
"\n",
"# Define loss function, which also returns auxilliary quantities.\n",
"def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
" loss = challenge.loss(response)\n",
" distance = challenge.distance_to_target(response)\n",
" metrics = challenge.metrics(response, params, aux)\n",
" return loss, (response, distance, aux)\n",
"\n",
"value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
"\n",
"# Select an optimizer.\n",
"opt = invrs_opt.density_lbfgsb(beta=4)\n",
"\n",
"# Generate initial parameters, and use these to initialize the optimizer state.\n",
"params = challenge.component.init(jax.random.PRNGKey(0))\n",
"state = opt.init(params)\n",
"\n",
"# Carry out the optimization.\n",
"data = []\n",
"for i in range(36):\n",
" params = opt.params(state)\n",
" (value, (response, distance, aux)), grad = value_and_grad_fn(params)\n",
" state = opt.update(grad=grad, value=value, params=params, state=state)\n",
" data.append((i, value, params, aux))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f6c6fc3-11b5-4f4e-91ab-191c123f5d06",
"metadata": {},
"outputs": [],
"source": [
"# Create an animated gif showing the evolution of the waveguide bend design.\n",
"anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))\n",
"\n",
"for i, _, params, aux in data:\n",
" with anim.frame():\n",
" # Plot fields, using some of the methods specific to the underlying ceviche model.\n",
" density = challenge.component.ceviche_model.density(params.array)\n",
"\n",
" ax = plt.subplot(121)\n",
" ax.imshow(density, cmap=\"gray\")\n",
" plt.text(100, 90, f\"step {i:02}\", color=\"w\", fontsize=20)\n",
" ax.axis(False)\n",
" ax.set_xlim(ax.get_xlim()[::-1])\n",
" ax.set_ylim(ax.get_ylim()[::-1])\n",
"\n",
" # Plot the field, which is a part of the `aux` returned with the challenge response.\n",
" # The field will be overlaid with contours of the binarized design.\n",
" field = onp.real(aux[\"fields\"])\n",
" field = field[0, 0, :, :] # First wavelength, first excitation port.\n",
" contours = measure.find_contours(density)\n",
"\n",
" ax = plt.subplot(122)\n",
" im = ax.imshow(field, cmap=\"bwr\")\n",
" im.set_clim([-onp.amax(field), onp.amax(field)])\n",
" for c in contours:\n",
" plt.plot(c[:, 1], c[:, 0], \"k\", lw=1)\n",
" ax.axis(False)\n",
" ax.set_xlim(ax.get_xlim()[::-1])\n",
" ax.set_ylim(ax.get_ylim()[::-1])\n",
"\n",
"anim.save_gif(\"waveguide_bend.gif\", duration=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f75622f1-b207-4ba1-aeec-2c989985b1eb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@

__version__ = "v0.0.0"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges