Skip to content

Commit

Permalink
Add TD3
Browse files Browse the repository at this point in the history
  • Loading branch information
noahfarr committed Feb 4, 2024
1 parent df0a557 commit bbb5385
Show file tree
Hide file tree
Showing 4 changed files with 429 additions and 20 deletions.
69 changes: 49 additions & 20 deletions rlx/sac/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,27 @@ def thunk():
class SoftQNetwork(nn.Module):
def __init__(
self,
env,
num_layers,
input_dim,
hidden_dim,
activations,
):
super().__init__()
self.fc1 = nn.Linear(
np.array(env.single_observation_space.shape).prod()
+ np.prod(env.single_action_space.shape),
256,
)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 1)
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [1]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]
self.activations = activations
assert (
len(self.layers) == len(self.activations) + 1
), "Number of layers and activations should match"

def __call__(self, x, a):
x = mx.concatenate([x, a], axis=1)
x = nn.relu(self.fc1(x))
x = nn.relu(self.fc2(x))
x = self.fc3(x)
for layer, activation in zip(self.layers[:-1], self.activations):
x = activation(layer(x))
x = self.layers[-1](x)
return x


Expand Down Expand Up @@ -116,10 +121,10 @@ def __call__(self, x):

def copy_weights(source, target, tau):
weights = []
for i, ((_, target_params), (_, q_params)) in enumerate(
for i, ((target_params), (q_params)) in enumerate(
zip(
target.parameters().items(),
source.parameters().items(),
target.parameters()["layers"],
source.parameters()["layers"],
)
):
target_weight = target_params["weight"]
Expand All @@ -130,8 +135,8 @@ def copy_weights(source, target, tau):
weight = tau * q_weight + (1.0 - tau) * target_weight
bias = tau * q_bias + (1.0 - tau) * target_bias

weights.append((f"fc{i+1}.weight", weight))
weights.append((f"fc{i+1}.bias", bias))
weights.append((f"layers.{i}.weight", weight))
weights.append((f"layers.{i}.bias", bias))
target.load_weights(weights)


Expand All @@ -153,12 +158,36 @@ def copy_weights(source, target, tau):

actor = Actor(envs)
mx.eval(actor.parameters())
qf1 = SoftQNetwork(envs)
qf1 = SoftQNetwork(
2,
np.array(envs.single_observation_space.shape).prod()
+ np.prod(envs.single_action_space.shape),
256,
[nn.relu, nn.relu],
)
mx.eval(qf1.parameters())
qf2 = SoftQNetwork(envs)
qf2 = SoftQNetwork(
2,
np.array(envs.single_observation_space.shape).prod()
+ np.prod(envs.single_action_space.shape),
256,
[nn.relu, nn.relu],
)
mx.eval(qf2.parameters())
qf1_target = SoftQNetwork(envs)
qf2_target = SoftQNetwork(envs)
qf1_target = SoftQNetwork(
2,
np.array(envs.single_observation_space.shape).prod()
+ np.prod(envs.single_action_space.shape),
256,
[nn.relu, nn.relu],
)
qf2_target = SoftQNetwork(
2,
np.array(envs.single_observation_space.shape).prod()
+ np.prod(envs.single_action_space.shape),
256,
[nn.relu, nn.relu],
)
copy_weights(qf1, qf1_target, 1.0)
copy_weights(qf2, qf2_target, 1.0)

Expand Down
18 changes: 18 additions & 0 deletions rlx/td3/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

exp_name: str = os.path.basename(__file__)[: -len(".py")]
seed: int = 1

# Algorithm specific arguments
env_id: str = "Pendulum-v1"
total_timesteps: int = 1000000
learning_rate: float = 3e-4
buffer_size: int = int(1e6)
gamma: float = 0.99
tau: float = 0.005
batch_size: int = 256
policy_noise: float = 0.2
exploration_noise: float = 0.1
learning_starts: int = 25e3
policy_frequency: int = 2
noise_clip: float = 0.5
Loading

0 comments on commit bbb5385

Please sign in to comment.