Skip to content

Commit

Permalink
update to new transform version
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Sep 25, 2024
1 parent 2dda5d2 commit bae7cff
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 60 deletions.
35 changes: 17 additions & 18 deletions docs/source/examples/realistic_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import pydiverse.transform as pdt
import xgboost
import xgboost as xgb
from pydiverse.pipedag import Blob, Flow, Stage, Table, materialize
from pydiverse.transform import aligned, λ
from pydiverse.transform.core.verbs import (
from pydiverse.transform import aligned, C
from pydiverse.transform.pipe.verbs import (
alias,
collect,
filter,
Expand All @@ -36,8 +36,7 @@ from pydiverse.transform.core.verbs import (
select,
build_query,
)
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend import PolarsImpl, SqlImpl


@pdt.verb
Expand Down Expand Up @@ -77,12 +76,12 @@ def read_input_data(src_dir="data/pipedag_example_data"):
]


@materialize(input_type=SQLTableImpl, lazy=True)
@materialize(input_type=SqlImpl, lazy=True)
def clean(src_tbls: list[pdt.Table]):
return [tbl >> trim_all_str() for tbl in src_tbls]


@materialize(input_type=SQLTableImpl, lazy=True, nout=3)
@materialize(input_type=SqlImpl, lazy=True, nout=3)
def transform(src_tbls: list[pdt.Table]):
named_tbls = get_named_tables(src_tbls)
a = named_tbls["a"]
Expand All @@ -99,7 +98,7 @@ def transform(src_tbls: list[pdt.Table]):
return new_a, new_b, new_c


@materialize(input_type=SQLTableImpl, lazy=True)
@materialize(input_type=SqlImpl, lazy=True)
def lazy_features(a: pdt.Table, src_tbls: list[pdt.Table]):
named_tbls = get_named_tables(src_tbls)
b = named_tbls["b"]
Expand All @@ -111,7 +110,7 @@ def lazy_features(a: pdt.Table, src_tbls: list[pdt.Table]):
)


@materialize(input_type=PolarsEager
@materialize(input_type=PolarsImpl
, version="2.3.5")
def eager_features(a: pdt.Table, src_tbls: list[pdt.Table]):
named_tbls = get_named_tables(src_tbls)
Expand All @@ -124,7 +123,7 @@ def eager_features(a: pdt.Table, src_tbls: list[pdt.Table]):
)


@materialize(input_type=SQLTableImpl, lazy=True)
@materialize(input_type=SqlImpl, lazy=True)
def combine_features(features1: pdt.Table, features2: pdt.Table):
return (
features1
Expand All @@ -135,7 +134,7 @@ def combine_features(features1: pdt.Table, features2: pdt.Table):
)


