Skip to content

Commit

Permalink
Merge pull request #33 from invrs-io/readme
Browse files Browse the repository at this point in the history
Expand readme and add animation
  • Loading branch information
mfschubert authored Oct 26, 2023
2 parents 03cb21e + 204a670 commit 1e34959
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 2 deletions.
44 changes: 42 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,56 @@ The `invrs_gym` package is an open-source gym containing a diverse set of photon

Each of the challenges consists of a high-dimensional problem in which a physical structure (the photonic device) is optimized. The structure includes typically >10,000 degrees of freedom (DoF), generally including one or more arrays representing the structure or patterning of a layer, and may also include scalar variables representing e.g. layer thickness. In general, the DoF must satisfy certain constraints to be physical: thicknesses must be positive, and layer patterns must be _manufacturable_---they must not include features that are too small, or too closely spaced.

In general, we seek optimization techniques that _reliably_ produce manufacturable, high-quality solutions and require reasonable compute resources.
In general, we seek optimization techniques that _reliably_ produce manufacturable, high-quality solutions and require reasonable compute resources. Among the techniques that could be applied are topology optimization, inverse design, and AI-guided design.

`invrs_gym` is intended to facilitate research on such methods within the jax ecosystem. It includes several challenges that have been used in previous works, so that researchers may directly compare their results to those of the literature. While some challenges are test problems (e.g. where the structure is two-dimensional, which is unphysical but allows fast simulation), others are actual problems that are relevant e.g. for quantum computing or 3D sensing.

## Key concepts
The key types of the challenge are the `Challenge` and `Component` objects.

The `Component` represents the physical structure to be optimized, and has some intended excitation or operating condition (e.g. illumination with a particular wavelength from a particular direction). The `Component` includes methods to obtain initial parameters, and to compute the _response_ of a component to the excitation.

Each `Challenge` has a `Component` as an attribute, and also has a target that can be used to determine whether particular parameters "solve" the challenge. The `Challenge` also provides functions to compute a scalar loss for use with gradient-based optimization, and additional metrics.

## Example
```python
# Select the challenge.
challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()

# Define loss function, which also returns auxilliary quantities.
def loss_fn(params):
response, aux = challenge.component.response(params)
loss = challenge.loss(response)
distance = challenge.distance_to_target(response)
metrics = challenge.metrics(response, params, aux)
return loss, (response, distance, aux)

value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

# Select an optimizer.
opt = invrs_opt.density_lbfgsb(beta=4)

# Generate initial parameters, and use these to initialize the optimizer state.
params = challenge.component.init(jax.random.PRNGKey(0))
state = opt.init(params)

# Carry out the optimization.
for i in range(steps):
params = opt.params(state)
(value, (response, distance, aux)), grad = value_and_grad_fn(params)
state = opt.update(grad=grad, value=value, params=params, state=state)
```
With some plotting (see the [example notebook](notebooks/readme_example.ipynb)), this code will produce the following waveguide bend:

![Animated evolution of waveguide bend design](docs/img/waveguide_bend.gif)

## Challenges
The current list of challenges is below. Check out the notebooks for ready-to-go examples of each.

