diff --git a/.github/workflows/build-ci.yml b/.github/workflows/build-ci.yml index 0e57d34..4011192 100644 --- a/.github/workflows/build-ci.yml +++ b/.github/workflows/build-ci.yml @@ -15,13 +15,6 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 - - name: Checkout submodule - uses: actions/checkout@v4 - with: - repository: invrs-io/totypes - path: totypes - token: ${{ secrets.INVRSIO_PAT }} - - name: Set up Python uses: actions/setup-python@v4 with: @@ -32,7 +25,6 @@ jobs: - name: Test pre-commit hooks run: | python -m pip install --upgrade pip - pip install "totypes/" pip install pre-commit pre-commit run -a @@ -66,6 +58,10 @@ jobs: run: | darglint src --strictness=short --ignore-raise=ValueError + - name: Validate docstrings + run: | + mypy src + tests: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f02a87..9cbc4a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,13 +41,6 @@ repos: args: [--exit-zero] exclude: ^tests/ - # Type annotation - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.1 - hooks: - - id: mypy - exclude: ^(notebooks|tests) - # Linter - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.0.287 diff --git a/pyproject.toml b/pyproject.toml index 2f9e458..8571f3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dev = [ "bump-my-version", "darglint", "invrs_gym[tests]", + "mypy", "pre-commit", ] @@ -46,10 +47,6 @@ build-backend = "setuptools.build_meta" line-length = 88 target-version = ['py310'] -[tool.mypy] -python_version = "3.10" -strict = true - [tool.isort] multi_line_output = 3 line_length = 88 diff --git a/src/invrs_gym/challenge/ceviche/challenge.py b/src/invrs_gym/challenge/ceviche/challenge.py index f6dc0dc..dabbad8 100644 --- a/src/invrs_gym/challenge/ceviche/challenge.py +++ b/src/invrs_gym/challenge/ceviche/challenge.py @@ -4,11 +4,11 @@ import functools from typing import Any, Callable, Dict, Optional, Sequence, Tuple -import agjax +import agjax # type: ignore[import] import jax import jax.numpy as jnp import numpy as onp -from totypes import types # type: ignore[attr-defined] +from totypes import types # type: ignore[import,attr-defined,unused-ignore] from invrs_gym.challenge.ceviche import defaults from invrs_gym.loss import transmission_loss @@ -85,7 +85,7 @@ def sim_fn( excite_port_idxs: Sequence[int], wavelengths_nm: Optional[jnp.ndarray], max_parallelizm: Optional[int], - ) -> Tuple[jnp.ndarray, onp.ndarray]: + ) -> Tuple[jnp.ndarray, onp.ndarray[Any, Any]]: s_params, fields = self.ceviche_model.simulate( design_variable, excite_port_idxs, wavelengths_nm, max_parallelizm ) @@ -227,8 +227,8 @@ def loss(self, response: jnp.ndarray) -> jnp.ndarray: transmission=transmission, window_lower_bound=lb, window_upper_bound=ub, - transmission_exponent=TRANSMISSION_EXPONENT, - scalar_exponent=SCALAR_EXPONENT, + transmission_exponent=jnp.asarray(TRANSMISSION_EXPONENT), + scalar_exponent=jnp.asarray(SCALAR_EXPONENT), ) def metrics( @@ -269,7 +269,8 @@ def _wavelength_bound( ) repeats = transmission_shape[0] // band_bound.shape[0] - return jnp.repeat(band_bound, repeats, axis=0) + repeated: jnp.ndarray = jnp.repeat(band_bound, repeats, axis=0) + return repeated # ----------------------------------------------------------------------------- diff --git a/src/invrs_gym/challenge/ceviche/defaults.py b/src/invrs_gym/challenge/ceviche/defaults.py index 9283d80..7f9ba61 100644 --- a/src/invrs_gym/challenge/ceviche/defaults.py +++ b/src/invrs_gym/challenge/ceviche/defaults.py @@ -3,10 +3,15 @@ from typing import Union import jax.numpy as jnp -from ceviche_challenges import beam_splitter, mode_converter, model_base, params +from ceviche_challenges import ( # type: ignore[import] + beam_splitter, + mode_converter, + model_base, + params, +) from ceviche_challenges import units as u from ceviche_challenges import waveguide_bend, wdm -from totypes import symmetry # type: ignore[attr-defined] +from totypes import symmetry # type: ignore[import,attr-defined,unused-ignore] DeviceSpec = Union[ beam_splitter.spec.BeamSplitterSpec, @@ -17,9 +22,10 @@ Model = model_base.Model -def _linear_from_decibels(x_decibels: jnp.ndarray) -> jnp.ndarray: +def _linear_from_decibels(x_decibels: float) -> jnp.ndarray: """Converts a quantity in decibels to linear.""" - return 10 ** (x_decibels / 10) + linear: jnp.ndarray = jnp.asarray(10 ** (x_decibels / 10)) + return linear WG_WIDTH = 400 * u.nm diff --git a/src/invrs_gym/loss/transmission_loss.py b/src/invrs_gym/loss/transmission_loss.py index eac61d6..bf614cc 100644 --- a/src/invrs_gym/loss/transmission_loss.py +++ b/src/invrs_gym/loss/transmission_loss.py @@ -16,10 +16,11 @@ def orthotope_smooth_transmission_loss( ) -> jnp.ndarray: """Compute a scalar loss from a array based on an orthotope transmission window. - The loss is related to an orthotope window, i.e. a orthotope target region within - the space that contains valid responses. With values for `transmission_exponent` - and `scalar_exponent` of `0.5` and `2.0`, respectively, this loss function is - equivalent to that of [2022 Schubert](https://arxiv.org/abs/2201.12965). + The loss is related to an orthotope window in tansmission space, i.e. the space + of the squared magnitude of scattering parameters. With values for + `transmission_exponent` and `scalar_exponent` of `1.0` and `2.0`, respectively, + this loss function is equivalent to that of [2022 Schubert] + (https://arxiv.org/abs/2201.12965). Args: transmission: The transmission array for which the loss is to be calculated. @@ -33,23 +34,28 @@ def orthotope_smooth_transmission_loss( Returns: The scalar loss value. """ - # The signed distance to the target window is positive - elementwise_signed_distance = elementwise_signed_distance_to_window( + # Compute the signed psuedodistance. This is equal to the signed distance to the + # nearest bound, except when the bounds are the min and max physical transmission + # values, in which case the distance is equal to the window size. + elementwise_signed_psuedodistance = elementwise_signed_psuedodistance_to_window( transmission=transmission**transmission_exponent, window_lower_bound=window_lower_bound**transmission_exponent, window_upper_bound=window_upper_bound**transmission_exponent, ) - window_size = jnp.minimum( - window_upper_bound, _MAX_PHYSICAL_TRANSMISSION - ) - jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION) + # Scale the signed distance by the maximum window size. + lower_bound = jnp.maximum(window_lower_bound, _MIN_PHYSICAL_TRANSMISSION) + upper_bound = jnp.minimum(window_upper_bound, _MAX_PHYSICAL_TRANSMISSION) + window_size = upper_bound - lower_bound + elementwise_signed_psuedodistance /= jnp.amin(window_size) - # TODO: check whether dividing by max or min is appropriate. transformed_elementwise_signed_distance = jax.nn.softplus( - elementwise_signed_distance / jnp.amax(window_size) + elementwise_signed_psuedodistance ) - - return jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent + loss: jnp.ndarray = ( + jnp.linalg.norm(transformed_elementwise_signed_distance) ** scalar_exponent + ) + return loss def distance_to_window( @@ -70,21 +76,37 @@ def distance_to_window( Returns: The elementwise signed distance. """ - elementwise_signed_distance = elementwise_signed_distance_to_window( + elementwise_signed_distance = elementwise_signed_psuedodistance_to_window( transmission=transmission, window_lower_bound=window_lower_bound, window_upper_bound=window_upper_bound, ) elementwise_distance = jnp.maximum(elementwise_signed_distance, 0.0) - return jnp.linalg.norm(elementwise_distance) + distance: jnp.ndarray = jnp.linalg.norm(elementwise_distance) + return distance -def elementwise_signed_distance_to_window( +def elementwise_signed_psuedodistance_to_window( transmission: jnp.ndarray, window_lower_bound: jnp.ndarray, window_upper_bound: jnp.ndarray, ) -> jnp.ndarray: - """Returns the elementwise signed distance to a transmission window. + """Returns the elementwise signed psuedodistance to a transmission window. + + The psuedodistance is given by, + + - the distance to the nearest bound, when both bounds are within the + window (e.g. the lower bound is greater than the minimum physical + transmission value). + - the distance to the upper bound, when the lower bound is less than or + equal to the minimum physical transmission value, and the upper bound + is less than the maximum physical transmission value. + - the distance to the lower bound, when the upper bound is greater than + or equal to the maximum physicall transmission value, and the lower + bound is greater than the minimum physical transmission value. + - the difference between the maximum and minimum physical transmission + value, when both bounds equal or exceed their physical extremal values. + In this case, the psuedodistance has no dependence on `transmission`. Args: transmission: The transmission for which the signed distance is sought. @@ -92,7 +114,7 @@ def elementwise_signed_distance_to_window( window_upper_bound: Array defining the transmission window upper bound. Returns: - The elementwise signed distance. + The elementwise signed psuedodistance. """ # Signed distance to lower bound is positive when `transmission` is below the lower # bound and negative when it is above the lower bound. When the lower bound is less