Skip to content

Commit

Permalink
more performant arrow to_dummies (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
anopsy authored Jul 26, 2024
1 parent 5ca86b5 commit c2f1751
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,19 +529,20 @@ def to_dummies(
) -> ArrowDataFrame:
from narwhals._arrow.dataframe import ArrowDataFrame

np = get_numpy()
pa = get_pyarrow()
pc = get_pyarrow_compute()

series = self._native_series
unique_values = self.unique().sort()._native_series
columns = [pc.cast(pc.equal(series, v), pa.uint8()) for v in unique_values][
int(drop_first) :
]
names = [f"{self._name}{separator}{v}" for v in unique_values][int(drop_first) :]
da = series.dictionary_encode().combine_chunks()

columns = np.zeros((len(da.dictionary), len(da)), np.uint8)
columns[da.indices, np.arange(len(da))] = 1
names = [f"{self._name}{separator}{v}" for v in da.dictionary]

return ArrowDataFrame(
pa.Table.from_arrays(columns, names=names),
backend_version=self._backend_version,
)
).select(*sorted(names)[int(drop_first) :])

def quantile(
self: Self,
Expand Down

0 comments on commit c2f1751

Please sign in to comment.