Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC for new model DX #1311

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

POC for new model DX #1311

wants to merge 5 commits into from

Conversation

nnarayen
Copy link
Contributor

@nnarayen nnarayen commented Jan 14, 2025

🚀 What

Introduces a POC of a python first DX for traditional truss models, inspiration taken from https://www.notion.so/ml-infra/Model-Like-DX-17491d24727380e394bfd0f5ac312bdd.

Can look at the introduced e2e test to see what the user would write with this new framework. The overall goal is to hook into as much of the existing chains code gen framework as possible, so we have little bespoke behavior between the two paths.

💻 How

🔬 Testing

  • End to end deployment on local dev / codespaces
  • New integration test

@@ -648,7 +648,7 @@ def _inplace_fill_base_image(
)


def _make_truss_config(
def write_truss_config_yaml(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this because I thought it was confusing that it returned a TrussConfig object that was never used, took me a while to figure out it was persisting to yaml which a later flow picked up

module_path = pathlib.Path(module_path).resolve()
module_name = module_path.stem # Use the file's name as the module name
if not os.path.isfile(module_path):
resolved_module_path = pathlib.Path(module_path).resolve()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to avoid shadowing

truss/cli/cli.py Outdated
@@ -1132,11 +1133,32 @@ def push(
TARGET_DIRECTORY: A Truss directory. If none, use current directory.

"""
from truss_chains import framework
from truss_chains.deployment.code_gen import write_truss_config_yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this (really cumbersome) trailing dependency to use pydantic v1 and support python 3.8 for truss in general, but chains requires newer - so the chains import should always be in a guard for "traditional" truss code paths - see other locations for reference.

truss/cli/cli.py Outdated
# the config file so _get_truss_from_directory will pick it up
target_path = Path(target_directory)
with framework.import_model_target(
target_path / "model/model.py"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truss dir structure (and also the relative import quirkiness of how "packages" are imported) is something we could simplify now.

What about this:
We generalize TARGET_DIRECTORY:

  1. If it's a directory (or ., or left out), assume traditional truss.
  2. If it's a file, assume it's a chains entrypoint.

Have useful error handling and messaging in either case if the situation doesn't match either of these two assumptions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I don't love overloading the term directory here since I imagine that could get confusing for users, but if the model/ directory structure has proven to be annoying, I think the benefit of having a single file could be worth it (can maybe think about some options to truss init to account for this as well down the road)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's in particular annoying (at least to me):
You have model/model.py and then you have packages/some_name.py but inside model you don't import it from packages namespace, but just as import some_name - also breaks any IDE...

truss/cli/cli.py Outdated
with framework.import_model_target(
target_path / "model/model.py"
) as entrypoint_cls:
write_truss_config_yaml(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write_truss_config_yaml should be abstracted away from CLI (and keept internal/private to the chains deployment module).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I don't think it's good to modify the user code dir. This is "generated" and in a way redundant to what they have in their python source file. If they use source versioning this will be messy. Better to use a tmp dir like for chains code generation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there's a function in chains analoguous to deployment_client.push, e.g. deployment_client.push_model then you could also add a local test of the final built truss model in a docker container like those tests here: https://github.com/basetenlabs/truss/blob/main/truss-chains/tests/test_e2e.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agreed on it being abstracted into a deployment_client, for this POC I put before since this cli peeked into the config early for model_name

truss/truss/cli/cli.py

Lines 1162 to 1166 in 5ecb0dd

tr = _get_truss_from_directory(target_directory=target_directory)
model_name = model_name or tr.spec.config.model_name
if not model_name:
model_name = inquire_model_name()

except ValueError: # In case the value was already removed for whatever reason.
pass


# NB(nikhil): mainly taken from above, but with some dependency logic removed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this can be consolidated into one function down the road?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely!

return cls.remote_config.name or cls.name

@abc.abstractmethod
def predict(self, request: Any) -> Any: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use request, it's really confusing if this is a request object or just the body/data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg to me! Will likely end up not enforcing this anyways, since requiring users to specify a given arg name seems like bad UX


# NB(nikhil): This seems to pass even when my definition doesn't have remote_config, likely
# pulling from default
if not hasattr(cls, definitions.REMOTE_CONFIG_NAME):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class var has a default assigned - do we have the same validation for chainlets? It doesn't really make sense to validate this, if it's given by default anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we do! I mainly copied from the chainlet definition here, agreed I can clean up both if we don't think it's valuable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be removed.

Comment on lines 780 to 777
# NB(nikhil): Following lines cause ERROR TypeError: <class 'model.Model'> is a built-in class
# src_path = os.path.abspath(inspect.getfile(cls))
# line = inspect.getsourcelines(cls)[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's track this down - but first make the change away from the convoluted truss dir structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was just missing this line in my first attempt, doesn't seem to be an issue anymore

@nnarayen nnarayen force-pushed the nikhil/model-base-chainlet branch 15 times, most recently from 6f78f3e to 76468b3 Compare January 18, 2025 01:01
@@ -138,9 +138,10 @@ def _prepare_push(
if model_name.isspace():
raise ValueError("Model name cannot be empty")

gathered_truss = TrussHandle(truss_handle.gather())
if truss_handle.is_scattered():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small cleanup, no need to recreate the entire TrussHandle (which reloads the YAML) if we're not scattered


logging.debug(
f"Deleting modules when exiting import context: {modules_to_delete}"
def _load_module(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_load_module and _cleanup_module_imports are pure refactors to extract logic out of import_target above. It was more helpful when I had separate code to import a ChainletBase and a ModelBase, but now those have been consolidated

@@ -122,6 +122,19 @@ def __init_with_arg_check__(self, *args, **kwargs):
cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign]


class ModelBase(definitions.ABCChainlet):
Copy link
Contributor Author

@nnarayen nnarayen Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new class allows us to change behavior slightly between traditional chains and solo models being deployed using the chains framework. For now, the main difference is the is_entrypoint below, but can imagine this growing over time

cls: Type[definitions.ABCChainlet], location: _ErrorLocation
) -> None:
if not hasattr(cls, definitions.REMOTE_CONFIG_NAME):
_collect_error(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can safely remove this since it'll never fire - the base abstract class instantiates this

@@ -138,10 +137,8 @@ class _ChainSourceGenerator:
def __init__(
self,
options: definitions.PushOptions,
gen_root: pathlib.Path,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small cleanup, this gen_root was passed around to allow callers to optionally pass in a given directory, but every callsite I tracked down defaulted to a generated temp dir

write_truss_config_yaml(
chainlet_dir=chainlet_dir,
chains_config=chainlet_descriptor.chainlet_cls.remote_config,
model_name=model_name or chain_name,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chain models have their name suffixed, but for traditional trusses we'd want to preserve the name specified in RemoteConfig exactly (or the one autoderived from the class name)

We can push this coalesce down into write_truss_config_yaml if we think that's cleaner

name="OverridePassthroughModelName",
docker_image=chains.DockerImage(
pip_requirements=[
"truss==0.9.59rc2",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traditional docker entrypoint doesn't have an equivalent to use_local_chains_src, this was a hack to get the e2e test to pass for now

@@ -31,6 +31,9 @@ def populate_chainlet_service_predict_urls(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
) -> Mapping[str, definitions.DeployedServiceDescriptor]:
chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {}
# If there are no dependencies of this chainlet, no need to derive dynamic URLs
if len(chainlet_to_service) == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed to fix the runtime behavior for a traditional truss deployed through the chains framework. The better long term fix might be to change the generated code, but I thought this should be fine for now

@nnarayen nnarayen force-pushed the nikhil/model-base-chainlet branch from 76468b3 to f09cfcb Compare January 18, 2025 01:13
@nnarayen nnarayen force-pushed the nikhil/model-base-chainlet branch from f09cfcb to c5cad3d Compare January 18, 2025 01:13
@nnarayen nnarayen marked this pull request as ready for review January 18, 2025 01:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants