diff --git a/rlds/__init__.py b/rlds/__init__.py index 35d16fe..da0d4a5 100644 --- a/rlds/__init__.py +++ b/rlds/__init__.py @@ -15,6 +15,7 @@ # coding=utf-8 """RLDS basic API.""" + from rlds import metadata @@ -30,10 +31,14 @@ from rlds.rlds_types import CORE_STEP_FIELDS from rlds.rlds_types import DISCOUNT from rlds.rlds_types import Episode +from rlds.rlds_types import EpisodeFilterFn from rlds.rlds_types import IS_FIRST from rlds.rlds_types import IS_LAST from rlds.rlds_types import IS_TERMINAL from rlds.rlds_types import OBSERVATION from rlds.rlds_types import REWARD from rlds.rlds_types import Step +from rlds.rlds_types import StepFilterFn +from rlds.rlds_types import StepMapFn from rlds.rlds_types import STEPS +from rlds.rlds_types import StepsToStepsFn diff --git a/rlds/rlds_types.py b/rlds/rlds_types.py index 97be666..51b77cb 100644 --- a/rlds/rlds_types.py +++ b/rlds/rlds_types.py @@ -14,7 +14,7 @@ # coding=utf-8 """Types used in RL Datasets.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import tensorflow as tf @@ -42,6 +42,12 @@ BatchedStep = Step BatchedEpisode = Episode +# Step(s) transformation function types. +StepFilterFn = Callable[[Step], bool] +EpisodeFilterFn = Callable[[Episode], bool] +StepsToStepsFn = Callable[[tf.data.Dataset], tf.data.Dataset] +StepMapFn = Callable[[Step], Step] + def build_step( observation: Optional[Any],