From 7b842fb2ab995ab70a33dd026272cb4df44701c2 Mon Sep 17 00:00:00 2001 From: Chunnien Chan <121328115+chunnienc@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:14:04 -0800 Subject: [PATCH] rollback import branches --- experimental/torch_xla2/torch_xla2/ops/op_base.py | 7 +------ experimental/torch_xla2/torch_xla2/types.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/experimental/torch_xla2/torch_xla2/ops/op_base.py b/experimental/torch_xla2/torch_xla2/ops/op_base.py index 203ec5a3686..f81ab2487ed 100644 --- a/experimental/torch_xla2/torch_xla2/ops/op_base.py +++ b/experimental/torch_xla2/torch_xla2/ops/op_base.py @@ -7,12 +7,7 @@ from torch_xla2 import types import sys -if sys.version_info < (3, 10): - from typing_extensions import ParamSpec, Concatenate -else: - from typing import ParamSpec, Concatenate - -from typing import Callable, Optional +from typing import Callable, Optional, ParamSpec, Concatenate class InplaceOp: diff --git a/experimental/torch_xla2/torch_xla2/types.py b/experimental/torch_xla2/torch_xla2/types.py index fef65290671..72a2f678c96 100644 --- a/experimental/torch_xla2/torch_xla2/types.py +++ b/experimental/torch_xla2/torch_xla2/types.py @@ -1,14 +1,9 @@ -from typing import Callable, Any, Union +from typing import Callable, Any, Union, ParamSpec, TypeAlias import torch import jax import jax.numpy as jnp import sys -if sys.version_info < (3, 10): - from typing_extensions import ParamSpec, TypeAlias -else: - from typing import ParamSpec, TypeAlias - P = ParamSpec('P') TorchValue: TypeAlias = Union[torch.Tensor, torch.dtype, 'TorchCallable', Any]