Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] MessagePack IDL, Pydantic Support, and Attribute Access #1770

Merged
merged 9 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ coverage
pre-commit
codespell
mock
pydantic>2
pytest
mypy
mashumaro
Expand Down
2 changes: 1 addition & 1 deletion examples/data_types_and_io/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ RUN python3 -m venv ${VENV}
ENV PATH="${VENV}/bin:$PATH"

RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/pip,id=pip \
pip install flytekit pandas pyarrow
pip install flytekit pandas pyarrow pydantic>2
RUN --mount=type=cache,sharing=locked,mode=0777,target=/root/.cache/pip,id=pip \
pip install torch --index-url https://download.pytorch.org/whl/cpu

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dataclasses import dataclass

from dataclasses_json import dataclass_json
from flytekit import task, workflow


Expand Down Expand Up @@ -36,7 +35,6 @@ def dict_wf():


# Directly access an attribute of a dataclass
@dataclass_json
@dataclass
class Fruit:
name: str
Expand Down
100 changes: 100 additions & 0 deletions examples/data_types_and_io/data_types_and_io/pydantic_basemodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import tempfile

import pandas as pd
from flytekit import ImageSpec, task, workflow
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset
from pydantic import BaseModel

image_spec = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["pandas", "pyarrow", "pydantic"],
)


# Python types
# Define a Pydantic model with `int`, `str`, and `dict` as the data types
class Datum(BaseModel):
x: int
y: str
z: dict[int, str]


# Once declared, a Pydantic model can be returned as an output or accepted as an input
@task(container_image=image_spec)
def stringify(s: int) -> Datum:
"""
A Pydantic model return will be treated as a single complex JSON return.
"""
return Datum(x=s, y=str(s), z={s: str(s)})


@task(container_image=image_spec)
def add(x: Datum, y: Datum) -> Datum:
"""
Flytekit automatically converts the provided JSON into a Pydantic model.
If the structures don't match, it triggers a runtime failure.
"""
x.z.update(y.z)
return Datum(x=x.x + y.x, y=x.y + y.y, z=x.z)


# Flyte types
class FlyteTypes(BaseModel):
dataframe: StructuredDataset
file: FlyteFile
directory: FlyteDirectory


@task(container_image=image_spec)
def upload_data() -> FlyteTypes:
"""
Flytekit will upload FlyteFile, FlyteDirectory, and StructuredDataset to the blob store,
such as GCP or S3.
"""
# 1. StructuredDataset
df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

# 2. FlyteDirectory
temp_dir = tempfile.mkdtemp(prefix="flyte-")
df.to_parquet(os.path.join(temp_dir, "df.parquet"))

# 3. FlyteFile
file_path = tempfile.NamedTemporaryFile(delete=False)
file_path.write(b"Hello, World!")
file_path.close()

fs = FlyteTypes(
dataframe=StructuredDataset(dataframe=df),
file=FlyteFile(file_path.name),
directory=FlyteDirectory(temp_dir),
)
return fs


@task(container_image=image_spec)
def download_data(res: FlyteTypes):
expected_df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})
actual_df = res.dataframe.open(pd.DataFrame).all()
assert expected_df.equals(actual_df), "DataFrames do not match!"

with open(res.file, "r") as f:
assert f.read() == "Hello, World!", "File contents do not match!"

assert os.listdir(res.directory) == ["df.parquet"], "Directory contents do not match!"


# Define a workflow that calls the tasks created above
@workflow
def basemodel_wf(x: int, y: int) -> (Datum, FlyteTypes):
o1 = add(x=stringify(s=x), y=stringify(s=y))
o2 = upload_data()
download_data(res=o2)
return o1, o2


# Run the workflow locally
if __name__ == "__main__":
basemodel_wf(x=10, y=20)
1 change: 1 addition & 0 deletions examples/data_types_and_io/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ torch
tabulate
tensorflow
pyarrow
pydantic>2
Loading