Traditionally, a designer faced with such challenges would use their knowledge to define a low-dimensional solution space, and use gradient-free methods such as particle swarms to find a local optimum.
- The **ceviche** challenges are jax-wrapped versions of the [Ceviche Challenges](https://github.com/google/ceviche-challenges) open-sourced by Google, with defaults matching [Inverse Design of Photonic Devices with Strict Foundry Fabrication Constraints](https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313) by Schubert et al.
- The **metagrating** challenge is a re-implementation of the [Metagrating3D](https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D) problem using the [fmmax](https://github.com/facebookresearch/fmmax) simulator.
- The **diffractive splitter** challenge involves designing a non-paraxial diffractive beamsplitter useful for 3D sensing, as discussed in [LightTrans documentation](https://www.lighttrans.com/use-cases/application/design-and-rigorous-analysis-of-non-paraxial-diffractive-beam-splitter.html).
- The **photon extractor** challenge is based on [Inverse-designed photon extractors for optically addressable defect qubits](https://opg.optica.org/optica/fulltext.cfm?uri=optica-7-12-1805) by Chakravarthi et al., and aims to create structures that increase photon extraction efficiency for quantum applications.


## Install
Expand Down
Binary file added docs/img/waveguide_bend.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
127 changes: 127 additions & 0 deletions notebooks/readme_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "e4f6b9c2-e829-46b1-bfd1-de01c690ef27",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as onp\n",
"from skimage import measure\n",
"import gifcm\n",
"\n",
"import invrs_gym\n",
"import invrs_opt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "340fab20-b508-4d47-8c7d-ae884282dce8",
"metadata": {},
"outputs": [],
"source": [
"challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()\n",
"\n",
"# Define loss function, which also returns auxilliary quantities.\n",
"def loss_fn(params):\n",
" response, aux = challenge.component.response(params)\n",
" loss = challenge.loss(response)\n",
" distance = challenge.distance_to_target(response)\n",
" metrics = challenge.metrics(response, params, aux)\n",
" return loss, (response, distance, aux)\n",
"\n",
"value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
"\n",
"# Select an optimizer.\n",
"opt = invrs_opt.density_lbfgsb(beta=4)\n",
"\n",
"# Generate initial parameters, and use these to initialize the optimizer state.\n",
"params = challenge.component.init(jax.random.PRNGKey(0))\n",
"state = opt.init(params)\n",
"\n",
"# Carry out the optimization.\n",
"data = []\n",
"for i in range(36):\n",
" params = opt.params(state)\n",
" (value, (response, distance, aux)), grad = value_and_grad_fn(params)\n",
" state = opt.update(grad=grad, value=value, params=params, state=state)\n",
" data.append((i, value, params, aux))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f6c6fc3-11b5-4f4e-91ab-191c123f5d06",
"metadata": {},
"outputs": [],
"source": [
"# Create an animated gif showing the evolution of the waveguide bend design.\n",
"anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))\n",
"\n",
"for i, _, params, aux in data:\n",
" with anim.frame():\n",
" # Plot fields, using some of the methods specific to the underlying ceviche model.\n",
" density = challenge.component.ceviche_model.density(params.array)\n",
"\n",
" ax = plt.subplot(121)\n",
" ax.imshow(density, cmap=\"gray\")\n",
" plt.text(100, 90, f\"step {i:02}\", color=\"w\", fontsize=20)\n",
" ax.axis(False)\n",
" ax.set_xlim(ax.get_xlim()[::-1])\n",
" ax.set_ylim(ax.get_ylim()[::-1])\n",
"\n",
" # Plot the field, which is a part of the `aux` returned with the challenge response.\n",
" # The field will be overlaid with contours of the binarized design.\n",
" field = onp.real(aux[\"fields\"])\n",
" field = field[0, 0, :, :] # First wavelength, first excitation port.\n",
" contours = measure.find_contours(density)\n",
"\n",
" ax = plt.subplot(122)\n",
" im = ax.imshow(field, cmap=\"bwr\")\n",
" im.set_clim([-onp.amax(field), onp.amax(field)])\n",
" for c in contours:\n",
" plt.plot(c[:, 1], c[:, 0], \"k\", lw=1)\n",
" ax.axis(False)\n",
" ax.set_xlim(ax.get_xlim()[::-1])\n",
" ax.set_ylim(ax.get_ylim()[::-1])\n",
"\n",
"anim.save_gif(\"waveguide_bend.gif\", duration=200)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f75622f1-b207-4ba1-aeec-2c989985b1eb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 2 additions & 0 deletions src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@

__version__ = "v0.0.0"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges

0 comments on commit 1e34959

Please sign in to comment.