Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 0f37065 commit 92d4013
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/test_torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
TorchMiniBatch,
TorchTrajectoryMiniBatch,
View,
copy_recursively,
eval_api,
get_batch_size,
get_device,
Expand Down Expand Up @@ -168,6 +169,19 @@ def test_to_cpu() -> None:
pass


def test_copy_recursively() -> None:
x = torch.rand(10)
y = torch.rand(10)
copy_recursively(x, y)
assert torch.all(x == y)

x_list = [torch.rand(10), torch.rand(20)]
y_list = [torch.rand(10), torch.rand(20)]
copy_recursively(x_list, y_list)
assert torch.all(x_list[0] == y_list[0])
assert torch.all(x_list[1] == y_list[1])


def test_get_device() -> None:
x = torch.rand(10)
assert get_device(x) == "cpu"
Expand Down Expand Up @@ -323,6 +337,29 @@ def test_torch_mini_batch(
assert np.all(torch_batch.terminals.numpy() == batch.terminals)
assert np.all(torch_batch.intervals.numpy() == batch.intervals)

torch_batch2 = TorchMiniBatch(
observations=torch.zeros_like(torch_batch.observations),
actions=torch.zeros_like(torch_batch.actions),
rewards=torch.zeros_like(torch_batch.rewards),
next_observations=torch.zeros_like(torch_batch.next_observations),
next_actions=torch.zeros_like(torch_batch.next_actions),
returns_to_go=torch.zeros_like(torch_batch.returns_to_go),
terminals=torch.zeros_like(torch_batch.terminals),
intervals=torch.zeros_like(torch_batch.intervals),
device=torch_batch.device,
)
torch_batch2.copy_(torch_batch)
assert torch.all(torch_batch2.observations == torch_batch.observations)
assert torch.all(torch_batch2.actions == torch_batch.actions)
assert torch.all(torch_batch2.rewards == torch_batch.rewards)
assert torch.all(
torch_batch2.next_observations == torch_batch.next_observations
)
assert torch.all(torch_batch2.next_actions == torch_batch.next_actions)
assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go)
assert torch.all(torch_batch2.terminals == torch_batch.terminals)
assert torch.all(torch_batch2.intervals == torch_batch.intervals)


@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("length", [32])
Expand Down Expand Up @@ -397,6 +434,25 @@ def test_torch_trajectory_mini_batch(

assert np.all(torch_batch.terminals.numpy() == batch.terminals)

torch_batch2 = TorchTrajectoryMiniBatch(
observations=torch.zeros_like(torch_batch.observations),
actions=torch.zeros_like(torch_batch.actions),
rewards=torch.zeros_like(torch_batch.rewards),
returns_to_go=torch.zeros_like(torch_batch.returns_to_go),
terminals=torch.zeros_like(torch_batch.terminals),
timesteps=torch.zeros_like(torch_batch.timesteps),
masks=torch.zeros_like(torch_batch.masks),
device=torch_batch.device,
)
torch_batch2.copy_(torch_batch)
assert torch.all(torch_batch2.observations == torch_batch.observations)
assert torch.all(torch_batch2.actions == torch_batch.actions)
assert torch.all(torch_batch2.rewards == torch_batch.rewards)
assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go)
assert torch.all(torch_batch2.terminals == torch_batch.terminals)
assert torch.all(torch_batch2.timesteps == torch_batch.timesteps)
assert torch.all(torch_batch2.masks == torch_batch.masks)


def test_checkpointer() -> None:
fc1 = torch.nn.Linear(100, 100)
Expand Down

0 comments on commit 92d4013

Please sign in to comment.