From 12e9b39e9f4e605e946bb38d98af2bfc8377e1a3 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Wed, 15 Jan 2025 08:40:59 +0100 Subject: [PATCH] fix import in utils --- narwhals/_spark_like/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index f69745247..dc49a8466 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -4,14 +4,13 @@ from typing import TYPE_CHECKING from typing import Any -from pyspark.sql import types as pyspark_types - from narwhals.exceptions import InvalidIntoExprError from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from pyspark.sql import Column + from pyspark.sql import types as pyspark_types from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.typing import IntoSparkLikeExpr @@ -24,9 +23,10 @@ def native_to_narwhals_dtype( dtype: pyspark_types.DataType, version: Version, ) -> DType: # pragma: no cover - dtypes = import_dtypes_module(version=version) from pyspark.sql import types as pyspark_types + dtypes = import_dtypes_module(version=version) + if isinstance(dtype, pyspark_types.DoubleType): return dtypes.Float64() if isinstance(dtype, pyspark_types.FloatType): @@ -65,7 +65,10 @@ def native_to_narwhals_dtype( def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version ) -> pyspark_types.DataType: + from pyspark.sql import types as pyspark_types + dtypes = import_dtypes_module(version) + if isinstance_or_issubclass(dtype, dtypes.Float64): return pyspark_types.DoubleType() if isinstance_or_issubclass(dtype, dtypes.Float32):