diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index 73e17ff3..27ce4f7b 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa from skrl.trainers.jax import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model class Runner: @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": None, # memory diff --git a/skrl/utils/runner/torch/runner.py b/skrl/utils/runner/torch/runner.py index 7f7d32c4..5e4f6b52 100644 --- a/skrl/utils/runner/torch/runner.py +++ b/skrl/utils/runner/torch/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa from skrl.trainers.torch import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model +from skrl.utils.model_instantiators.torch import categorical_model, deterministic_model, gaussian_model, shared_model class Runner: @@ -35,6 +35,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": shared_model, # memory