Skip to content

Commit

Permalink
Use instance over is_foo_dtype (rapidsai#14641)
Browse files Browse the repository at this point in the history
Similar to rapidsai#14638, use isinstance when we know we are checking a dtype instance

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: rapidsai#14641
  • Loading branch information
mroeschke authored Jan 19, 2024
1 parent 2c1b949 commit f785ed3
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 113 deletions.
14 changes: 7 additions & 7 deletions python/cudf/cudf/_lib/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ from typing import Literal

import cupy as cp
import numpy as np
import pandas as pd

import rmm

import cudf
import cudf._lib as libcudf
from cudf._lib import pylibcudf
from cudf.api.types import is_categorical_dtype, is_datetime64tz_dtype
from cudf.core.buffer import (
Buffer,
ExposureTrackedBuffer,
Expand Down Expand Up @@ -344,10 +344,10 @@ cdef class Column:
)

cdef mutable_column_view mutable_view(self) except *:
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
data_dtype = col.dtype
elif is_datetime64tz_dtype(self.dtype):
elif isinstance(self.dtype, pd.DatetimeTZDtype):
col = self
data_dtype = _get_base_dtype(col.dtype)
else:
Expand Down Expand Up @@ -407,10 +407,10 @@ cdef class Column:
return self._view(c_null_count)

cdef column_view _view(self, libcudf_types.size_type null_count) except *:
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
data_dtype = col.dtype
elif is_datetime64tz_dtype(self.dtype):
elif isinstance(self.dtype, pd.DatetimeTZDtype):
col = self
data_dtype = _get_base_dtype(col.dtype)
else:
Expand Down Expand Up @@ -482,7 +482,7 @@ cdef class Column:
# categoricals because cudf supports ordered and unordered categoricals
# while libcudf supports only unordered categoricals (see
# https://github.com/rapidsai/cudf/pull/8567).
if is_categorical_dtype(self.dtype):
if isinstance(self.dtype, cudf.CategoricalDtype):
col = self.base_children[0]
else:
col = self
Expand Down Expand Up @@ -648,7 +648,7 @@ cdef class Column:
"""
column_owner = isinstance(owner, Column)
mask_owner = owner
if column_owner and is_categorical_dtype(owner.dtype):
if column_owner and isinstance(owner.dtype, cudf.CategoricalDtype):
owner = owner.base_children[0]

size = cv.size()
Expand Down
74 changes: 48 additions & 26 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
from functools import singledispatch

from pandas.core.groupby.groupby import DataError

from cudf.api.types import (
is_categorical_dtype,
is_decimal_dtype,
is_interval_dtype,
is_list_dtype,
is_string_dtype,
is_struct_dtype,
)
from cudf.api.types import is_string_dtype
from cudf.core.buffer import acquire_spill_lock
from cudf.core.dtypes import (
CategoricalDtype,
DecimalDtype,
IntervalDtype,
ListDtype,
StructDtype,
)

from libcpp cimport bool
from libcpp.memory cimport unique_ptr
Expand Down Expand Up @@ -73,6 +74,43 @@ _DECIMAL_AGGS = {
ctypedef const scalar constscalar


@singledispatch
def get_valid_aggregation(dtype):
if is_string_dtype(dtype):
return _STRING_AGGS
return "ALL"


@get_valid_aggregation.register
def _(dtype: ListDtype):
return _LIST_AGGS


@get_valid_aggregation.register
def _(dtype: CategoricalDtype):
return _CATEGORICAL_AGGS


@get_valid_aggregation.register
def _(dtype: ListDtype):
return _LIST_AGGS


@get_valid_aggregation.register
def _(dtype: StructDtype):
return _STRUCT_AGGS


@get_valid_aggregation.register
def _(dtype: IntervalDtype):
return _INTERVAL_AGGS


@get_valid_aggregation.register
def _(dtype: DecimalDtype):
return _DECIMAL_AGGS


cdef _agg_result_from_columns(
vector[libcudf_groupby.aggregation_result]& c_result_columns,
set column_included,
Expand Down Expand Up @@ -187,15 +225,7 @@ cdef class GroupBy:
for i, (col, aggs) in enumerate(zip(values, aggregations)):
dtype = col.dtype

valid_aggregations = (
_LIST_AGGS if is_list_dtype(dtype)
else _STRING_AGGS if is_string_dtype(dtype)
else _CATEGORICAL_AGGS if is_categorical_dtype(dtype)
else _STRUCT_AGGS if is_struct_dtype(dtype)
else _INTERVAL_AGGS if is_interval_dtype(dtype)
else _DECIMAL_AGGS if is_decimal_dtype(dtype)
else "ALL"
)
valid_aggregations = get_valid_aggregation(dtype)
included_aggregations_i = []

c_agg_request = move(libcudf_groupby.aggregation_request())
Expand Down Expand Up @@ -258,15 +288,7 @@ cdef class GroupBy:
for i, (col, aggs) in enumerate(zip(values, aggregations)):
dtype = col.dtype

valid_aggregations = (
_LIST_AGGS if is_list_dtype(dtype)
else _STRING_AGGS if is_string_dtype(dtype)
else _CATEGORICAL_AGGS if is_categorical_dtype(dtype)
else _STRUCT_AGGS if is_struct_dtype(dtype)
else _INTERVAL_AGGS if is_interval_dtype(dtype)
else _DECIMAL_AGGS if is_decimal_dtype(dtype)
else "ALL"
)
valid_aggregations = get_valid_aggregation(dtype)
included_aggregations_i = []

c_agg_request = move(libcudf_groupby.scan_request())
Expand Down
10 changes: 5 additions & 5 deletions python/cudf/cudf/_lib/interop.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cpython cimport pycapsule
from libcpp.memory cimport shared_ptr, unique_ptr
Expand All @@ -18,8 +18,8 @@ from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.utils cimport columns_from_unique_ptr, table_view_from_columns

from cudf.api.types import is_list_dtype, is_struct_dtype
from cudf.core.buffer import acquire_spill_lock
from cudf.core.dtypes import ListDtype, StructDtype


def from_dlpack(dlpack_capsule):
Expand Down Expand Up @@ -98,7 +98,7 @@ cdef vector[column_metadata] gather_metadata(object cols_dtypes) except *:
if cols_dtypes is not None:
for idx, (col_name, col_dtype) in enumerate(cols_dtypes):
cpp_metadata.push_back(column_metadata(col_name.encode()))
if is_struct_dtype(col_dtype) or is_list_dtype(col_dtype):
if isinstance(col_dtype, (ListDtype, StructDtype)):
_set_col_children_metadata(col_dtype, cpp_metadata[idx])
else:
raise TypeError(
Expand All @@ -113,14 +113,14 @@ cdef _set_col_children_metadata(dtype,

cdef column_metadata element_metadata

if is_struct_dtype(dtype):
if isinstance(dtype, StructDtype):
for name, value in dtype.fields.items():
element_metadata = column_metadata(name.encode())
_set_col_children_metadata(
value, element_metadata
)
col_meta.children_meta.push_back(element_metadata)
elif is_list_dtype(dtype):
elif isinstance(dtype, ListDtype):
col_meta.children_meta.reserve(2)
# Offsets - child 0
col_meta.children_meta.push_back(column_metadata())
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/_lib/io/utils.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

from cpython.buffer cimport PyBUF_READ
from cpython.memoryview cimport PyMemoryView_FromMemory
Expand All @@ -23,7 +23,7 @@ import errno
import io
import os

from cudf.api.types import is_struct_dtype
from cudf.core.dtypes import StructDtype


# Converts the Python source input to libcudf IO source_info
Expand Down Expand Up @@ -172,7 +172,7 @@ cdef Column update_column_struct_field_names(
)
col.set_base_children(tuple(children))

if is_struct_dtype(col):
if isinstance(col.dtype, StructDtype):
field_names.reserve(len(col.base_children))
for i in range(info.children.size()):
field_names.push_back(info.children[i].name)
Expand Down
21 changes: 8 additions & 13 deletions python/cudf/cudf/_lib/json.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

# cython: boundscheck = False

Expand All @@ -17,6 +17,7 @@ from libcpp.utility cimport move
from libcpp.vector cimport vector

cimport cudf._lib.cpp.io.types as cudf_io_types
from cudf._lib.column cimport Column
from cudf._lib.cpp.io.data_sink cimport data_sink
from cudf._lib.cpp.io.json cimport (
json_reader_options,
Expand All @@ -42,10 +43,6 @@ from cudf._lib.io.utils cimport (
from cudf._lib.types cimport dtype_to_data_type
from cudf._lib.utils cimport data_from_unique_ptr, table_view_from_table

from cudf.api.types import is_list_dtype, is_struct_dtype

from cudf._lib.column cimport Column


cpdef read_json(object filepaths_or_buffers,
object dtype,
Expand Down Expand Up @@ -214,13 +211,12 @@ def write_json(
cdef schema_element _get_cudf_schema_element_from_dtype(object dtype) except *:
cdef schema_element s_element
cdef data_type lib_type
if cudf.api.types.is_categorical_dtype(dtype):
dtype = cudf.dtype(dtype)
if isinstance(dtype, cudf.CategoricalDtype):
raise NotImplementedError(
"CategoricalDtype as dtype is not yet "
"supported in JSON reader"
)

dtype = cudf.dtype(dtype)
lib_type = dtype_to_data_type(dtype)
s_element.type = lib_type
if isinstance(dtype, cudf.StructDtype):
Expand All @@ -237,19 +233,18 @@ cdef schema_element _get_cudf_schema_element_from_dtype(object dtype) except *:


cdef data_type _get_cudf_data_type_from_dtype(object dtype) except *:
if cudf.api.types.is_categorical_dtype(dtype):
dtype = cudf.dtype(dtype)
if isinstance(dtype, cudf.CategoricalDtype):
raise NotImplementedError(
"CategoricalDtype as dtype is not yet "
"supported in JSON reader"
)

dtype = cudf.dtype(dtype)
return dtype_to_data_type(dtype)

cdef _set_col_children_metadata(Column col,
column_name_info& col_meta):
cdef column_name_info child_info
if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
Expand All @@ -258,7 +253,7 @@ cdef _set_col_children_metadata(Column col,
_set_col_children_metadata(
child_col, col_meta.children[i]
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
for i, child_col in enumerate(col.children):
col_meta.children.push_back(child_info)
_set_col_children_metadata(
Expand Down
7 changes: 3 additions & 4 deletions python/cudf/cudf/_lib/orc.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import cudf
from cudf.core.buffer import acquire_spill_lock
Expand Down Expand Up @@ -59,7 +59,6 @@ from cudf._lib.utils cimport data_from_unique_ptr, table_view_from_table
from pyarrow.lib import NativeFile

from cudf._lib.utils import _index_level_name, generate_pandas_metadata
from cudf.api.types import is_list_dtype, is_struct_dtype


cpdef read_raw_orc_statistics(filepath_or_buffer):
Expand Down Expand Up @@ -474,15 +473,15 @@ cdef class ORCWriter:
cdef _set_col_children_metadata(Column col,
column_in_metadata& col_meta,
list_column_as_map=False):
if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
col_meta.child(i).set_name(name.encode())
_set_col_children_metadata(
child_col, col_meta.child(i), list_column_as_map
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
if list_column_as_map:
col_meta.set_list_column_as_map()
_set_col_children_metadata(
Expand Down
21 changes: 7 additions & 14 deletions python/cudf/cudf/_lib/parquet.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.

# cython: boundscheck = False

Expand All @@ -18,12 +18,7 @@ import numpy as np

from cython.operator cimport dereference

from cudf.api.types import (
is_decimal_dtype,
is_list_dtype,
is_list_like,
is_struct_dtype,
)
from cudf.api.types import is_list_like

from cudf._lib.utils cimport data_from_unique_ptr

Expand Down Expand Up @@ -220,7 +215,7 @@ cpdef read_parquet(filepaths_or_buffers, columns=None, row_groups=None,

# update the decimal precision of each column
for col in names:
if is_decimal_dtype(df._data[col].dtype):
if isinstance(df._data[col].dtype, cudf.core.dtypes.DecimalDtype):
df._data[col].dtype.precision = (
meta_data_per_column[col]["metadata"]["precision"]
)
Expand Down Expand Up @@ -703,7 +698,7 @@ cdef _set_col_metadata(
# is true.
col_meta.set_nullability(True)

if is_struct_dtype(col):
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
Expand All @@ -713,13 +708,11 @@ cdef _set_col_metadata(
col_meta.child(i),
force_nullable_schema
)
elif is_list_dtype(col):
elif isinstance(col.dtype, cudf.ListDtype):
_set_col_metadata(
col.children[1],
col_meta.child(1),
force_nullable_schema
)
else:
if is_decimal_dtype(col):
col_meta.set_decimal_precision(col.dtype.precision)
return
elif isinstance(col.dtype, cudf.core.dtypes.DecimalDtype):
col_meta.set_decimal_precision(col.dtype.precision)
Loading

0 comments on commit f785ed3

Please sign in to comment.