From 5b00735aa0a68ec239af4a728ad4a596c1b551f6 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Mon, 6 Jun 2022 05:18:39 +0530 Subject: [PATCH] feat: use Sequential from flax (#7) --- jax_resnet/common.py | 14 +------------- jax_resnet/resnet.py | 6 +++--- tests/test_resnet.py | 3 ++- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/jax_resnet/common.py b/jax_resnet/common.py index ac03ee1..73a6d78 100644 --- a/jax_resnet/common.py +++ b/jax_resnet/common.py @@ -1,10 +1,8 @@ from functools import partial -from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, - Union) +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union import flax import flax.linen as nn -import jax.numpy as jnp ModuleDef = Callable[..., Callable] # InitFn = Callable[[PRNGKey, Shape, DType], Array] @@ -50,16 +48,6 @@ def __call__(self, x): return x -class Sequential(nn.Module): - layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]] - - @nn.compact - def __call__(self, x): - for layer in self.layers: - x = layer(x) - return x - - def slice_variables(variables: Mapping[str, Any], start: int = 0, end: Optional[int] = None) -> flax.core.FrozenDict: diff --git a/jax_resnet/resnet.py b/jax_resnet/resnet.py index 85aadad..f62c8be 100644 --- a/jax_resnet/resnet.py +++ b/jax_resnet/resnet.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from flax import linen as nn -from .common import ConvBlock, ModuleDef, Sequential +from .common import ConvBlock, ModuleDef from .splat import SplAtConv2d STAGE_SIZES = { @@ -185,7 +185,7 @@ def ResNet( window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1))), -) -> Sequential: +) -> nn.Sequential: conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls) stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls) block_cls = partial(block_cls, conv_block_cls=conv_block_cls) @@ -199,7 +199,7 @@ def ResNet( layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool layers.append(nn.Dense(n_classes)) - return Sequential(layers) + return nn.Sequential(layers) # yapf: disable diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 036f8b7..2f41269 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -1,5 +1,6 @@ import jax import pytest +from flax import linen as nn from jax_resnet import * # noqa @@ -137,7 +138,7 @@ def test_slice_variables(start, end): variables = model.init(key, jnp.ones((1, 224, 224, 3))) sliced_vars = slice_variables(variables, start, end) - sliced_model = Sequential(model.layers[start:end]) + sliced_model = nn.Sequential(model.layers[start:end]) # Need the correct number of input channels for slice: first = variables['params'][f'layers_{start}']['ConvBlock_0']['Conv_0']['kernel']