Skip to content

Commit

Permalink
update include paths and polars for transform
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Aug 23, 2024
1 parent 94975a1 commit 2dda5d2
Show file tree
Hide file tree
Showing 7 changed files with 168 additions and 162 deletions.
13 changes: 8 additions & 5 deletions docs/source/examples/realistic_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ from pydiverse.transform.core.verbs import (
select,
build_query,
)
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl


@pdt.verb
Expand Down Expand Up @@ -111,7 +111,8 @@ def lazy_features(a: pdt.Table, src_tbls: list[pdt.Table]):
)


@materialize(input_type=PandasTableImpl, version="2.3.5")
@materialize(input_type=PolarsEager
, version="2.3.5")
def eager_features(a: pdt.Table, src_tbls: list[pdt.Table]):
named_tbls = get_named_tables(src_tbls)
c = named_tbls["c"]
Expand Down Expand Up @@ -182,11 +183,13 @@ def predict(model: xgboost.Booster, test_set: pdt.Table):
predict_col = model.predict(dx)

return pdt.Table(
PandasTableImpl("prediction", pd.DataFrame({"prediction": predict_col}))
PolarsEager
("prediction", pd.DataFrame({"prediction": predict_col}))
).prediction


@materialize(input_type=PandasTableImpl, version="3.4.5")
@materialize(input_type=PolarsEager
, version="3.4.5")
def model_evaluation(model: xgboost.Booster, test_set: pdt.Table):
prediction = predict(model, test_set) # produces an aligned vector with input
return (
Expand Down
3 changes: 2 additions & 1 deletion docs/source/table_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ decorator out-of-the-box:

- `sqlalchemy.Table` (see [https://www.sqlalchemy.org/](https://www.sqlalchemy.org/); recommended with `lazy=True`;
can also be used for composing handwritten SQL strings)
- `pydiverse.transform.eager.PandasTableImpl` (see
- `pydiverse.transform.eager.PolarsEager
` (see
[https://pydiversetransform.readthedocs.io/en/latest/](https://pydiversetransform.readthedocs.io/en/latest/);
recommended with manual version bumping and `version="X.Y.Z"`)
- `pydiverse.transform.lazy.SQLTableImpl` (
Expand Down
14 changes: 7 additions & 7 deletions src/pydiverse/pipedag/backend/table/cache/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,21 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager

return issubclass(type_, PandasTableImpl)
return issubclass(type_, PolarsEager)

@classmethod
def materialize(
cls, store: ParquetTableCache, table: Table[pdt.Table], stage_name: str
):
from pydiverse.transform.core.verbs import collect
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager

t = table.obj
table = table.copy_without_obj()

if isinstance(t._impl, PandasTableImpl):
if isinstance(t._impl, PolarsEager):
table.obj = t >> collect()
return store.get_hook_subclass(PandasTableHook).materialize(
store, table, stage_name
Expand All @@ -232,11 +232,11 @@ def retrieve(
stage_name: str | None,
as_type: type,
):
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager

if isinstance(as_type, PandasTableImpl):
if isinstance(as_type, PolarsEager):
hook = store.get_hook_subclass(PandasTableHook)
df = hook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PandasTableImpl(table.name, df))
return pdt.Table(PolarsEager(table.name, df))

raise ValueError(f"Invalid type {as_type}")
8 changes: 4 additions & 4 deletions src/pydiverse/pipedag/backend/table/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager

return issubclass(type_, PandasTableImpl)
return issubclass(type_, PolarsEager)

@classmethod
def materialize(
Expand All @@ -196,10 +196,10 @@ def materialize(

@classmethod
def retrieve(cls, store, table, stage_name, as_type):
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager

df = PandasTableHook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PandasTableImpl(table.name, df))
return pdt.Table(PolarsEager(table.name, df))

@classmethod
def auto_table(cls, obj: pdt.Table):
Expand Down
22 changes: 11 additions & 11 deletions src/pydiverse/pipedag/backend/table/sql/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,26 +816,26 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl

return issubclass(type_, (PandasTableImpl, SQLTableImpl))
return issubclass(type_, (PolarsEager, SQLTableImpl))

@classmethod
def retrieve_as_reference(cls, type_) -> bool:
from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.sql.sql_table import SQLTableImpl

return issubclass(type_, SQLTableImpl)

@classmethod
def materialize(cls, store, table: Table[pdt.Table], stage_name):
from pydiverse.transform.core.verbs import collect
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl

t = table.obj
table = table.copy_without_obj()
if isinstance(t._impl, PandasTableImpl):
if isinstance(t._impl, PolarsEager):
table.obj = t >> collect()
hook = store.get_hook_subclass(PandasTableHook)
return hook.materialize(store, table, stage_name)
Expand All @@ -853,13 +853,13 @@ def retrieve(
stage_name: str | None,
as_type: type[T],
) -> T:
from pydiverse.transform.eager import PandasTableImpl
from pydiverse.transform.lazy import SQLTableImpl
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl

if issubclass(as_type, PandasTableImpl):
if issubclass(as_type, PolarsEager):
hook = store.get_hook_subclass(PandasTableHook)
df = hook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PandasTableImpl(table.name, df))
return pdt.Table(PolarsEager(table.name, df))
if issubclass(as_type, SQLTableImpl):
hook = store.get_hook_subclass(SQLAlchemyTableHook)
sa_tbl = hook.retrieve(store, table, stage_name, sa.Table)
Expand Down
Loading

0 comments on commit 2dda5d2

Please sign in to comment.