Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597303593
Change-Id: I88989a0bf18a66eb2b0d2e7add0f69c5c2697583
  • Loading branch information
pwohlhart authored and copybara-github committed Jan 10, 2024
1 parent 76906e9 commit 5ed51a2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
5 changes: 5 additions & 0 deletions rlds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# coding=utf-8
"""RLDS basic API."""


from rlds import metadata


Expand All @@ -30,10 +31,14 @@
from rlds.rlds_types import CORE_STEP_FIELDS
from rlds.rlds_types import DISCOUNT
from rlds.rlds_types import Episode
from rlds.rlds_types import EpisodeFilterFn
from rlds.rlds_types import IS_FIRST
from rlds.rlds_types import IS_LAST
from rlds.rlds_types import IS_TERMINAL
from rlds.rlds_types import OBSERVATION
from rlds.rlds_types import REWARD
from rlds.rlds_types import Step
from rlds.rlds_types import StepFilterFn
from rlds.rlds_types import StepMapFn
from rlds.rlds_types import STEPS
from rlds.rlds_types import StepsToStepsFn
8 changes: 7 additions & 1 deletion rlds/rlds_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# coding=utf-8
"""Types used in RL Datasets."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import tensorflow as tf

Expand Down Expand Up @@ -42,6 +42,12 @@
BatchedStep = Step
BatchedEpisode = Episode

# Step(s) transformation function types.
StepFilterFn = Callable[[Step], bool]
EpisodeFilterFn = Callable[[Episode], bool]
StepsToStepsFn = Callable[[tf.data.Dataset], tf.data.Dataset]
StepMapFn = Callable[[Step], Step]


def build_step(
observation: Optional[Any],
Expand Down

0 comments on commit 5ed51a2

Please sign in to comment.