From 11adcee98811d9eb7e6b084009b1da9b20b91ac5 Mon Sep 17 00:00:00 2001 From: takuseno Date: Thu, 18 Jan 2024 20:12:29 +0900 Subject: [PATCH] Fix use_batch_norm error --- d3rlpy/models/torch/encoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index 8125b2da..86745806 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -282,8 +282,8 @@ def compute_output_size( inputs = [] for shape in input_shapes: if isinstance(shape[0], (list, tuple)): - inputs.append([torch.rand(1, *s, device=device) for s in shape]) + inputs.append([torch.rand(2, *s, device=device) for s in shape]) else: - inputs.append(torch.rand(1, *shape, device=device)) + inputs.append(torch.rand(2, *shape, device=device)) y = encoder(*inputs) return int(y.shape[1])