Skip to content

Commit

Permalink
Revert CQL reprdoduction script
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 6c64c23 commit 3e57c75
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions reproductions/offline/cql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import math

import d3rlpy

Expand All @@ -18,7 +17,7 @@ def main() -> None:
d3rlpy.seed(args.seed)
d3rlpy.envs.seed_env(env, args.seed)

encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256])
encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256])

if "medium-v0" in args.dataset:
conservative_weight = 10.0
Expand All @@ -29,13 +28,11 @@ def main() -> None:
actor_learning_rate=1e-4,
critic_learning_rate=3e-4,
temp_learning_rate=1e-4,
alpha_learning_rate=3e-4,
initial_alpha=math.e,
actor_encoder_factory=encoder,
critic_encoder_factory=encoder,
batch_size=256,
n_action_samples=10,
alpha_threshold=10,
alpha_learning_rate=0.0,
conservative_weight=conservative_weight,
compile_graph=args.compile,
).create(device=args.gpu)
Expand Down

0 comments on commit 3e57c75

Please sign in to comment.