diff --git a/dacapo/experiments/model.py b/dacapo/experiments/model.py index 8ca2b2b9e..75777cd81 100644 --- a/dacapo/experiments/model.py +++ b/dacapo/experiments/model.py @@ -40,6 +40,13 @@ def __init__( ) self.eval_activation = eval_activation + # UPDATE WEIGHT INITIALIZATION TO USE KAIMING + # TODO: put this somewhere better, there might be + # conv layers that aren't follwed by relus? + for _name, layer in self.named_modules(): + if isinstance(layer, torch.nn.modules.conv._ConvNd): + torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu") + def forward(self, x): result = self.chain(x) if not self.training and self.eval_activation is not None: