Skip to content

Commit

Permalink
feat: use Sequential from flax (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
SauravMaheshkar authored Jun 5, 2022
1 parent 43f4119 commit 5b00735
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 17 deletions.
14 changes: 1 addition & 13 deletions jax_resnet/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from functools import partial
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple,
Union)
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union

import flax
import flax.linen as nn
import jax.numpy as jnp

ModuleDef = Callable[..., Callable]
# InitFn = Callable[[PRNGKey, Shape, DType], Array]
Expand Down Expand Up @@ -50,16 +48,6 @@ def __call__(self, x):
return x


class Sequential(nn.Module):
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x


def slice_variables(variables: Mapping[str, Any],
start: int = 0,
end: Optional[int] = None) -> flax.core.FrozenDict:
Expand Down
6 changes: 3 additions & 3 deletions jax_resnet/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from flax import linen as nn

from .common import ConvBlock, ModuleDef, Sequential
from .common import ConvBlock, ModuleDef
from .splat import SplAtConv2d

STAGE_SIZES = {
Expand Down Expand Up @@ -185,7 +185,7 @@ def ResNet(
window_shape=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1))),
) -> Sequential:
) -> nn.Sequential:
conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
Expand All @@ -199,7 +199,7 @@ def ResNet(

layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool
layers.append(nn.Dense(n_classes))
return Sequential(layers)
return nn.Sequential(layers)


# yapf: disable
Expand Down
3 changes: 2 additions & 1 deletion tests/test_resnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import pytest
from flax import linen as nn

from jax_resnet import * # noqa

Expand Down Expand Up @@ -137,7 +138,7 @@ def test_slice_variables(start, end):

variables = model.init(key, jnp.ones((1, 224, 224, 3)))
sliced_vars = slice_variables(variables, start, end)
sliced_model = Sequential(model.layers[start:end])
sliced_model = nn.Sequential(model.layers[start:end])

# Need the correct number of input channels for slice:
first = variables['params'][f'layers_{start}']['ConvBlock_0']['Conv_0']['kernel']
Expand Down

0 comments on commit 5b00735

Please sign in to comment.