Skip to content

Commit

Permalink
Refactor CalQL reproduction script
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed May 7, 2024
1 parent 645b49d commit 1625361
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions reproductions/finetuning/cal_ql_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ def main() -> None:
parser.add_argument("--gpu", type=int)
args = parser.parse_args()

# sparse reward setup requires special treatment for failure trajectories
transition_picker = d3rlpy.dataset.SparseRewardTransitionPicker(
horizon_length=100,
step_reward=0,
)

dataset, env = d3rlpy.datasets.get_d4rl(
args.dataset,
transition_picker=d3rlpy.dataset.SparseRewardTransitionPicker(
horizon_length=100,
step_reward=0,
),
transition_picker=transition_picker,
)

# fix seed
Expand Down Expand Up @@ -60,7 +63,11 @@ def main() -> None:
)

# prepare FIFO buffer filled with dataset episodes
buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000, env=env)
buffer = d3rlpy.dataset.create_fifo_replay_buffer(
limit=1000000,
env=env,
transition_picker=transition_picker,
)

# sample half from offline dataset and the rest from online buffer
mixed_buffer = d3rlpy.dataset.MixedReplayBuffer(
Expand Down

0 comments on commit 1625361

Please sign in to comment.