@materialize(input_type=SQLTableImpl, lazy=True, nout=2)
@materialize(input_type=SqlImpl, lazy=True, nout=2)
def train_and_test_set(flat_table: pdt.Table, features: pdt.Table):
tbl = (
flat_table
Expand All @@ -146,12 +145,12 @@ def train_and_test_set(flat_table: pdt.Table, features: pdt.Table):

training_set = (
tbl
>> filter(λ.row_num % 10 != 0)
>> select(-λ.row_num)
>> filter(C.row_num % 10 != 0)
>> select(-C.row_num)
>> alias("training_set")
)
test_set = (
tbl >> filter(λ.row_num % 10 == 0) >> select(-λ.row_num) >> alias("test_set")
tbl >> filter(C.row_num % 10 == 0) >> select(-C.row_num) >> alias("test_set")
)

return (training_set, test_set)
Expand All @@ -171,7 +170,7 @@ def model_training(train_set: pd.DataFrame):

@aligned(with_="test_set")
def predict(model: xgboost.Booster, test_set: pdt.Table):
x = test_set >> select(-λ.target) >> collect()
x = test_set >> select(-C.target) >> collect()

# Ugly hack to convert new pandas dtypes to numpy dtypes, because xgboost
# requires numpy dtypes.
Expand All @@ -183,20 +182,20 @@ def predict(model: xgboost.Booster, test_set: pdt.Table):
predict_col = model.predict(dx)

return pdt.Table(
PolarsEager
PolarsImpl
("prediction", pd.DataFrame({"prediction": predict_col}))
).prediction


@materialize(input_type=PolarsEager
@materialize(input_type=PolarsImpl
, version="3.4.5")
def model_evaluation(model: xgboost.Booster, test_set: pdt.Table):
prediction = predict(model, test_set) # produces an aligned vector with input
return (
test_set
>> select(λ.target)
>> select(C.target)
>> mutate(prediction=prediction)
>> mutate(abs_error=abs(λ.target - λ.prediction))
>> mutate(abs_error=abs(C.target - C.prediction))
>> alias("evaluation")
)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def join_tables(names: sa.Alias, ages: sa.Alias):
Or with pydiverse.transform:

```python
@materialize(lazy=True, input_type=pdt.SQLTableImpl)
@materialize(lazy=True, input_type=pdt.SqlImpl)
def join_tables(names: pdt.Table, ages: pdt.Table):
return (
names
Expand Down
4 changes: 2 additions & 2 deletions docs/source/table_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ decorator out-of-the-box:

- `sqlalchemy.Table` (see [https://www.sqlalchemy.org/](https://www.sqlalchemy.org/); recommended with `lazy=True`;
can also be used for composing handwritten SQL strings)
- `pydiverse.transform.eager.PolarsEager
- `pydiverse.transform.backend.PolarsImpl
` (see
[https://pydiversetransform.readthedocs.io/en/latest/](https://pydiversetransform.readthedocs.io/en/latest/);
recommended with manual version bumping and `version="X.Y.Z"`)
- `pydiverse.transform.lazy.SQLTableImpl` (
- `pydiverse.transform.backend.SqlImpl` (
see [https://pydiversetransform.readthedocs.io/en/latest/](https://pydiversetransform.readthedocs.io/en/latest/);
recommended with `lazy=True`)
- `ibis.Table` (see [https://ibis-project.org/](https://ibis-project.org/); recommended with `lazy=True`)
Expand Down
16 changes: 8 additions & 8 deletions src/pydiverse/pipedag/backend/table/cache/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,21 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.backend import PolarsImpl

return issubclass(type_, PolarsEager)
return issubclass(type_, PolarsImpl)

@classmethod
def materialize(
cls, store: ParquetTableCache, table: Table[pdt.Table], stage_name: str
):
from pydiverse.transform.core.verbs import collect
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.backend import PolarsImpl
from pydiverse.transform.pipe.verbs import collect

t = table.obj
table = table.copy_without_obj()

if isinstance(t._impl, PolarsEager):
if isinstance(t._impl, PolarsImpl):
table.obj = t >> collect()
return store.get_hook_subclass(PandasTableHook).materialize(
store, table, stage_name
Expand All @@ -232,11 +232,11 @@ def retrieve(
stage_name: str | None,
as_type: type,
):
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.backend import PolarsImpl

if isinstance(as_type, PolarsEager):
if isinstance(as_type, PolarsImpl):
hook = store.get_hook_subclass(PandasTableHook)
df = hook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PolarsEager(table.name, df))
return pdt.Table(PolarsImpl(table.name, df))

raise ValueError(f"Invalid type {as_type}")
10 changes: 5 additions & 5 deletions src/pydiverse/pipedag/backend/table/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.backend import PolarsImpl

return issubclass(type_, PolarsEager)
return issubclass(type_, PolarsImpl)

@classmethod
def materialize(
Expand All @@ -188,18 +188,18 @@ def materialize(
table: Table[pdt.Table],
stage_name,
):
from pydiverse.transform.core.verbs import collect
from pydiverse.transform.pipe.verbs import collect

table.obj = table.obj >> collect()
# noinspection PyTypeChecker
return PandasTableHook.materialize(store, table, stage_name)

@classmethod
def retrieve(cls, store, table, stage_name, as_type):
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.backend import PolarsImpl

df = PandasTableHook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PolarsEager(table.name, df))
return pdt.Table(PolarsImpl(table.name, df))

@classmethod
def auto_table(cls, obj: pdt.Table):
Expand Down
31 changes: 14 additions & 17 deletions src/pydiverse/pipedag/backend/table/sql/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,30 +816,28 @@ def can_materialize(cls, type_) -> bool:

@classmethod
def can_retrieve(cls, type_) -> bool:
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend import PolarsImpl, SqlImpl

return issubclass(type_, (PolarsEager, SQLTableImpl))
return issubclass(type_, (PolarsImpl, SqlImpl))

@classmethod
def retrieve_as_reference(cls, type_) -> bool:
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend import SqlImpl

return issubclass(type_, SQLTableImpl)
return issubclass(type_, SqlImpl)

@classmethod
def materialize(cls, store, table: Table[pdt.Table], stage_name):
from pydiverse.transform.core.verbs import collect
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend import PolarsImpl, SqlImpl
from pydiverse.transform.pipe.verbs import collect

t = table.obj
table = table.copy_without_obj()
if isinstance(t._impl, PolarsEager):
if isinstance(t._impl, PolarsImpl):
table.obj = t >> collect()
hook = store.get_hook_subclass(PandasTableHook)
return hook.materialize(store, table, stage_name)
if isinstance(t._impl, SQLTableImpl):
if isinstance(t._impl, SqlImpl):
table.obj = t._impl.build_select()
hook = store.get_hook_subclass(SQLAlchemyTableHook)
return hook.materialize(store, table, stage_name)
Expand All @@ -853,17 +851,16 @@ def retrieve(
stage_name: str | None,
as_type: type[T],
) -> T:
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend import PolarsImpl, SqlImpl

if issubclass(as_type, PolarsEager):
if issubclass(as_type, PolarsImpl):
hook = store.get_hook_subclass(PandasTableHook)
df = hook.retrieve(store, table, stage_name, pd.DataFrame)
return pdt.Table(PolarsEager(table.name, df))
if issubclass(as_type, SQLTableImpl):
return pdt.Table(PolarsImpl(table.name, df))
if issubclass(as_type, SqlImpl):
hook = store.get_hook_subclass(SQLAlchemyTableHook)
sa_tbl = hook.retrieve(store, table, stage_name, sa.Table)
return pdt.Table(SQLTableImpl(store.engine, sa_tbl))
return pdt.Table(SqlImpl(store.engine, sa_tbl))
raise NotImplementedError

@classmethod
Expand All @@ -872,7 +869,7 @@ def auto_table(cls, obj: pdt.Table):

@classmethod
def lazy_query_str(cls, store, obj: pdt.Table) -> str:
from pydiverse.transform.core.verbs import build_query
from pydiverse.transform.pipe.verbs import build_query

query = obj >> build_query()

Expand Down
4 changes: 2 additions & 2 deletions src/pydiverse/pipedag/backend/table/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ class SQLTableStore(BaseTableStore):
* - pydiverse.transform
- ``pdt.Table``
- | ``pdt.eager.PolarsEager
- | ``pdt.eager.PolarsImpl
``
| ``pdt.lazy.SQLTableImpl``
| ``pdt.lazy.SqlImpl``
* - pydiverse.pipedag table reference
- :py:class:`~.ExternalTableReference` (no materialization)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_table_hooks/test_pdtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@


try:
from pydiverse.transform.core.verbs import mutate
from pydiverse.transform.polars.polars_table import PolarsEager
from pydiverse.transform.sql.sql_table import SQLTableImpl
from pydiverse.transform.backend.polars import PolarsImpl
from pydiverse.transform.backend.sql import SqlImpl
from pydiverse.transform.pipe.verbs import mutate
except ImportError:
SQLTableImpl = None
PolarsEager = None
SqlImpl = None
PolarsImpl = None


@pytest.mark.parametrize(
"impl_type",
[SQLTableImpl, PolarsEager],
[SqlImpl, PolarsImpl],
)
def test_table_store(impl_type: type):
def cache_fn(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def unicode(src):

@skip_instances("mssql", "mssql_pytsql")
def test_unicode_beyond_mssql():
test_unicode("λ")
test_unicode("C")

0 comments on commit bae7cff

Please sign in to comment.