diff --git a/.gitignore b/.gitignore index 153c7f59744..619e1464b2a 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,7 @@ jupyter_execute # clang tooling compile_commands.json .clangd/ + +# pytest artifacts +rmm_log.txt +python/cudf/cudf_pandas_tests/data/rmm_log.txt diff --git a/ci/cudf_pandas_scripts/run_tests.sh b/ci/cudf_pandas_scripts/run_tests.sh index 5f34a7d1ed0..fd15020b51b 100755 --- a/ci/cudf_pandas_scripts/run_tests.sh +++ b/ci/cudf_pandas_scripts/run_tests.sh @@ -61,6 +61,9 @@ else "$(echo ./dist/pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" fi +python -m pip install ipykernel +python -m ipykernel install --user --name python3 + # We're ignoring third-party library tests because they are run nightly in a seperate CI job python -m pytest -p cudf.pandas \ --ignore=./python/cudf/cudf_pandas_tests/third_party_integration_tests/ \ diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 96596958636..7f6967d7287 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -37,6 +37,7 @@ dependencies: - hypothesis - identify>=2.5.20 - ipython +- jupyter_client - libcufile-dev=1.4.0.31 - libcufile=1.4.0.31 - libcurand-dev=10.3.0.86 @@ -48,6 +49,8 @@ dependencies: - moto>=4.0.8 - msgpack-python - myst-nb +- nbconvert +- nbformat - nbsphinx - ninja - notebook @@ -57,12 +60,14 @@ dependencies: - nvcc_linux-64=11.8 - nvcomp==3.0.6 - nvtx>=0.2.1 +- openpyxl - packaging - pandas - pandas>=2.0,<2.2.3dev0 - pandoc - pre-commit - ptxcompiler +- pyarrow>=14.0.0,<18.0.0a0 - pydata-sphinx-theme!=0.14.2 - pytest-benchmark - pytest-cases>=3.8.2 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index efc5f76b90f..c1315e73f16 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -38,6 +38,7 @@ dependencies: - hypothesis - identify>=2.5.20 - ipython +- jupyter_client - libcufile-dev - libcurand-dev - libkvikio==24.10.*,>=0.0.0a0 @@ -47,6 +48,8 @@ dependencies: - moto>=4.0.8 - msgpack-python - myst-nb +- nbconvert +- nbformat - nbsphinx - ninja - notebook @@ -55,11 +58,13 @@ dependencies: - numpydoc - nvcomp==3.0.6 - nvtx>=0.2.1 +- openpyxl - packaging - pandas - pandas>=2.0,<2.2.3dev0 - pandoc - pre-commit +- pyarrow>=14.0.0,<18.0.0a0 - pydata-sphinx-theme!=0.14.2 - pynvjitlink>=0.0.0a0 - pytest-benchmark diff --git a/conda/recipes/cudf/meta.yaml b/conda/recipes/cudf/meta.yaml index 53f52a35651..e22b4a4eddc 100644 --- a/conda/recipes/cudf/meta.yaml +++ b/conda/recipes/cudf/meta.yaml @@ -82,7 +82,7 @@ requirements: - cupy >=12.0.0 - numba >=0.57 - numpy >=1.23,<3.0a0 - - pyarrow ==16.1.0.* + - pyarrow>=14.0.0,<18.0.0a0 - libcudf ={{ version }} - pylibcudf ={{ version }} - {{ pin_compatible('rmm', max_pin='x.x') }} diff --git a/conda/recipes/pylibcudf/meta.yaml b/conda/recipes/pylibcudf/meta.yaml index 67b9b76bb8c..7c1efa0176c 100644 --- a/conda/recipes/pylibcudf/meta.yaml +++ b/conda/recipes/pylibcudf/meta.yaml @@ -79,7 +79,7 @@ requirements: - typing_extensions >=4.0.0 - pandas >=2.0,<2.2.3dev0 - numpy >=1.23,<3.0a0 - - pyarrow ==16.1.0.* + - pyarrow>=14.0.0,<18.0.0a0 - {{ pin_compatible('rmm', max_pin='x.x') }} - fsspec >=0.6.0 {% if cuda_major == "11" %} diff --git a/dependencies.yaml b/dependencies.yaml index b55860815bf..c6851d9cb90 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -19,6 +19,7 @@ files: - docs - notebooks - py_version + - pyarrow_run - rapids_build_skbuild - rapids_build_setuptools - run_common @@ -31,6 +32,7 @@ files: - test_python_cudf - test_python_dask_cudf - test_python_pylibcudf + - test_python_cudf_pandas test_static_build: output: none includes: @@ -45,10 +47,10 @@ files: includes: - cuda_version - py_version - - pyarrow_run - test_python_common - test_python_cudf - test_python_dask_cudf + - test_python_cudf_pandas test_java: output: none includes: @@ -134,13 +136,6 @@ files: - build_base - build_cpp - depends_on_librmm - py_run_libcudf: - output: pyproject - pyproject_dir: python/libcudf - extras: - table: project - includes: - - pyarrow_run py_build_pylibcudf: output: pyproject pyproject_dir: python/pylibcudf @@ -388,8 +383,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - # Allow runtime version to float up to patch version - - pyarrow>=16.1.0,<16.2.0a0 + - pyarrow>=14.0.0,<18.0.0a0 cuda_version: specific: - output_types: conda @@ -934,9 +928,13 @@ dependencies: # installation issues with `psycopg2`. - pandas[test, pyarrow, performance, computation, fss, excel, parquet, feather, hdf5, spss, html, xml, plot, output-formatting, clipboard, compression] - pytest-reportlog + - ipython test_python_cudf_pandas: common: - - output_types: [requirements, pyproject] + - output_types: [conda, requirements, pyproject] packages: - ipython + - jupyter_client + - nbconvert + - nbformat - openpyxl diff --git a/docs/cudf/source/user_guide/10min.ipynb b/docs/cudf/source/user_guide/10min.ipynb index c3da2558db8..2eaa75b3189 100644 --- a/docs/cudf/source/user_guide/10min.ipynb +++ b/docs/cudf/source/user_guide/10min.ipynb @@ -15,7 +15,11 @@ "\n", "[Dask](https://dask.org/) is a flexible library for parallel computing in Python that makes scaling out your workflow smooth and simple. On the CPU, Dask uses Pandas to execute operations in parallel on DataFrame partitions.\n", "\n", - "[Dask-cuDF](https://github.com/rapidsai/cudf/tree/main/python/dask_cudf) extends Dask where necessary to allow its DataFrame partitions to be processed using cuDF GPU DataFrames instead of Pandas DataFrames. For instance, when you call `dask_cudf.read_csv(...)`, your cluster's GPUs do the work of parsing the CSV file(s) by calling [`cudf.read_csv()`](https://docs.rapids.ai/api/cudf/stable/api_docs/api/cudf.read_csv.html).\n", + "[Dask cuDF](https://github.com/rapidsai/cudf/tree/main/python/dask_cudf) extends Dask where necessary to allow its DataFrame partitions to be processed using cuDF GPU DataFrames instead of Pandas DataFrames. For instance, when you call `dask_cudf.read_csv(...)`, your cluster's GPUs do the work of parsing the CSV file(s) by calling [`cudf.read_csv()`](https://docs.rapids.ai/api/cudf/stable/api_docs/api/cudf.read_csv.html).\n", + "\n", + "\n", + "> [!NOTE] \n", + "> This notebook uses the explicit Dask cuDF API (`dask_cudf`) for clarity. However, we strongly recommend that you use Dask's [configuration infrastructure](https://docs.dask.org/en/latest/configuration.html) to set the `\"dataframe.backend\"` to `\"cudf\"`, and work with the `dask.dataframe` API directly. Please see the [Dask cuDF documentation](https://github.com/rapidsai/cudf/tree/main/python/dask_cudf) for more information.\n", "\n", "\n", "## When to use cuDF and Dask-cuDF\n", diff --git a/pyproject.toml b/pyproject.toml index e15cb7b3cdd..8f9aa165e5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,9 @@ select = [ # non-pep585-annotation "UP006", # non-pep604-annotation - "UP007" + "UP007", + # Import from `collections.abc` instead: `Callable` + "UP035", ] ignore = [ # whitespace before : diff --git a/python/cudf/cudf/_typing.py b/python/cudf/cudf/_typing.py index 34c96cc8cb3..6e8ad556b08 100644 --- a/python/cudf/cudf/_typing.py +++ b/python/cudf/cudf/_typing.py @@ -1,7 +1,8 @@ # Copyright (c) 2021-2024, NVIDIA CORPORATION. import sys -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, TypeVar, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Dict, Iterable, TypeVar, Union import numpy as np from pandas import Period, Timedelta, Timestamp diff --git a/python/cudf/cudf/core/_internals/where.py b/python/cudf/cudf/core/_internals/where.py index 0c754317185..2199d4d5ba5 100644 --- a/python/cudf/cudf/core/_internals/where.py +++ b/python/cudf/cudf/core/_internals/where.py @@ -106,19 +106,6 @@ def _check_and_cast_columns_with_other( return _normalize_categorical(source_col.astype(common_dtype), other) -def _make_categorical_like(result, column): - if isinstance(column, cudf.core.column.CategoricalColumn): - result = cudf.core.column.build_categorical_column( - categories=column.categories, - codes=result, - mask=result.base_mask, - size=result.size, - offset=result.offset, - ordered=column.ordered, - ) - return result - - def _can_cast(from_dtype, to_dtype): """ Utility function to determine if we can cast diff --git a/python/cudf/cudf/core/column/__init__.py b/python/cudf/cudf/core/column/__init__.py index e7119fcdf47..5781d77ee9a 100644 --- a/python/cudf/cudf/core/column/__init__.py +++ b/python/cudf/cudf/core/column/__init__.py @@ -8,7 +8,6 @@ from cudf.core.column.column import ( ColumnBase, as_column, - build_categorical_column, build_column, column_empty, column_empty_like, diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index a7e98e5218f..de5ed15771d 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -52,6 +52,15 @@ _DEFAULT_CATEGORICAL_VALUE = np.int8(-1) +def as_unsigned_codes( + num_cats: int, codes: NumericalColumn +) -> NumericalColumn: + codes_dtype = min_unsigned_type(num_cats) + return cast( + cudf.core.column.numerical.NumericalColumn, codes.astype(codes_dtype) + ) + + class CategoricalAccessor(ColumnMethods): """ Accessor object for categorical properties of the Series values. @@ -637,13 +646,12 @@ def __setitem__(self, key, value): value = value.codes codes = self.codes codes[key] = value - out = cudf.core.column.build_categorical_column( - categories=self.categories, - codes=codes, - mask=codes.base_mask, + out = type(self)( + data=self.data, size=codes.size, - offset=self.offset, - ordered=self.ordered, + dtype=self.dtype, + mask=codes.base_mask, + children=(codes,), ) self._mimic_inplace(out, inplace=True) @@ -669,16 +677,13 @@ def _fill( def slice(self, start: int, stop: int, stride: int | None = None) -> Self: codes = self.codes.slice(start, stop, stride) - return cast( - Self, - cudf.core.column.build_categorical_column( - categories=self.categories, - codes=codes, - mask=codes.base_mask, - ordered=self.ordered, - size=codes.size, - offset=codes.offset, - ), + return type(self)( + data=self.data, # type: ignore[arg-type] + size=codes.size, + dtype=self.dtype, + mask=codes.base_mask, + offset=codes.offset, + children=(codes,), ) def _reduce( @@ -719,7 +724,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase: ) return self.codes._binaryop(other.codes, op) - def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn: + def normalize_binop_value(self, other: ScalarLike) -> Self: if isinstance(other, column.ColumnBase): if not isinstance(other, CategoricalColumn): return NotImplemented @@ -727,30 +732,27 @@ def normalize_binop_value(self, other: ScalarLike) -> CategoricalColumn: raise TypeError( "Categoricals can only compare with the same type" ) - return other - - ary = column.as_column( + return cast(Self, other) + codes = column.as_column( self._encode(other), length=len(self), dtype=self.codes.dtype ) - return column.build_categorical_column( - categories=self.dtype.categories._values, - codes=column.as_column(ary), + return type(self)( + data=None, + size=self.size, + dtype=self.dtype, mask=self.base_mask, - ordered=self.dtype.ordered, + children=(codes,), # type: ignore[arg-type] ) - def sort_values( - self, ascending: bool = True, na_position="last" - ) -> CategoricalColumn: + def sort_values(self, ascending: bool = True, na_position="last") -> Self: codes = self.codes.sort_values(ascending, na_position) - col = column.build_categorical_column( - categories=self.dtype.categories._values, - codes=codes, - mask=codes.base_mask, + return type(self)( + data=self.data, # type: ignore[arg-type] size=codes.size, - ordered=self.dtype.ordered, + dtype=self.dtype, + mask=codes.base_mask, + children=(codes,), ) - return col def element_indexing(self, index: int) -> ScalarLike: val = self.codes.element_indexing(index) @@ -777,12 +779,12 @@ def to_pandas( if self.categories.dtype.kind == "f": new_mask = bools_to_mask(self.notnull()) - col = column.build_categorical_column( - categories=self.categories, - codes=column.as_column(self.codes, dtype=self.codes.dtype), + col = type(self)( + data=self.data, # type: ignore[arg-type] + size=self.size, + dtype=self.dtype, mask=new_mask, - ordered=self.dtype.ordered, - size=self.codes.size, + children=self.children, ) else: col = self @@ -849,15 +851,15 @@ def data_array_view( ) -> numba.cuda.devicearray.DeviceNDArray: return self.codes.data_array_view(mode=mode) - def unique(self) -> CategoricalColumn: + def unique(self) -> Self: codes = self.codes.unique() - return column.build_categorical_column( - categories=self.categories, - codes=codes, + return type(self)( + data=self.data, # type: ignore[arg-type] + size=codes.size, + dtype=self.dtype, mask=codes.base_mask, offset=codes.offset, - size=codes.size, - ordered=self.ordered, + children=(codes,), ) def _encode(self, value) -> ScalarLike: @@ -988,14 +990,17 @@ def find_and_replace( output = libcudf.replace.replace( replaced_codes, to_replace_col, replacement_col ) + codes = as_unsigned_codes(len(new_cats["cats"]), output) - result = column.build_categorical_column( - categories=new_cats["cats"], - codes=output, - mask=output.base_mask, - offset=output.offset, - size=output.size, - ordered=self.dtype.ordered, + result = type(self)( + data=self.data, # type: ignore[arg-type] + size=codes.size, + dtype=CategoricalDtype( + categories=new_cats["cats"], ordered=self.dtype.ordered + ), + mask=codes.base_mask, + offset=codes.offset, + children=(codes,), ) if result.dtype != self.dtype: warnings.warn( @@ -1082,7 +1087,7 @@ def is_monotonic_increasing(self) -> bool: def is_monotonic_decreasing(self) -> bool: return bool(self.ordered) and self.codes.is_monotonic_decreasing - def as_categorical_column(self, dtype: Dtype) -> CategoricalColumn: + def as_categorical_column(self, dtype: Dtype) -> Self: if isinstance(dtype, str) and dtype == "category": return self if isinstance(dtype, pd.CategoricalDtype): @@ -1099,7 +1104,23 @@ def as_categorical_column(self, dtype: Dtype) -> CategoricalColumn: if not isinstance(self.categories, type(dtype.categories._column)): # If both categories are of different Column types, # return a column full of Nulls. - return _create_empty_categorical_column(self, dtype) + codes = cast( + cudf.core.column.numerical.NumericalColumn, + column.as_column( + _DEFAULT_CATEGORICAL_VALUE, + length=self.size, + dtype=self.codes.dtype, + ), + ) + codes = as_unsigned_codes(len(dtype.categories), codes) + return type(self)( + data=self.data, # type: ignore[arg-type] + size=self.size, + dtype=dtype, + mask=self.base_mask, + offset=self.offset, + children=(codes,), + ) return self.set_categories( new_categories=dtype.categories, ordered=bool(dtype.ordered) @@ -1185,26 +1206,29 @@ def _concat( codes = [o for o in codes if len(o)] codes_col = libcudf.concat.concat_columns(objs) - return column.build_categorical_column( - categories=column.as_column(cats), - codes=codes_col, - mask=codes_col.base_mask, + codes_col = as_unsigned_codes( + len(cats), + cast(cudf.core.column.numerical.NumericalColumn, codes_col), + ) + return CategoricalColumn( + data=None, size=codes_col.size, + dtype=CategoricalDtype(categories=cats), + mask=codes_col.base_mask, offset=codes_col.offset, + children=(codes_col,), # type: ignore[arg-type] ) - def _with_type_metadata( - self: CategoricalColumn, dtype: Dtype - ) -> CategoricalColumn: + def _with_type_metadata(self: Self, dtype: Dtype) -> Self: if isinstance(dtype, CategoricalDtype): - return column.build_categorical_column( - categories=dtype.categories._values, - codes=self.codes, - mask=self.codes.base_mask, - ordered=dtype.ordered, + return type(self)( + data=self.data, # type: ignore[arg-type] size=self.codes.size, + dtype=dtype, + mask=self.codes.base_mask, offset=self.codes.offset, null_count=self.codes.null_count, + children=(self.codes,), ) return self @@ -1213,7 +1237,7 @@ def set_categories( new_categories: Any, ordered: bool = False, rename: bool = False, - ) -> CategoricalColumn: + ) -> Self: # See CategoricalAccessor.set_categories. ordered = ordered if ordered is not None else self.ordered @@ -1232,25 +1256,39 @@ def set_categories( "new_categories must have the same " "number of items as old categories" ) - - out_col = column.build_categorical_column( - categories=new_categories, - codes=self.base_children[0], - mask=self.base_mask, + out_col = type(self)( + data=self.data, # type: ignore[arg-type] size=self.size, + dtype=CategoricalDtype( + categories=new_categories, ordered=ordered + ), + mask=self.base_mask, offset=self.offset, - ordered=ordered, + children=(self.codes,), ) else: out_col = self if type(out_col.categories) is not type(new_categories): # If both categories are of different Column types, # return a column full of Nulls. - out_col = _create_empty_categorical_column( - self, - CategoricalDtype( + new_codes = cast( + cudf.core.column.numerical.NumericalColumn, + column.as_column( + _DEFAULT_CATEGORICAL_VALUE, + length=self.size, + dtype=self.codes.dtype, + ), + ) + new_codes = as_unsigned_codes(len(new_categories), new_codes) + out_col = type(self)( + data=self.data, # type: ignore[arg-type] + size=self.size, + dtype=CategoricalDtype( categories=new_categories, ordered=ordered ), + mask=self.base_mask, + offset=self.offset, + children=(new_codes,), ) elif ( not out_col._categories_equal(new_categories, ordered=True) @@ -1335,19 +1373,19 @@ def _set_categories( df.reset_index(drop=True, inplace=True) ordered = ordered if ordered is not None else self.ordered - new_codes = df._data["new_codes"] + new_codes = cast( + cudf.core.column.numerical.NumericalColumn, df._data["new_codes"] + ) # codes can't have masks, so take mask out before moving in - return cast( - Self, - column.build_categorical_column( - categories=new_cats, - codes=new_codes, - mask=new_codes.base_mask, - size=new_codes.size, - offset=new_codes.offset, - ordered=ordered, - ), + new_codes = as_unsigned_codes(len(new_cats), new_codes) + return type(self)( + data=self.data, # type: ignore[arg-type] + size=new_codes.size, + dtype=CategoricalDtype(categories=new_cats, ordered=ordered), + mask=new_codes.base_mask, + offset=new_codes.offset, + children=(new_codes,), ) def add_categories(self, new_categories: Any) -> Self: @@ -1425,56 +1463,16 @@ def remove_unused_categories(self) -> Self: "remove_unused_categories is currently not supported." ) - def as_ordered(self, ordered: bool): + def as_ordered(self, ordered: bool) -> Self: if self.dtype.ordered == ordered: return self - return column.build_categorical_column( - categories=self.categories, - codes=self.codes, - mask=self.base_mask, + return type(self)( + data=self.data, # type: ignore[arg-type] size=self.size, + dtype=CategoricalDtype( + categories=self.categories, ordered=ordered + ), + mask=self.base_mask, offset=self.offset, - ordered=ordered, + children=self.children, ) - - -def _create_empty_categorical_column( - categorical_column: CategoricalColumn, dtype: "CategoricalDtype" -) -> CategoricalColumn: - return column.build_categorical_column( - categories=column.as_column(dtype.categories), - codes=column.as_column( - _DEFAULT_CATEGORICAL_VALUE, - length=categorical_column.size, - dtype=categorical_column.codes.dtype, - ), - offset=categorical_column.offset, - size=categorical_column.size, - mask=categorical_column.base_mask, - ordered=dtype.ordered, - ) - - -def pandas_categorical_as_column( - categorical: ColumnLike, codes: ColumnLike | None = None -) -> CategoricalColumn: - """Creates a CategoricalColumn from a pandas.Categorical - - If ``codes`` is defined, use it instead of ``categorical.codes`` - """ - codes = categorical.codes if codes is None else codes - codes = column.as_column(codes) - - valid_codes = codes != codes.dtype.type(_DEFAULT_CATEGORICAL_VALUE) - - mask = None - if not valid_codes.all(): - mask = bools_to_mask(valid_codes) - - return column.build_categorical_column( - categories=categorical.categories, - codes=codes, - size=codes.size, - mask=mask, - ordered=categorical.ordered, - ) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 60b4126ddd4..885476a897c 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -352,13 +352,17 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase: codes = libcudf.interop.from_arrow(indices_table)[0] categories = libcudf.interop.from_arrow(dictionaries_table)[0] - - return build_categorical_column( - categories=categories, - codes=codes, - mask=codes.base_mask, + codes = cudf.core.column.categorical.as_unsigned_codes( + len(categories), codes + ) + return cudf.core.column.CategoricalColumn( + data=None, size=codes.size, - ordered=array.type.ordered, + dtype=CategoricalDtype( + categories=categories, ordered=array.type.ordered + ), + mask=codes.base_mask, + children=(codes,), ) result = libcudf.interop.from_arrow(data)[0] @@ -950,10 +954,10 @@ def is_monotonic_decreasing(self) -> bool: ) def sort_values( - self: ColumnBase, + self: Self, ascending: bool = True, na_position: str = "last", - ) -> ColumnBase: + ) -> Self: if (not ascending and self.is_monotonic_decreasing) or ( ascending and self.is_monotonic_increasing ): @@ -1041,12 +1045,16 @@ def as_categorical_column(self, dtype) -> ColumnBase: and dtype._categories is not None ): cat_col = dtype._categories - labels = self._label_encoding(cats=cat_col) - return build_categorical_column( - categories=cat_col, - codes=labels, + codes = self._label_encoding(cats=cat_col) + codes = cudf.core.column.categorical.as_unsigned_codes( + len(cat_col), codes + ) + return cudf.core.column.categorical.CategoricalColumn( + data=None, + size=None, + dtype=dtype, mask=self.mask, - ordered=dtype.ordered, + children=(codes,), ) # Categories must be unique and sorted in ascending order. @@ -1058,15 +1066,16 @@ def as_categorical_column(self, dtype) -> ColumnBase: # columns include null index in factorization; remove: if self.has_nulls(): cats = cats.dropna() - min_type = min_unsigned_type(len(cats), 8) - if cudf.dtype(min_type).itemsize < labels.dtype.itemsize: - labels = labels.astype(min_type) - return build_categorical_column( - categories=cats, - codes=labels, + labels = cudf.core.column.categorical.as_unsigned_codes( + len(cats), labels + ) + return cudf.core.column.categorical.CategoricalColumn( + data=None, + size=None, + dtype=CategoricalDtype(categories=cats, ordered=ordered), mask=self.mask, - ordered=ordered, + children=(labels,), ) def as_numerical_column( @@ -1186,7 +1195,7 @@ def searchsorted( na_position=na_position, ) - def unique(self) -> ColumnBase: + def unique(self) -> Self: """ Get unique values in the data """ @@ -1695,51 +1704,6 @@ def build_column( raise TypeError(f"Unrecognized dtype: {dtype}") -def build_categorical_column( - categories: ColumnBase, - codes: ColumnBase, - mask: Buffer | None = None, - size: int | None = None, - offset: int = 0, - null_count: int | None = None, - ordered: bool = False, -) -> "cudf.core.column.CategoricalColumn": - """ - Build a CategoricalColumn - - Parameters - ---------- - categories : Column - Column of categories - codes : Column - Column of codes, the size of the resulting Column will be - the size of `codes` - mask : Buffer - Null mask - size : int, optional - offset : int, optional - ordered : bool, default False - Indicates whether the categories are ordered - """ - codes_dtype = min_unsigned_type(len(categories)) - codes = as_column(codes) - if codes.dtype != codes_dtype: - codes = codes.astype(codes_dtype) - - dtype = CategoricalDtype(categories=categories, ordered=ordered) - - result = build_column( - data=None, - dtype=dtype, - mask=mask, - size=size, - offset=offset, - null_count=null_count, - children=(codes,), - ) - return cast("cudf.core.column.CategoricalColumn", result) - - def check_invalid_array(shape: tuple, dtype): """Invalid ndarrays properties that are not supported""" if len(shape) > 1: diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index 90bec049831..78d2814ed26 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -3,7 +3,7 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Any, Callable, Sequence, cast +from typing import TYPE_CHECKING, Any, Sequence, cast import numpy as np import pandas as pd @@ -28,6 +28,8 @@ from .numerical_base import NumericalBaseColumn if TYPE_CHECKING: + from collections.abc import Callable + from cudf._typing import ( ColumnBinaryOperand, ColumnLike, @@ -649,22 +651,20 @@ def can_cast_safely(self, to_dtype: DtypeObj) -> bool: return False - def _with_type_metadata(self: ColumnBase, dtype: Dtype) -> ColumnBase: + def _with_type_metadata(self: Self, dtype: Dtype) -> ColumnBase: if isinstance(dtype, CategoricalDtype): - return column.build_categorical_column( - categories=dtype.categories._values, - codes=cudf.core.column.NumericalColumn( - self.base_data, # type: ignore[arg-type] - self.size, - dtype=self.dtype, - ), - mask=self.base_mask, - ordered=dtype.ordered, + codes = cudf.core.column.categorical.as_unsigned_codes( + len(dtype.categories), self + ) + return cudf.core.column.CategoricalColumn( + data=None, size=self.size, + dtype=dtype, + mask=self.base_mask, offset=self.offset, null_count=self.null_count, + children=(codes,), ) - return self def to_pandas( diff --git a/python/cudf/cudf/core/column_accessor.py b/python/cudf/cudf/core/column_accessor.py index 34076fa0060..09b0f453692 100644 --- a/python/cudf/cudf/core/column_accessor.py +++ b/python/cudf/cudf/core/column_accessor.py @@ -6,7 +6,7 @@ import sys from collections import abc from functools import cached_property, reduce -from typing import TYPE_CHECKING, Any, Callable, Mapping, cast +from typing import TYPE_CHECKING, Any, Mapping, cast import numpy as np import pandas as pd @@ -639,7 +639,7 @@ def _pad_key( def rename_levels( self, - mapper: Mapping[abc.Hashable, abc.Hashable] | Callable, + mapper: Mapping[abc.Hashable, abc.Hashable] | abc.Callable, level: int | None = None, ) -> Self: """ diff --git a/python/cudf/cudf/core/cut.py b/python/cudf/cudf/core/cut.py index a4ceea266b4..c9b1fa2669c 100644 --- a/python/cudf/cudf/core/cut.py +++ b/python/cudf/cudf/core/cut.py @@ -8,7 +8,8 @@ import cudf from cudf.api.types import is_list_like -from cudf.core.column import as_column, build_categorical_column +from cudf.core.column import as_column +from cudf.core.column.categorical import CategoricalColumn, as_unsigned_codes from cudf.core.index import IntervalIndex, interval_range @@ -282,13 +283,17 @@ def cut( # should allow duplicate categories. return interval_labels[index_labels] - col = build_categorical_column( - categories=interval_labels, - codes=index_labels, + index_labels = as_unsigned_codes(len(interval_labels), index_labels) + + col = CategoricalColumn( + data=None, + size=index_labels.size, + dtype=cudf.CategoricalDtype( + categories=interval_labels, ordered=ordered + ), mask=index_labels.base_mask, offset=index_labels.offset, - size=index_labels.size, - ordered=ordered, + children=(index_labels,), ) # we return a categorical index, as we don't have a Categorical method diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 14b63c2b0d7..0d632f4775f 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -13,8 +13,8 @@ import textwrap import warnings from collections import abc, defaultdict -from collections.abc import Iterator -from typing import TYPE_CHECKING, Any, Callable, Literal, MutableMapping, cast +from collections.abc import Callable, Iterator +from typing import TYPE_CHECKING, Any, Literal, MutableMapping, cast import cupy import numba @@ -48,10 +48,10 @@ ColumnBase, StructColumn, as_column, - build_categorical_column, column_empty, concat_columns, ) +from cudf.core.column.categorical import as_unsigned_codes from cudf.core.column_accessor import ColumnAccessor from cudf.core.copy_types import BooleanMask from cudf.core.groupby.groupby import DataFrameGroupBy, groupby_doc_template @@ -414,8 +414,9 @@ def _setitem_tuple_arg(self, key, value): ) else: - value = cupy.asarray(value) - if value.ndim == 2: + if not is_column_like(value): + value = cupy.asarray(value) + if getattr(value, "ndim", 1) == 2: # If the inner dimension is 1, it's broadcastable to # all columns of the dataframe. indexed_shape = columns_df.loc[key[0]].shape @@ -558,8 +559,9 @@ def _setitem_tuple_arg(self, key, value): else: # TODO: consolidate code path with identical counterpart # in `_DataFrameLocIndexer._setitem_tuple_arg` - value = cupy.asarray(value) - if value.ndim == 2: + if not is_column_like(value): + value = cupy.asarray(value) + if getattr(value, "ndim", 1) == 2: indexed_shape = columns_df.iloc[key[0]].shape if value.shape[1] == 1: if value.shape[0] != indexed_shape[0]: @@ -678,7 +680,9 @@ class DataFrame(IndexedFrame, Serializable, GetAttrGetItemMixin): 3 3 0.3 """ - _PROTECTED_KEYS = frozenset(("_data", "_index")) + _PROTECTED_KEYS = frozenset( + ("_data", "_index", "_ipython_canary_method_should_not_exist_") + ) _accessors: set[Any] = set() _loc_indexer_type = _DataFrameLocIndexer _iloc_indexer_type = _DataFrameIlocIndexer @@ -3063,7 +3067,6 @@ def where(self, cond, other=None, inplace=False, axis=None, level=None): from cudf.core._internals.where import ( _check_and_cast_columns_with_other, - _make_categorical_like, ) # First process the condition. @@ -3115,7 +3118,7 @@ def where(self, cond, other=None, inplace=False, axis=None, level=None): out = [] for (name, col), other_col in zip(self._data.items(), other_cols): - col, other_col = _check_and_cast_columns_with_other( + source_col, other_col = _check_and_cast_columns_with_other( source_col=col, other=other_col, inplace=inplace, @@ -3123,16 +3126,16 @@ def where(self, cond, other=None, inplace=False, axis=None, level=None): if cond_col := cond._data.get(name): result = cudf._lib.copying.copy_if_else( - col, other_col, cond_col + source_col, other_col, cond_col ) - out.append(_make_categorical_like(result, self._data[name])) + out.append(result._with_type_metadata(col.dtype)) else: out_mask = cudf._lib.null_mask.create_null_mask( - len(col), + len(source_col), state=cudf._lib.null_mask.MaskState.ALL_NULL, ) - out.append(col.set_mask(out_mask)) + out.append(source_col.set_mask(out_mask)) return self._mimic_inplace( self._from_data_like_self(self._data._from_columns_like_self(out)), @@ -3292,9 +3295,7 @@ def _insert(self, loc, name, value, nan_as_null=None, ignore_index=True): # least require a deprecation cycle because we currently support # inserting a pd.Categorical. if isinstance(value, pd.Categorical): - value = cudf.core.column.categorical.pandas_categorical_as_column( - value - ) + value = as_column(value) if _is_scalar_or_zero_d_array(value): dtype = None @@ -8506,12 +8507,16 @@ def _cast_cols_to_common_dtypes(col_idxs, list_of_columns, dtypes, categories): def _reassign_categories(categories, cols, col_idxs): for name, idx in zip(cols, col_idxs): if idx in categories: - cols[name] = build_categorical_column( - categories=categories[idx], - codes=cols[name], - mask=cols[name].base_mask, - offset=cols[name].offset, - size=cols[name].size, + codes = as_unsigned_codes(len(categories[idx]), cols[name]) + cols[name] = CategoricalColumn( + data=None, + size=codes.size, + dtype=cudf.CategoricalDtype( + categories=categories[idx], ordered=False + ), + mask=codes.base_mask, + offset=codes.offset, + children=(codes,), ) diff --git a/python/cudf/cudf/core/df_protocol.py b/python/cudf/cudf/core/df_protocol.py index a70a42c04af..5250a741d3d 100644 --- a/python/cudf/cudf/core/df_protocol.py +++ b/python/cudf/cudf/core/df_protocol.py @@ -13,7 +13,12 @@ import cudf from cudf.core.buffer import Buffer, as_buffer -from cudf.core.column import as_column, build_categorical_column, build_column +from cudf.core.column import ( + CategoricalColumn, + NumericalColumn, + as_column, + build_column, +) # Implementation of interchange protocol classes # ---------------------------------------------- @@ -830,18 +835,19 @@ def _protocol_to_cudf_column_categorical( assert buffers["data"] is not None, "data buffer should not be None" codes_buffer, codes_dtype = buffers["data"] codes_buffer = _ensure_gpu_buffer(codes_buffer, codes_dtype, allow_copy) - cdtype = protocol_dtype_to_cupy_dtype(codes_dtype) - codes = build_column( - codes_buffer._buf, - cdtype, + cdtype = np.dtype(protocol_dtype_to_cupy_dtype(codes_dtype)) + codes = NumericalColumn( + data=codes_buffer._buf, + size=None, + dtype=cdtype, ) - - cudfcol = build_categorical_column( - categories=categories, - codes=codes, - mask=codes.base_mask, + cudfcol = CategoricalColumn( + data=None, size=codes.size, - ordered=ordered, + dtype=cudf.CategoricalDtype(categories=categories, ordered=ordered), + mask=codes.base_mask, + offset=codes.offset, + children=(codes,), ) return _set_missing_values(col, cudfcol, allow_copy), buffers diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 6d532e01cba..2110e610c37 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -7,7 +7,7 @@ import textwrap import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -27,6 +27,8 @@ PANDAS_NUMPY_DTYPE = pd.core.dtypes.dtypes.PandasDtype if TYPE_CHECKING: + from collections.abc import Callable + from cudf._typing import Dtype from cudf.core.buffer import Buffer diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index 3e1efd7c97a..7b2bc85b13b 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -6,7 +6,7 @@ import pickle import warnings from collections import abc -from typing import TYPE_CHECKING, Any, Callable, Literal, MutableMapping +from typing import TYPE_CHECKING, Any, Literal, MutableMapping # TODO: The `numpy` import is needed for typing purposes during doc builds # only, need to figure out why the `np` alias is insufficient then remove. @@ -24,10 +24,10 @@ from cudf.core.column import ( ColumnBase, as_column, - build_categorical_column, deserialize_columns, serialize_columns, ) +from cudf.core.column.categorical import CategoricalColumn, as_unsigned_codes from cudf.core.column_accessor import ColumnAccessor from cudf.core.mixins import BinaryOperand, Scannable from cudf.utils import ioutils @@ -403,7 +403,7 @@ def __arrow_array__(self, type=None): @_performance_tracking def _to_array( self, - get_array: Callable, + get_array: abc.Callable, module: ModuleType, copy: bool, dtype: Dtype | None = None, @@ -889,18 +889,21 @@ def from_arrow(cls, data: pa.Table) -> Self: for name in dict_dictionaries.keys() } - cudf_category_frame = { - name: build_categorical_column( - cudf_dictionaries_columns[name], - codes, - mask=codes.base_mask, + for name, codes in zip( + dict_indices_table.column_names, indices_columns + ): + categories = cudf_dictionaries_columns[name] + codes = as_unsigned_codes(len(categories), codes) + cudf_category_frame[name] = CategoricalColumn( + data=None, size=codes.size, - ordered=dict_ordered[name], - ) - for name, codes in zip( - dict_indices_table.column_names, indices_columns + dtype=cudf.CategoricalDtype( + categories=categories, + ordered=dict_ordered[name], + ), + mask=codes.base_mask, + children=(codes,), ) - } # Handle non-dict arrays cudf_non_category_frame = { diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index df8af856f4f..500fc580097 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -1443,7 +1443,21 @@ def __repr__(self): output[:break_idx].replace("'", "") + output[break_idx:] ) else: - output = repr(preprocess.to_pandas()) + # Too many non-unique categories will cause + # the output to take too long. In this case, we + # split the categories into data and categories + # and generate the repr separately and + # merge them. + pd_cats = pd.Categorical( + preprocess.astype(preprocess.categories.dtype).to_pandas() + ) + pd_preprocess = pd.CategoricalIndex(pd_cats) + data_repr = repr(pd_preprocess).split("\n") + pd_preprocess.dtype._categories = ( + preprocess.categories.to_pandas() + ) + cats_repr = repr(pd_preprocess).split("\n") + output = "\n".join(data_repr[:-1] + cats_repr[-1:]) output = output.replace("nan", str(cudf.NA)) elif preprocess._values.nullable: @@ -3065,22 +3079,8 @@ def __init__( name = _getdefault_name(data, name=name) if isinstance(data, CategoricalColumn): data = data - elif isinstance(data, pd.Series) and ( - isinstance(data.dtype, pd.CategoricalDtype) - ): - codes_data = column.as_column(data.cat.codes.values) - data = column.build_categorical_column( - categories=data.cat.categories, - codes=codes_data, - ordered=data.cat.ordered, - ) - elif isinstance(data, (pd.Categorical, pd.CategoricalIndex)): - codes_data = column.as_column(data.codes) - data = column.build_categorical_column( - categories=data.categories, - codes=codes_data, - ordered=data.ordered, - ) + elif isinstance(getattr(data, "dtype", None), pd.CategoricalDtype): + data = column.as_column(data) else: data = column.as_column( data, dtype="category" if dtype is None else dtype diff --git a/python/cudf/cudf/core/indexed_frame.py b/python/cudf/cudf/core/indexed_frame.py index ad6aa56d472..fd6bf37f0e6 100644 --- a/python/cudf/cudf/core/indexed_frame.py +++ b/python/cudf/cudf/core/indexed_frame.py @@ -173,17 +173,7 @@ def _drop_columns(f: Frame, columns: abc.Iterable, errors: str): def _indices_from_labels(obj, labels): if not isinstance(labels, cudf.MultiIndex): labels = cudf.core.column.as_column(labels) - - if isinstance(obj.index.dtype, cudf.CategoricalDtype): - labels = labels.astype("category") - codes = labels.codes.astype(obj.index.codes.dtype) - labels = cudf.core.column.build_categorical_column( - categories=labels.dtype.categories, - codes=codes, - ordered=labels.dtype.ordered, - ) - else: - labels = labels.astype(obj.index.dtype) + labels = labels.astype(obj.index.dtype) idx_labels = cudf.Index._from_column(labels) else: idx_labels = labels diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 4be10752651..a831a798772 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -38,7 +38,9 @@ as_column, ) from cudf.core.column.categorical import ( + _DEFAULT_CATEGORICAL_VALUE, CategoricalAccessor as CategoricalAccessor, + CategoricalColumn, ) from cudf.core.column.column import concat_columns from cudf.core.column.lists import ListMethods @@ -511,9 +513,22 @@ def from_categorical(cls, categorical, codes=None): dtype: category Categories (3, object): ['a', 'b', 'c'] """ # noqa: E501 - col = cudf.core.column.categorical.pandas_categorical_as_column( - categorical, codes=codes - ) + col = as_column(categorical) + if codes is not None: + codes = as_column(codes) + + valid_codes = codes != codes.dtype.type(_DEFAULT_CATEGORICAL_VALUE) + + mask = None + if not valid_codes.all(): + mask = libcudf.transform.bools_to_mask(valid_codes) + col = CategoricalColumn( + data=col.data, + size=codes.size, + dtype=col.dtype, + mask=mask, + children=(codes,), + ) return Series._from_column(col) @classmethod diff --git a/python/cudf/cudf/core/single_column_frame.py b/python/cudf/cudf/core/single_column_frame.py index eb6714029cf..55dda34a576 100644 --- a/python/cudf/cudf/core/single_column_frame.py +++ b/python/cudf/cudf/core/single_column_frame.py @@ -350,7 +350,6 @@ def _get_elements_from_column(self, arg) -> ScalarLike | ColumnBase: def where(self, cond, other=None, inplace=False): from cudf.core._internals.where import ( _check_and_cast_columns_with_other, - _make_categorical_like, ) if isinstance(other, cudf.DataFrame): @@ -366,14 +365,12 @@ def where(self, cond, other=None, inplace=False): if not cudf.api.types.is_scalar(other): other = cudf.core.column.as_column(other) - self_column = self._column input_col, other = _check_and_cast_columns_with_other( - source_col=self_column, other=other, inplace=inplace + source_col=self._column, other=other, inplace=inplace ) result = cudf._lib.copying.copy_if_else(input_col, other, cond) - - return _make_categorical_like(result, self_column) + return result._with_type_metadata(self.dtype) @_performance_tracking def transpose(self): diff --git a/python/cudf/cudf/core/udf/utils.py b/python/cudf/cudf/core/udf/utils.py index d616761cb3b..6d7362952c9 100644 --- a/python/cudf/cudf/core/udf/utils.py +++ b/python/cudf/cudf/core/udf/utils.py @@ -3,7 +3,7 @@ import functools import os -from typing import Any, Callable +from typing import TYPE_CHECKING, Any import cachetools import cupy as cp @@ -41,6 +41,9 @@ from cudf.utils.performance_tracking import _performance_tracking from cudf.utils.utils import initfunc +if TYPE_CHECKING: + from collections.abc import Callable + # Maximum size of a string column is 2 GiB _STRINGS_UDF_DEFAULT_HEAP_SIZE = os.environ.get("STRINGS_UDF_HEAP_SIZE", 2**31) _heap_size = 0 diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 6b895abbf66..984115dcbbe 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -10,7 +10,7 @@ from collections import defaultdict from contextlib import ExitStack from functools import partial, reduce -from typing import Callable +from typing import TYPE_CHECKING from uuid import uuid4 import numpy as np @@ -20,10 +20,15 @@ import cudf from cudf._lib import parquet as libparquet from cudf.api.types import is_list_like -from cudf.core.column import as_column, build_categorical_column, column_empty +from cudf.core.column import as_column, column_empty +from cudf.core.column.categorical import CategoricalColumn, as_unsigned_codes from cudf.utils import ioutils from cudf.utils.performance_tracking import _performance_tracking +if TYPE_CHECKING: + from collections.abc import Callable + + BYTE_SIZES = { "kb": 1000, "mb": 1000000, @@ -807,12 +812,17 @@ def _parquet_to_frame( partition_categories[name].index(value), length=_len, ) - dfs[-1][name] = build_categorical_column( - categories=partition_categories[name], - codes=codes, + codes = as_unsigned_codes( + len(partition_categories[name]), codes + ) + dfs[-1][name] = CategoricalColumn( + data=None, size=codes.size, + dtype=cudf.CategoricalDtype( + categories=partition_categories[name], ordered=False + ), offset=codes.offset, - ordered=False, + children=(codes,), ) else: # Not building categorical columns, so diff --git a/python/cudf/cudf/options.py b/python/cudf/cudf/options.py index 94e73021cec..df7bbe22a61 100644 --- a/python/cudf/cudf/options.py +++ b/python/cudf/cudf/options.py @@ -5,10 +5,10 @@ import textwrap from contextlib import ContextDecorator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from collections.abc import Container + from collections.abc import Callable, Container @dataclass diff --git a/python/cudf/cudf/pandas/_wrappers/pandas.py b/python/cudf/cudf/pandas/_wrappers/pandas.py index 478108f36f1..6d03063fa27 100644 --- a/python/cudf/cudf/pandas/_wrappers/pandas.py +++ b/python/cudf/cudf/pandas/_wrappers/pandas.py @@ -61,6 +61,12 @@ TimeGrouper as pd_TimeGrouper, ) +try: + from IPython import get_ipython + + ipython_shell = get_ipython() +except ImportError: + ipython_shell = None cudf.set_option("mode.pandas_compatible", True) @@ -208,6 +214,12 @@ def _DataFrame__dir__(self): ] +def ignore_ipython_canary_check(self, **kwargs): + raise AttributeError( + "_ipython_canary_method_should_not_exist_ doesn't exist" + ) + + DataFrame = make_final_proxy_type( "DataFrame", cudf.DataFrame, @@ -220,10 +232,26 @@ def _DataFrame__dir__(self): "_constructor": _FastSlowAttribute("_constructor"), "_constructor_sliced": _FastSlowAttribute("_constructor_sliced"), "_accessors": set(), + "_ipython_canary_method_should_not_exist_": ignore_ipython_canary_check, }, ) +def custom_repr_html(obj): + # This custom method is need to register a html format + # for ipython + return _fast_slow_function_call( + lambda obj: obj._repr_html_(), + obj, + )[0] + + +if ipython_shell: + # See: https://ipython.readthedocs.io/en/stable/config/integrating.html#formatters-for-third-party-types + html_formatter = ipython_shell.display_formatter.formatters["text/html"] + html_formatter.for_type(DataFrame, custom_repr_html) + + Series = make_final_proxy_type( "Series", cudf.Series, diff --git a/python/cudf/cudf/pandas/fast_slow_proxy.py b/python/cudf/cudf/pandas/fast_slow_proxy.py index bb678fd1efe..4b0fd9a5b36 100644 --- a/python/cudf/cudf/pandas/fast_slow_proxy.py +++ b/python/cudf/cudf/pandas/fast_slow_proxy.py @@ -10,9 +10,9 @@ import pickle import types import warnings -from collections.abc import Iterator +from collections.abc import Callable, Iterator from enum import IntEnum -from typing import Any, Callable, Literal, Mapping +from typing import Any, Literal, Mapping import numpy as np diff --git a/python/cudf/cudf/testing/_utils.py b/python/cudf/cudf/testing/_utils.py index a6a2d4eea00..540f12c8382 100644 --- a/python/cudf/cudf/testing/_utils.py +++ b/python/cudf/cudf/testing/_utils.py @@ -1,6 +1,7 @@ # Copyright (c) 2020-2024, NVIDIA CORPORATION. import itertools +import signal import string from collections import abc from contextlib import contextmanager @@ -368,3 +369,23 @@ def sv_to_udf_str_testing_lowering(context, builder, sig, args): return cast_string_view_to_udf_string( context, builder, sig.args[0], sig.return_type, args[0] ) + + +class cudf_timeout: + """ + Context manager to raise a TimeoutError after a specified number of seconds. + """ + + def __init__(self, seconds, *, timeout_message=""): + self.seconds = int(seconds) + self.timeout_message = timeout_message + + def _timeout_handler(self, signum, frame): + raise TimeoutError(self.timeout_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self._timeout_handler) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) diff --git a/python/cudf/cudf/tests/test_indexing.py b/python/cudf/cudf/tests/test_indexing.py index 716b4dc6acd..9df2852dde8 100644 --- a/python/cudf/cudf/tests/test_indexing.py +++ b/python/cudf/cudf/tests/test_indexing.py @@ -2369,3 +2369,13 @@ def test_duplicate_labels_raises(): df[["a", "a"]] with pytest.raises(ValueError): df.loc[:, ["a", "a"]] + + +@pytest.mark.parametrize("indexer", ["iloc", "loc"]) +@pytest.mark.parametrize("dtype", ["category", "timedelta64[ns]"]) +def test_loc_iloc_setitem_col_slice_non_cupy_types(indexer, dtype): + df_pd = pd.DataFrame(range(2), dtype=dtype) + df_cudf = cudf.DataFrame.from_pandas(df_pd) + getattr(df_pd, indexer)[:, 0] = getattr(df_pd, indexer)[:, 0] + getattr(df_cudf, indexer)[:, 0] = getattr(df_cudf, indexer)[:, 0] + assert_eq(df_pd, df_cudf) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index db4f1c9c8bd..879b2bd3d74 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -515,10 +515,6 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ) -@pytest.mark.skipif( - version.parse(pa.__version__) < version.parse("1.0.1"), - reason="pyarrow 1.0.0 needed for various operators and operand types", -) @pytest.mark.parametrize( "predicate,expected_len", [ @@ -2393,6 +2389,10 @@ def test_parquet_writer_list_large_mixed(tmpdir): @pytest.mark.parametrize("store_schema", [True, False]) def test_parquet_writer_list_chunked(tmpdir, store_schema): + if store_schema and version.parse(pa.__version__) < version.parse( + "15.0.0" + ): + pytest.skip("https://github.com/apache/arrow/pull/37792") table1 = cudf.DataFrame( { "a": list_gen(string_gen, 128, 80, 50), @@ -2578,6 +2578,10 @@ def normalized_equals(value1, value2): @pytest.mark.parametrize("add_nulls", [True, False]) @pytest.mark.parametrize("store_schema", [True, False]) def test_parquet_writer_statistics(tmpdir, pdf, add_nulls, store_schema): + if store_schema and version.parse(pa.__version__) < version.parse( + "15.0.0" + ): + pytest.skip("https://github.com/apache/arrow/pull/37792") file_path = tmpdir.join("cudf.parquet") if "col_category" in pdf.columns: pdf = pdf.drop(columns=["col_category", "col_bool"]) @@ -2957,6 +2961,10 @@ def test_per_column_options_string_col(tmpdir, encoding): assert encoding in fmd.row_group(0).column(0).encodings +@pytest.mark.skipif( + version.parse(pa.__version__) < version.parse("16.0.0"), + reason="https://github.com/apache/arrow/pull/39748", +) @pytest.mark.parametrize( "num_rows", [200, 10000], @@ -3557,6 +3565,10 @@ def test_parquet_reader_roundtrip_structs_with_arrow_schema(tmpdir, data): @pytest.mark.parametrize("index", [None, True, False]) +@pytest.mark.skipif( + version.parse(pa.__version__) < version.parse("15.0.0"), + reason="https://github.com/apache/arrow/pull/37792", +) def test_parquet_writer_roundtrip_with_arrow_schema(index): # Ensure that the concrete and nested types are faithfully being roundtripped # across Parquet with arrow schema @@ -3707,6 +3719,10 @@ def test_parquet_writer_int96_timestamps_and_arrow_schema(): ], ) @pytest.mark.parametrize("index", [None, True, False]) +@pytest.mark.skipif( + version.parse(pa.__version__) < version.parse("15.0.0"), + reason="https://github.com/apache/arrow/pull/37792", +) def test_parquet_writer_roundtrip_structs_with_arrow_schema( tmpdir, data, index ): diff --git a/python/cudf/cudf/tests/test_repr.py b/python/cudf/cudf/tests/test_repr.py index a013745f71e..57eef9e3463 100644 --- a/python/cudf/cudf/tests/test_repr.py +++ b/python/cudf/cudf/tests/test_repr.py @@ -1480,3 +1480,14 @@ def test_interval_index_repr(): gi = cudf.from_pandas(pi) assert repr(pi) == repr(gi) + + +def test_large_unique_categories_repr(): + # Unfortunately, this is a long running test (takes about 1 minute) + # and there is no way we can reduce the time + pi = pd.CategoricalIndex(range(100_000_000)) + gi = cudf.CategoricalIndex(range(100_000_000)) + expected_repr = repr(pi) + with utils.cudf_timeout(2, timeout_message="Failed to repr fast enough"): + actual_repr = repr(gi) + assert expected_repr == actual_repr diff --git a/python/cudf/cudf/utils/ioutils.py b/python/cudf/cudf/utils/ioutils.py index e5944d7093c..94974e595b1 100644 --- a/python/cudf/cudf/utils/ioutils.py +++ b/python/cudf/cudf/utils/ioutils.py @@ -4,9 +4,9 @@ import os import urllib import warnings +from collections.abc import Callable from io import BufferedWriter, BytesIO, IOBase, TextIOWrapper from threading import Thread -from typing import Callable import fsspec import fsspec.implementations.local diff --git a/python/cudf/cudf_pandas_tests/data/repr_slow_down_test.ipynb b/python/cudf/cudf_pandas_tests/data/repr_slow_down_test.ipynb new file mode 100644 index 00000000000..c7d39b78810 --- /dev/null +++ b/python/cudf/cudf_pandas_tests/data/repr_slow_down_test.ipynb @@ -0,0 +1,69 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext cudf.pandas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "np.random.seed(0)\n", + "\n", + "num_rows = 25_000_000\n", + "num_columns = 12\n", + "\n", + "# Create a DataFrame with random data\n", + "df = pd.DataFrame(np.random.randint(0, 100, size=(num_rows, num_columns)),\n", + " columns=[f'Column_{i}' for i in range(1, num_columns + 1)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py index 028f5f173ac..0827602852d 100644 --- a/python/cudf/cudf_pandas_tests/test_cudf_pandas.py +++ b/python/cudf/cudf_pandas_tests/test_cudf_pandas.py @@ -14,9 +14,12 @@ import types from io import BytesIO, StringIO +import jupyter_client +import nbformat import numpy as np import pyarrow as pa import pytest +from nbconvert.preprocessors import ExecutePreprocessor from numba import NumbaDeprecationWarning from pytz import utc @@ -1650,3 +1653,36 @@ def test_change_index_name(index): assert s.index.name == name assert df.index.name == name + + +def test_notebook_slow_repr(): + notebook_filename = ( + os.path.dirname(os.path.abspath(__file__)) + + "/data/repr_slow_down_test.ipynb" + ) + with open(notebook_filename, "r", encoding="utf-8") as f: + nb = nbformat.read(f, as_version=4) + + ep = ExecutePreprocessor( + timeout=20, kernel_name=jupyter_client.KernelManager().kernel_name + ) + + try: + ep.preprocess(nb, {"metadata": {"path": "./"}}) + except Exception as e: + assert False, f"Error executing the notebook: {e}" + + # Collect the outputs + html_result = nb.cells[2]["outputs"][0]["data"]["text/html"] + for string in { + "div", + "Column_1", + "Column_2", + "Column_3", + "Column_4", + "tbody", + "", + }: + assert ( + string in html_result + ), f"Expected string {string} not found in the output" diff --git a/python/cudf/pyproject.toml b/python/cudf/pyproject.toml index 8386935fab0..17d1292980b 100644 --- a/python/cudf/pyproject.toml +++ b/python/cudf/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "packaging", "pandas>=2.0,<2.2.3dev0", "ptxcompiler", - "pyarrow>=16.1.0,<16.2.0a0", + "pyarrow>=14.0.0,<18.0.0a0", "pylibcudf==24.10.*,>=0.0.0a0", "rich", "rmm==24.10.*,>=0.0.0a0", @@ -63,11 +63,15 @@ test = [ "tzdata", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. pandas-tests = [ + "ipython", "pandas[test, pyarrow, performance, computation, fss, excel, parquet, feather, hdf5, spss, html, xml, plot, output-formatting, clipboard, compression]", "pytest-reportlog", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. cudf-pandas-tests = [ "ipython", + "jupyter_client", + "nbconvert", + "nbformat", "openpyxl", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index ebc7dee6bfb..e334e6f5cc5 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -18,7 +18,7 @@ import types from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import pyarrow as pa import pylibcudf as plc @@ -31,7 +31,7 @@ from cudf_polars.utils import sorting if TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import Callable, MutableMapping from typing import Literal from cudf_polars.typing import Schema diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py index 5276073e62a..adab10bdded 100644 --- a/python/cudf_polars/cudf_polars/typing/__init__.py +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -13,7 +13,8 @@ from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir if TYPE_CHECKING: - from typing import Callable, TypeAlias + from collections.abc import Callable + from typing import TypeAlias import polars as pl diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 0382e3ce6a2..f2bab9e6623 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -115,7 +115,6 @@ ignore = [ # tryceratops "TRY003", # Avoid specifying long messages outside the exception class # pyupgrade - "UP035", # Import from `collections.abc` instead: `Callable` "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` # Lints below are turned off because of conflicts with the ruff # formatter diff --git a/python/dask_cudf/README.md b/python/dask_cudf/README.md deleted file mode 120000 index fe840054137..00000000000 --- a/python/dask_cudf/README.md +++ /dev/null @@ -1 +0,0 @@ -../../README.md \ No newline at end of file diff --git a/python/dask_cudf/README.md b/python/dask_cudf/README.md new file mode 100644 index 00000000000..6edb9f87d48 --- /dev/null +++ b/python/dask_cudf/README.md @@ -0,0 +1,135 @@ +#