diff --git a/d3rlpy/algos/qlearning/torch/utility.py b/d3rlpy/algos/qlearning/torch/utility.py index 2a548d2b..f2af6ace 100644 --- a/d3rlpy/algos/qlearning/torch/utility.py +++ b/d3rlpy/algos/qlearning/torch/utility.py @@ -1,6 +1,6 @@ +from typing import Protocol import torch -from typing_extensions import Protocol from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, diff --git a/d3rlpy/algos/transformer/action_samplers.py b/d3rlpy/algos/transformer/action_samplers.py index 43cd8ce6..e404661f 100644 --- a/d3rlpy/algos/transformer/action_samplers.py +++ b/d3rlpy/algos/transformer/action_samplers.py @@ -1,7 +1,6 @@ -from typing import Union +from typing import Protocol, Union import numpy as np -from typing_extensions import Protocol from ...types import NDArray diff --git a/d3rlpy/dataset/buffers.py b/d3rlpy/dataset/buffers.py index 36bacd77..6723dff9 100644 --- a/d3rlpy/dataset/buffers.py +++ b/d3rlpy/dataset/buffers.py @@ -1,7 +1,5 @@ from collections import deque -from typing import Sequence - -from typing_extensions import Protocol +from typing import Protocol, Sequence from .components import EpisodeBase diff --git a/d3rlpy/dataset/components.py b/d3rlpy/dataset/components.py index a41a2198..b8e50f76 100644 --- a/d3rlpy/dataset/components.py +++ b/d3rlpy/dataset/components.py @@ -1,8 +1,7 @@ import dataclasses -from typing import Any, Sequence +from typing import Any, Protocol, Sequence import numpy as np -from typing_extensions import Protocol from ..constants import ActionSpace from ..types import ( diff --git a/d3rlpy/dataset/episode_generator.py b/d3rlpy/dataset/episode_generator.py index 4d594690..5857bc15 100644 --- a/d3rlpy/dataset/episode_generator.py +++ b/d3rlpy/dataset/episode_generator.py @@ -1,7 +1,6 @@ -from typing import Optional, Sequence +from typing import Optional, Protocol, Sequence import numpy as np -from typing_extensions import Protocol from ..types import Float32NDArray, NDArray, ObservationSequence from .components import Episode, EpisodeBase diff --git a/d3rlpy/dataset/trajectory_slicers.py b/d3rlpy/dataset/trajectory_slicers.py index c8c5d23d..01853008 100644 --- a/d3rlpy/dataset/trajectory_slicers.py +++ b/d3rlpy/dataset/trajectory_slicers.py @@ -1,5 +1,6 @@ +from typing import Protocol + import numpy as np -from typing_extensions import Protocol from ..types import Float32NDArray, Int32NDArray from .components import EpisodeBase, PartialTrajectory diff --git a/d3rlpy/dataset/transition_pickers.py b/d3rlpy/dataset/transition_pickers.py index f390e3d7..0b059a10 100644 --- a/d3rlpy/dataset/transition_pickers.py +++ b/d3rlpy/dataset/transition_pickers.py @@ -1,7 +1,7 @@ import dataclasses +from typing import Protocol import numpy as np -from typing_extensions import Protocol from ..types import Float32NDArray from .components import EpisodeBase, Transition diff --git a/d3rlpy/dataset/writers.py b/d3rlpy/dataset/writers.py index 5b4a8621..eab4649f 100644 --- a/d3rlpy/dataset/writers.py +++ b/d3rlpy/dataset/writers.py @@ -1,7 +1,6 @@ -from typing import Any, Sequence, Union +from typing import Any, Protocol, Sequence, Union import numpy as np -from typing_extensions import Protocol from ..types import NDArray, Observation, ObservationSequence from .buffers import BufferProtocol diff --git a/d3rlpy/interface.py b/d3rlpy/interface.py index 6ec9d728..65d8f59f 100644 --- a/d3rlpy/interface.py +++ b/d3rlpy/interface.py @@ -1,6 +1,4 @@ -from typing import Optional, Union - -from typing_extensions import Protocol +from typing import Optional, Protocol, Union from .preprocessing import ActionScaler, ObservationScaler, RewardScaler from .types import NDArray, Observation diff --git a/d3rlpy/logging/logger.py b/d3rlpy/logging/logger.py index 6811a8b3..87507430 100644 --- a/d3rlpy/logging/logger.py +++ b/d3rlpy/logging/logger.py @@ -2,11 +2,10 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime -from typing import Any, Iterator, Optional +from typing import Any, Iterator, Optional, Protocol import structlog from torch import nn -from typing_extensions import Protocol from ..types import Float32NDArray diff --git a/d3rlpy/metrics/evaluators.py b/d3rlpy/metrics/evaluators.py index 6aef7385..d2398e2c 100644 --- a/d3rlpy/metrics/evaluators.py +++ b/d3rlpy/metrics/evaluators.py @@ -1,7 +1,6 @@ -from typing import Iterator, Optional, Sequence +from typing import Iterator, Optional, Protocol, Sequence import numpy as np -from typing_extensions import Protocol from ..dataset import ( EpisodeBase, diff --git a/d3rlpy/tokenizers/tokenizers.py b/d3rlpy/tokenizers/tokenizers.py index 36c068e3..a2998304 100644 --- a/d3rlpy/tokenizers/tokenizers.py +++ b/d3rlpy/tokenizers/tokenizers.py @@ -1,5 +1,6 @@ +from typing import Protocol, runtime_checkable + import numpy as np -from typing_extensions import Protocol, runtime_checkable from ..types import Float32NDArray, Int32NDArray, NDArray from .utils import mu_law_decode, mu_law_encode diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 8fcd19ae..e7b2a354 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -6,6 +6,7 @@ Generic, Iterator, Optional, + Protocol, Sequence, TypeVar, Union, @@ -19,7 +20,7 @@ from torch.cuda import CUDAGraph from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer -from typing_extensions import Protocol, Self +from typing_extensions import Self from .dataclass_utils import asdict_without_copy from .dataset import TrajectoryMiniBatch, TransitionMiniBatch diff --git a/d3rlpy/types.py b/d3rlpy/types.py index 02687b73..9528f3d2 100644 --- a/d3rlpy/types.py +++ b/d3rlpy/types.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Sequence, Union +from typing import Any, Mapping, Protocol, Sequence, Union, runtime_checkable import gym import gymnasium @@ -6,7 +6,6 @@ import numpy.typing as npt import torch from torch.optim import Optimizer -from typing_extensions import Protocol, runtime_checkable __all__ = [ "NDArray", diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..36851f55 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,25 @@ +line-length = 80 +indent-width = 4 +target-version = "py39" +unsafe-fixes = true + +[lint] +select = ["E4", "E7", "E9", "F", "UP006", "I", "W"] +ignore = ["F403"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +[format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto"