You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is not an immediate issue but I was playing around with jax_resnet and noticed that ConvBlock decides if it should update it batch statistics or not depending on whether the batch_stats collection is mutable or not. This initially sounds like a safe bet but if you embed ResNet inside a another module that by chance also uses BatchNorm and you want to train the other module but freeze ResNet, it is not clear how you would do this.
Accept a use_running_average (or equivalent) argument in ConvBlock.__call__ and pass it to norm_cls.
Refactor ResNet to be a custom Module (instead of Sequential) so you also accept this in __call__ and pass it around to the relevant submodules that expect it.
Some repos use a single train flag to determine the state of both BatchNorm and Dropout.
Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.
The text was updated successfully, but these errors were encountered:
Thanks for raising this @cgarciae, definitely a relevant use case. I would prefer having a use_running_average member variable in ConvBlock. Perhaps in the future we can add a use_running_average=None argument in ConvBlock.__call__ if there is sufficient demand, then use nn.merge_param just like Flax does, but my general preference is to configure the behaviour of the module during construction (with @nn.compact you do both at once anyway).
Would be amazing if you could open a PR. Let me know if you have any issues with the environment/tests.
Hey @n2cholas!
This is not an immediate issue but I was playing around with
jax_resnet
and noticed thatConvBlock
decides if it should update it batch statistics or not depending on whether thebatch_stats
collection is mutable or not. This initially sounds like a safe bet but if you embedResNet
inside a another module that by chance also usesBatchNorm
and you want to train the other module but freezeResNet
, it is not clear how you would do this.jax-resnet/jax_resnet/common.py
Lines 43 to 44 in 5b00735
To solve this you have to:
use_running_average
(or equivalent) argument inConvBlock.__call__
and pass it tonorm_cls
.ResNet
to be a custom Module (instead ofSequential
) so you also accept this in__call__
and pass it around to the relevant submodules that expect it.Some repos use a single
train
flag to determine the state of both BatchNorm and Dropout.Anyway, not an immediate issue for me but might help some users in the future. Happy to send a PR if the changes makes sense.
The text was updated successfully, but these errors were encountered: