diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index 18ae98e..96528a5 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -59,7 +59,7 @@ def __init__( output_network_kwargs = dict() self.model = output_network( - in_features=np.product(self.pvnet_output_shape), + in_features=np.prod(self.pvnet_output_shape), out_features=self.num_output_features, **output_network_kwargs, )