From e8cb3943ee926ba13df2992c1397790d6f42e693 Mon Sep 17 00:00:00 2001 From: Telios Date: Fri, 25 Oct 2024 15:41:09 +0200 Subject: [PATCH 1/3] add class mapping of categorical model --- skrl/utils/runner/jax/runner.py | 3 ++- skrl/utils/runner/torch/runner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index 73e17ff3..d16f0ad8 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa from skrl.trainers.jax import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model +from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model, categorical_model class Runner: @@ -37,6 +37,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, "gaussianmixin": gaussian_model, "deterministicmixin": deterministic_model, "shared": None, + "categoricalmixin": categorical_model, # memory "randommemory": RandomMemory, # agent diff --git a/skrl/utils/runner/torch/runner.py b/skrl/utils/runner/torch/runner.py index 7f7d32c4..64c54ce9 100644 --- a/skrl/utils/runner/torch/runner.py +++ b/skrl/utils/runner/torch/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa from skrl.trainers.torch import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model +from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model, categorical_model class Runner: @@ -37,6 +37,7 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, "gaussianmixin": gaussian_model, "deterministicmixin": deterministic_model, "shared": shared_model, + "categoricalmixin": categorical_model, # memory "randommemory": RandomMemory, # agent From 0dc36ea852ede12cca82542b815102da6d3432a7 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 09:58:58 -0500 Subject: [PATCH 2/3] Apply format to runner.py in jax --- skrl/utils/runner/jax/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrl/utils/runner/jax/runner.py b/skrl/utils/runner/jax/runner.py index d16f0ad8..27ce4f7b 100644 --- a/skrl/utils/runner/jax/runner.py +++ b/skrl/utils/runner/jax/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa from skrl.trainers.jax import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model, categorical_model +from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model class Runner: @@ -35,9 +35,9 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": None, - "categoricalmixin": categorical_model, # memory "randommemory": RandomMemory, # agent From b65872cbfc186b401c285be96397bfd3046db7ec Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Sun, 3 Nov 2024 10:00:14 -0500 Subject: [PATCH 3/3] Apply format to runner.py in torch --- skrl/utils/runner/torch/runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skrl/utils/runner/torch/runner.py b/skrl/utils/runner/torch/runner.py index 64c54ce9..5e4f6b52 100644 --- a/skrl/utils/runner/torch/runner.py +++ b/skrl/utils/runner/torch/runner.py @@ -14,7 +14,7 @@ from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa from skrl.trainers.torch import SequentialTrainer, Trainer from skrl.utils import set_seed -from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model, categorical_model +from skrl.utils.model_instantiators.torch import categorical_model, deterministic_model, gaussian_model, shared_model class Runner: @@ -35,9 +35,9 @@ def __init__(self, env: Union[Wrapper, MultiAgentEnvWrapper], cfg: Mapping[str, self._class_mapping = { # model "gaussianmixin": gaussian_model, + "categoricalmixin": categorical_model, "deterministicmixin": deterministic_model, "shared": shared_model, - "categoricalmixin": categorical_model, # memory "randommemory": RandomMemory, # agent