Skip to content

Commit

Permalink
Fix issue with MLP
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Mar 12, 2024
1 parent 99c081e commit f46be98
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pdequinox/_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def __init__(
)

def __call__(self, x):
if x.shape != self.in_shape:
if x.shape != self._in_shape:
raise ValueError(
f"Input shape {x.shape} does not match expected shape {self.in_shape}. For batched operation use jax.vmap"
f"Input shape {x.shape} does not match expected shape {self._in_shape}. For batched operation use jax.vmap"
)
x_flat = x.flatten()
x_flat = self.flat_mlp(x_flat)
x = x_flat.reshape(self.out_shape)
x = x_flat.reshape(self._out_shape)
return x

0 comments on commit f46be98

Please sign in to comment.