Skip to content

Commit

Permalink
Move Protocol to typing from typing_extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent 588c6d8 commit 48e2ac1
Show file tree
Hide file tree
Showing 15 changed files with 42 additions and 25 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/utility.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Protocol

import torch
from typing_extensions import Protocol

from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/algos/transformer/action_samplers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Union
from typing import Protocol, Union

import numpy as np
from typing_extensions import Protocol

from ...types import NDArray

Expand Down
4 changes: 1 addition & 3 deletions d3rlpy/dataset/buffers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from collections import deque
from typing import Sequence

from typing_extensions import Protocol
from typing import Protocol, Sequence

from .components import EpisodeBase

Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/dataset/components.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import dataclasses
from typing import Any, Sequence
from typing import Any, Protocol, Sequence

import numpy as np
from typing_extensions import Protocol

from ..constants import ActionSpace
from ..types import (
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/dataset/episode_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional, Sequence
from typing import Optional, Protocol, Sequence

import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray, NDArray, ObservationSequence
from .components import Episode, EpisodeBase
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/dataset/trajectory_slicers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Protocol

import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray, Int32NDArray
from .components import EpisodeBase, PartialTrajectory
Expand Down
2 changes: 1 addition & 1 deletion d3rlpy/dataset/transition_pickers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from typing import Protocol

import numpy as np
from typing_extensions import Protocol

from ..types import Float32NDArray
from .components import EpisodeBase, Transition
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/dataset/writers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Sequence, Union
from typing import Any, Protocol, Sequence, Union

import numpy as np
from typing_extensions import Protocol

from ..types import NDArray, Observation, ObservationSequence
from .buffers import BufferProtocol
Expand Down
4 changes: 1 addition & 3 deletions d3rlpy/interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import Optional, Union

from typing_extensions import Protocol
from typing import Optional, Protocol, Union

from .preprocessing import ActionScaler, ObservationScaler, RewardScaler
from .types import NDArray, Observation
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from typing import Any, Iterator, Optional
from typing import Any, Iterator, Optional, Protocol

import structlog
from torch import nn
from typing_extensions import Protocol

from ..types import Float32NDArray

Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/metrics/evaluators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Iterator, Optional, Sequence
from typing import Iterator, Optional, Protocol, Sequence

import numpy as np
from typing_extensions import Protocol

from ..dataset import (
EpisodeBase,
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/tokenizers/tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Protocol, runtime_checkable

import numpy as np
from typing_extensions import Protocol, runtime_checkable

from ..types import Float32NDArray, Int32NDArray, NDArray
from .utils import mu_law_decode, mu_law_encode
Expand Down
3 changes: 2 additions & 1 deletion d3rlpy/torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Generic,
Iterator,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
Expand All @@ -19,7 +20,7 @@
from torch.cuda import CUDAGraph
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from typing_extensions import Protocol, Self
from typing_extensions import Self

from .dataclass_utils import asdict_without_copy
from .dataset import TrajectoryMiniBatch, TransitionMiniBatch
Expand Down
3 changes: 1 addition & 2 deletions d3rlpy/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, Mapping, Sequence, Union
from typing import Any, Mapping, Protocol, Sequence, Union, runtime_checkable

import gym
import gymnasium
import numpy as np
import numpy.typing as npt
import torch
from torch.optim import Optimizer
from typing_extensions import Protocol, runtime_checkable

__all__ = [
"NDArray",
Expand Down
25 changes: 25 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
line-length = 80
indent-width = 4
target-version = "py39"
unsafe-fixes = true

[lint]
select = ["E4", "E7", "E9", "F", "UP006", "I", "W"]
ignore = ["F403"]

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

[format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

0 comments on commit 48e2ac1

Please sign in to comment.