Skip to content

Commit

Permalink
Change get_hash to get_hashable
Browse files Browse the repository at this point in the history
This way someone wanting a more stable hash key can still use the same
infrastructure.
  • Loading branch information
wence- committed Oct 10, 2024
1 parent a454876 commit 87970fc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 38 deletions.
8 changes: 4 additions & 4 deletions python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cudf_polars.utils import dtypes, sorting

if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Hashable, Mapping

import polars as pl
import polars.type_aliases as pl_types
Expand Down Expand Up @@ -298,12 +298,12 @@ def __init__(self, dtype: plc.DataType, value: pl.Series) -> None:
data = value.to_arrow()
self.value = data.cast(dtypes.downcast_arrow_lists(data.type))

def get_hash(self) -> int:
"""Compute a hash of the column."""
def get_hashable(self) -> Hashable:
"""Compute a hashable representation of the column."""
# This is stricter than necessary, but we only need this hash
# for identity in groupby replacements so it's OK. And this
# way we avoid doing potentially expensive compute.
return hash((type(self), self.dtype, id(self.value)))
return (type(self), self.dtype, id(self.value))

def do_evaluate(
self,
Expand Down
63 changes: 38 additions & 25 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from cudf_polars.utils import dtypes

if TYPE_CHECKING:
from collections.abc import Callable, MutableMapping, Sequence
from collections.abc import Callable, Hashable, MutableMapping, Sequence
from typing import Literal

from cudf_polars.typing import Schema
Expand Down Expand Up @@ -131,12 +131,18 @@ class IR(Node):
"""Mapping from column names to their data types."""
children: tuple[IR, ...] = ()

def get_hash(self) -> int:
"""Hash of node, treating schema dictionary."""
def get_hashable(self) -> Hashable:
"""
Hashable representation of node, treating schema dictionary.
Since the schema is a dictionary, even though it is morally
immutable, it is not hashable. We therefore convert it to
tuples for hashing purposes.
"""
# Schema is the first constructor argument
args = self._ctor_arguments(self.children)[1:]
schema_hash = tuple(self.schema.items())
return hash((type(self), schema_hash, args))
return (type(self), schema_hash, args)

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""
Expand Down Expand Up @@ -309,21 +315,26 @@ def __init__(
"Reading only parquet metadata to produce row index."
)

def get_hash(self) -> int:
"""Hash of the node."""
return hash(
(
type(self),
self.typ,
json.dumps(self.reader_options),
json.dumps(self.cloud_options),
tuple(self.paths),
tuple(self.with_columns) if self.with_columns is not None else None,
self.skip_rows,
self.n_rows,
self.row_index,
self.predicate,
)
def get_hashable(self) -> Hashable:
"""
Hashable representation of the node.
The options dictionaries are serialised for hashing purposes
as json strings.
"""
schema_hash = tuple(self.schema.items())
return (
type(self),
schema_hash,
self.typ,
json.dumps(self.reader_options),
json.dumps(self.cloud_options),
tuple(self.paths),
tuple(self.with_columns) if self.with_columns is not None else None,
self.skip_rows,
self.n_rows,
self.row_index,
self.predicate,
)

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
Expand Down Expand Up @@ -526,13 +537,15 @@ def __init__(
self.projection = tuple(projection) if projection is not None else None
self.predicate = predicate

def get_hash(self) -> int:
"""Compute a hash of the node."""
# Stricter than necessary, but avoid hashing the dataframe.
def get_hashable(self) -> Hashable:
"""
Hashable representation of the node.
The (heavy) dataframe object is hashed as its id, so this is
not stable across runs, or repeat instances of the same equal dataframes.
"""
schema_hash = tuple(self.schema.items())
return hash(
(type(self), schema_hash, id(self.df), self.projection, self.predicate)
)
return (type(self), schema_hash, id(self.df), self.projection, self.predicate)

def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
"""Evaluate and return a dataframe."""
Expand Down
43 changes: 34 additions & 9 deletions python/cudf_polars/cudf_polars/dsl/nodebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Hashable, Sequence

from typing_extensions import Self

Expand Down Expand Up @@ -55,23 +55,42 @@ def reconstruct(
"""
return type(self)(*self._ctor_arguments(children))

def get_hash(self) -> int:
"""Return a hash of the node."""
return hash((type(self), self._ctor_arguments(self.children)))
def get_hashable(self) -> Hashable:
"""
Return a hashable object for the node.
Returns
-------
Hashable object.
Notes
-----
This method is used by the :meth:`__hash__` implementation
(which does caching). If your node type needs special-case
handling for some of its attributes, override this method, not
:meth:`__hash__`.
"""
return (type(self), self._ctor_arguments(self.children))

def __hash__(self) -> int:
"""Hash of an expression with caching."""
"""
Hash of an expression with caching.
See Also
--------
get_hashable
"""
try:
return self._hash_value
except AttributeError:
self._hash_value = self.get_hash()
self._hash_value = hash(self.get_hashable())
return self._hash_value

def is_equal(self, other: Any) -> bool:
"""
Equality of two expressions.
Override this in subclasses, rather than __eq__.
Override this in subclasses, rather than :meth:`__eq__`.
Parameter
---------
Expand All @@ -80,7 +99,7 @@ def is_equal(self, other: Any) -> bool:
Notes
-----
Since nodes are immutable, this does common-subexpression
Since nodes are immutable, this does common subexpression
elimination when two nodes are determined to be equal.
Returns
Expand All @@ -100,7 +119,13 @@ def is_equal(self, other: Any) -> bool:
return result

def __eq__(self, other: Any) -> bool:
"""Equality of expressions."""
"""
Equality of expressions.
See Also
--------
is_equal
"""
if type(self) is not type(other) or hash(self) != hash(other):
return False
else:
Expand Down

0 comments on commit 87970fc

Please sign in to comment.