Skip to content

Commit

Permalink
Merge branch 'branch-24.08' of github.com:rapidsai/cudf into pylibcud…
Browse files Browse the repository at this point in the history
…f-lists-extract
  • Loading branch information
Matt711 committed Jul 9, 2024
2 parents cf4ccad + 433e959 commit 7b9542a
Show file tree
Hide file tree
Showing 14 changed files with 480 additions and 296 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,28 @@ jobs:
sha: ${{ inputs.sha }}
date: ${{ inputs.date }}
package-name: dask_cudf
wheel-build-cudf-polars:
needs: wheel-publish-cudf
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
# This selects "ARCH=amd64 + the latest supported Python + CUDA".
matrix_filter: map(select(.ARCH == "amd64")) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))]))
build_type: ${{ inputs.build_type || 'branch' }}
branch: ${{ inputs.branch }}
sha: ${{ inputs.sha }}
date: ${{ inputs.date }}
script: ci/build_wheel_cudf_polars.sh
wheel-publish-cudf-polars:
needs: wheel-build-cudf-polars
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: ${{ inputs.build_type || 'branch' }}
branch: ${{ inputs.branch }}
sha: ${{ inputs.sha }}
date: ${{ inputs.date }}
package-name: cudf_polars
trigger-pandas-tests:
if: inputs.build_type == 'nightly'
needs: wheel-build-cudf
Expand Down
5 changes: 0 additions & 5 deletions cpp/cmake/thirdparty/patches/cccl_override.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@
"packages" : {
"CCCL" : {
"patches" : [
{
"file" : "cccl/revert_pr_211.diff",
"issue" : "thrust::copy introduced a change in behavior that causes failures with cudaErrorInvalidValue.",
"fixed_in" : ""
},
{
"file" : "${current_json_dir}/thrust_disable_64bit_dispatching.diff",
"issue" : "Remove 64bit dispatching as not needed by libcudf and results in compiling twice as many kernels [https://github.com/rapidsai/cudf/pull/11437]",
Expand Down
324 changes: 162 additions & 162 deletions cpp/src/io/text/multibyte_split.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def read_csv(
col_name = df._data.names[index]
df._data[col_name] = df._data[col_name].astype(col_dtype)

if names is not None and isinstance(names[0], (int)):
if names is not None and len(names) and isinstance(names[0], (int)):
df.columns = [int(x) for x in df._data]

# Set index if the index_col parameter is passed
Expand Down
61 changes: 21 additions & 40 deletions python/cudf/cudf/core/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import cupy as cp
import numpy as np

from cudf.core.column import as_column
from cudf.core.copy_types import BooleanMask
from cudf.core.index import RangeIndex, ensure_index
from cudf.core.indexed_frame import IndexedFrame
from cudf.core.scalar import Scalar
from cudf.options import get_option
from cudf.utils.dtypes import can_convert_to_column

if TYPE_CHECKING:
from cudf.core.column.column import ColumnBase
from cudf.core.index import BaseIndex


def factorize(values, sort=False, use_na_sentinel=True, size_hint=None):
"""Encode the input values as integer labels
Expand Down Expand Up @@ -110,55 +115,31 @@ def factorize(values, sort=False, use_na_sentinel=True, size_hint=None):
return labels, cats.values if return_cupy_array else ensure_index(cats)


def _linear_interpolation(column, index=None):
"""
Interpolate over a float column. Implicitly assumes that values are
evenly spaced with respect to the x-axis, for example the data
[1.0, NaN, 3.0] will be interpolated assuming the NaN is half way
between the two valid values, yielding [1.0, 2.0, 3.0]
"""

index = RangeIndex(start=0, stop=len(column), step=1)
return _index_or_values_interpolation(column, index=index)


def _index_or_values_interpolation(column, index=None):
def _interpolation(column: ColumnBase, index: BaseIndex) -> ColumnBase:
"""
Interpolate over a float column. assumes a linear interpolation
strategy using the index of the data to denote spacing of the x
values. For example the data and index [1.0, NaN, 4.0], [1, 3, 4]
would result in [1.0, 3.0, 4.0]
would result in [1.0, 3.0, 4.0].
"""
# figure out where the nans are
mask = cp.isnan(column)
mask = column.isnull()

# trivial cases, all nan or no nans
num_nan = mask.sum()
if num_nan == 0 or num_nan == len(column):
return column
if not mask.any() or mask.all():
return column.copy()

to_interp = IndexedFrame(data={None: column}, index=index)
known_x_and_y = to_interp._apply_boolean_mask(
BooleanMask(~mask, len(to_interp))
)

known_x = known_x_and_y.index.to_cupy()
known_y = known_x_and_y._data.columns[0].values
valid_locs = ~mask
if isinstance(index, RangeIndex):
# Each point is evenly spaced, index values don't matter
known_x = cp.flatnonzero(valid_locs.values)
else:
known_x = index._column.apply_boolean_mask(valid_locs).values # type: ignore[attr-defined]
known_y = column.apply_boolean_mask(valid_locs).values

result = cp.interp(index.to_cupy(), known_x, known_y)

# find the first nan
first_nan_idx = (mask == 0).argmax().item()
first_nan_idx = valid_locs.values.argmax().item()
result[:first_nan_idx] = np.nan
return result


def get_column_interpolator(method):
interpolator = {
"linear": _linear_interpolation,
"index": _index_or_values_interpolation,
"values": _index_or_values_interpolation,
}.get(method, None)
if not interpolator:
raise ValueError(f"Interpolation method `{method}` not found")
return interpolator
return as_column(result)
20 changes: 7 additions & 13 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,24 +1113,18 @@ def is_monotonic_decreasing(self) -> bool:
def as_categorical_column(self, dtype: Dtype) -> CategoricalColumn:
if isinstance(dtype, str) and dtype == "category":
return self
if isinstance(dtype, pd.CategoricalDtype):
dtype = cudf.CategoricalDtype.from_pandas(dtype)
if (
isinstance(
dtype, (cudf.core.dtypes.CategoricalDtype, pd.CategoricalDtype)
)
and (dtype.categories is None)
and (dtype.ordered is None)
isinstance(dtype, cudf.CategoricalDtype)
and dtype.categories is None
and dtype.ordered is None
):
return self

if isinstance(dtype, pd.CategoricalDtype):
dtype = CategoricalDtype(
categories=dtype.categories, ordered=dtype.ordered
)

if not isinstance(dtype, CategoricalDtype):
elif not isinstance(dtype, CategoricalDtype):
raise ValueError("dtype must be CategoricalDtype")

if not isinstance(self.categories, type(dtype.categories._values)):
if not isinstance(self.categories, type(dtype.categories._column)):
# If both categories are of different Column types,
# return a column full of Nulls.
return _create_empty_categorical_column(self, dtype)
Expand Down
91 changes: 44 additions & 47 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,59 +962,59 @@ def astype(self, dtype: Dtype, copy: bool = False) -> ColumnBase:
if len(self) == 0:
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
if copy:
return self.copy()
else:
return self
result = self
else:
return column_empty(0, dtype=dtype, masked=self.nullable)
if copy:
col = self.copy()
else:
col = self
if dtype == "category":
result = column_empty(0, dtype=dtype, masked=self.nullable)
elif dtype == "category":
# TODO: Figure out why `cudf.dtype("category")`
# astype's different than just the string
return col.as_categorical_column(dtype)
result = self.as_categorical_column(dtype)
elif (
isinstance(dtype, str)
and dtype == "interval"
and isinstance(self.dtype, cudf.IntervalDtype)
):
# astype("interval") (the string only) should no-op
return col
was_object = dtype == object or dtype == np.dtype(object)
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
return col
elif isinstance(dtype, CategoricalDtype):
return col.as_categorical_column(dtype)
elif isinstance(dtype, IntervalDtype):
return col.as_interval_column(dtype)
elif isinstance(dtype, (ListDtype, StructDtype)):
if not col.dtype == dtype:
raise NotImplementedError(
f"Casting {self.dtype} columns not currently supported"
)
return col
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
return col.as_decimal_column(dtype)
elif dtype.kind == "M":
return col.as_datetime_column(dtype)
elif dtype.kind == "m":
return col.as_timedelta_column(dtype)
elif dtype.kind == "O":
if cudf.get_option("mode.pandas_compatible") and was_object:
raise ValueError(
f"Casting to {dtype} is not supported, use "
"`.astype('str')` instead."
)
return col.as_string_column(dtype)
result = self
else:
return col.as_numerical_column(dtype)
was_object = dtype == object or dtype == np.dtype(object)
dtype = cudf.dtype(dtype)
if self.dtype == dtype:
result = self
elif isinstance(dtype, CategoricalDtype):
result = self.as_categorical_column(dtype)
elif isinstance(dtype, IntervalDtype):
result = self.as_interval_column(dtype)
elif isinstance(dtype, (ListDtype, StructDtype)):
if not self.dtype == dtype:
raise NotImplementedError(
f"Casting {self.dtype} columns not currently supported"
)
result = self
elif isinstance(dtype, cudf.core.dtypes.DecimalDtype):
result = self.as_decimal_column(dtype)
elif dtype.kind == "M":
result = self.as_datetime_column(dtype)
elif dtype.kind == "m":
result = self.as_timedelta_column(dtype)
elif dtype.kind == "O":
if cudf.get_option("mode.pandas_compatible") and was_object:
raise ValueError(
f"Casting to {dtype} is not supported, use "
"`.astype('str')` instead."
)
result = self.as_string_column(dtype)
else:
result = self.as_numerical_column(dtype)

if copy and result is self:
return result.copy()
return result

def as_categorical_column(self, dtype) -> ColumnBase:
if isinstance(dtype, (cudf.CategoricalDtype, pd.CategoricalDtype)):
if isinstance(dtype, pd.CategoricalDtype):
dtype = cudf.CategoricalDtype.from_pandas(dtype)
if isinstance(dtype, cudf.CategoricalDtype):
ordered = dtype.ordered
else:
ordered = False
Expand All @@ -1023,14 +1023,11 @@ def as_categorical_column(self, dtype) -> ColumnBase:
if (
isinstance(dtype, cudf.CategoricalDtype)
and dtype._categories is not None
) or (
isinstance(dtype, pd.CategoricalDtype)
and dtype.categories is not None
):
labels = self._label_encoding(cats=as_column(dtype.categories))

cat_col = dtype._categories
labels = self._label_encoding(cats=cat_col)
return build_categorical_column(
categories=as_column(dtype.categories),
categories=cat_col,
codes=labels,
mask=self.mask,
ordered=dtype.ordered,
Expand Down
14 changes: 11 additions & 3 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import cudf
import cudf._lib as libcudf
import cudf.core
import cudf.core.algorithms
from cudf.api.extensions import no_default
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
Expand Down Expand Up @@ -1987,6 +1989,8 @@ def interpolate(
"Use obj.ffill() or obj.bfill() instead.",
FutureWarning,
)
elif method not in {"linear", "values", "index"}:
raise ValueError(f"Interpolation method `{method}` not found")

data = self

Expand All @@ -2000,7 +2004,10 @@ def interpolate(
)
)

interpolator = cudf.core.algorithms.get_column_interpolator(method)
if method == "linear":
interp_index = RangeIndex(self._num_rows)
else:
interp_index = data.index
columns = []
for col in data._columns:
if isinstance(col, cudf.core.column.StringColumn):
Expand All @@ -2012,8 +2019,9 @@ def interpolate(
if col.nullable:
col = col.astype("float64").fillna(np.nan)

# Interpolation methods may or may not need the index
columns.append(interpolator(col, index=data.index))
columns.append(
cudf.core.algorithms._interpolation(col, index=interp_index)
)

result = self._from_data_like_self(
self._data._from_columns_like_self(columns)
Expand Down
4 changes: 1 addition & 3 deletions python/cudf/cudf/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cudf.api.types import is_integer, is_list_like, is_object_dtype
from cudf.core import column
from cudf.core._base_index import _return_get_indexer_result
from cudf.core.algorithms import factorize
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.frame import Frame
from cudf.core.index import (
Expand Down Expand Up @@ -1373,9 +1374,6 @@ def from_arrays(
(2, 'blue')],
names=['number', 'color'])
"""
# Imported here due to circular import
from cudf.core.algorithms import factorize

error_msg = "Input must be a list / sequence of array-likes."
if not is_list_like(arrays):
raise TypeError(error_msg)
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/tests/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ def test_interpolate_dataframe_error_cases(data, kwargs):
lfunc_args_and_kwargs=([], kwargs),
rfunc_args_and_kwargs=([], kwargs),
)


def test_interpolate_noop_new_column():
ser = cudf.Series([1.0, 2.0, 3.0])
result = ser.interpolate()
assert ser._column is not result._column
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

import polars.polars as plrs
import polars as pl
import polars.type_aliases as pl_types

from cudf_polars.containers import DataFrame
Expand Down Expand Up @@ -377,7 +377,7 @@ class LiteralColumn(Expr):
value: pa.Array[Any, Any]
children: tuple[()]

def __init__(self, dtype: plc.DataType, value: plrs.PySeries) -> None:
def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
super().__init__(dtype)
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))
Expand Down
Loading

0 comments on commit 7b9542a

Please sign in to comment.