Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 4, 2024
1 parent a982ebc commit 22f6da8
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 16 deletions.
3 changes: 1 addition & 2 deletions examples/gradient-based-offpolicy/q_learning_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def __init__(self) -> None:
+ cs.sum2(f.T @ cs.vertcat(x[:, :-1], u))
+ 0.5
* cs.sum2(
gammapowers
* (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
gammapowers * (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
)
)

Expand Down
3 changes: 1 addition & 2 deletions examples/gradient-based-onpolicy/dpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def __init__(self) -> None:
+ cs.sum2(f.T @ cs.vertcat(x[:, :-1], u))
+ 0.5
* cs.sum2(
gammapowers
* (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
gammapowers * (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
)
)

Expand Down
3 changes: 1 addition & 2 deletions examples/gradient-based-onpolicy/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,7 @@ def __init__(self) -> None:
+ cs.sum2(f.T @ cs.vertcat(x[:, :-1], u))
+ 0.5
* cs.sum2(
gammapowers
* (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
gammapowers * (cs.sum1(x[:, :-1] ** 2) + 0.5 * cs.sum1(u**2) + w.T @ s)
)
)

Expand Down
1 change: 0 additions & 1 deletion src/mpcrl/core/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ def __repr__(self) -> str:


class Chain(Scheduler[ScType]):

"""Chains multiple schedulers together.
Parameters
Expand Down
5 changes: 4 additions & 1 deletion src/mpcrl/optim/gradient_free_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class GradientFreeOptimizer(BaseOptimizer[SymType], ABC):
@abstractmethod
def ask(
self,
) -> tuple[Union[dict[str, npt.ArrayLike], npt.ArrayLike], Optional[str],]:
) -> tuple[
Union[dict[str, npt.ArrayLike], npt.ArrayLike],
Optional[str],
]:
"""Asks the learning agent for a new set of parameters to evaluate.
Returns
Expand Down
6 changes: 2 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,7 @@ def test_epsilon_greedy_exploration__sometimes_explores(self):
self.assertTrue(found_true and found_false)

def test_epsilon_greedy_exploration__decays_strength(self):
class MockScheduler(S.NoScheduling):
...
class MockScheduler(S.NoScheduling): ...

epsilon_scheduler = MockScheduler(None)
strength_scheduler = MockScheduler(None)
Expand Down Expand Up @@ -347,8 +346,7 @@ def test_ornsteinuhlenbeck_exploration__always_explores(self):
self.assertTrue(exploration.can_explore())

def test_ornsteinuhlenbeck_exploration__decays_mean_and_sigma(self):
class MockScheduler(S.NoScheduling):
...
class MockScheduler(S.NoScheduling): ...

mean_scheduler = MockScheduler(None)
sigma_scheduler = MockScheduler(None)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from mpcrl.util import control, iters, math, named, seeding


class DummyAgent(named.Named):
...
class DummyAgent(named.Named): ...


class TestNamedAgent(unittest.TestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ def get_mpc(horizon: int, multistart: bool):
return mpc


class DummyAgent(Agent):
...
class DummyAgent(Agent): ...


class DummyLearningAgent(LearningAgent):
Expand Down

0 comments on commit 22f6da8

Please sign in to comment.