-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
POC for new model DX #1311
base: main
Are you sure you want to change the base?
POC for new model DX #1311
Conversation
@@ -648,7 +648,7 @@ def _inplace_fill_base_image( | |||
) | |||
|
|||
|
|||
def _make_truss_config( | |||
def write_truss_config_yaml( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed this because I thought it was confusing that it returned a TrussConfig
object that was never used, took me a while to figure out it was persisting to yaml which a later flow picked up
module_path = pathlib.Path(module_path).resolve() | ||
module_name = module_path.stem # Use the file's name as the module name | ||
if not os.path.isfile(module_path): | ||
resolved_module_path = pathlib.Path(module_path).resolve() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to avoid shadowing
truss/cli/cli.py
Outdated
@@ -1132,11 +1133,32 @@ def push( | |||
TARGET_DIRECTORY: A Truss directory. If none, use current directory. | |||
|
|||
""" | |||
from truss_chains import framework | |||
from truss_chains.deployment.code_gen import write_truss_config_yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this (really cumbersome) trailing dependency to use pydantic v1 and support python 3.8 for truss in general, but chains requires newer - so the chains import should always be in a guard for "traditional" truss code paths - see other locations for reference.
truss/cli/cli.py
Outdated
# the config file so _get_truss_from_directory will pick it up | ||
target_path = Path(target_directory) | ||
with framework.import_model_target( | ||
target_path / "model/model.py" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The truss dir structure (and also the relative import quirkiness of how "packages" are imported) is something we could simplify now.
What about this:
We generalize TARGET_DIRECTORY
:
- If it's a directory (or
.
, or left out), assume traditional truss. - If it's a file, assume it's a chains entrypoint.
Have useful error handling and messaging in either case if the situation doesn't match either of these two assumptions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting! I don't love overloading the term directory
here since I imagine that could get confusing for users, but if the model/
directory structure has proven to be annoying, I think the benefit of having a single file could be worth it (can maybe think about some options to truss init
to account for this as well down the road)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's in particular annoying (at least to me):
You have model/model.py
and then you have packages/some_name.py
but inside model you don't import it from packages
namespace, but just as import some_name
- also breaks any IDE...
truss/cli/cli.py
Outdated
with framework.import_model_target( | ||
target_path / "model/model.py" | ||
) as entrypoint_cls: | ||
write_truss_config_yaml( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
write_truss_config_yaml
should be abstracted away from CLI (and keept internal/private to the chains deployment module).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I don't think it's good to modify the user code dir. This is "generated" and in a way redundant to what they have in their python source file. If they use source versioning this will be messy. Better to use a tmp dir like for chains code generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally there's a function in chains analoguous to deployment_client.push
, e.g. deployment_client.push_model
then you could also add a local test of the final built truss model in a docker container like those tests here: https://github.com/basetenlabs/truss/blob/main/truss-chains/tests/test_e2e.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally agreed on it being abstracted into a deployment_client
, for this POC I put before since this cli peeked into the config early for model_name
Lines 1162 to 1166 in 5ecb0dd
tr = _get_truss_from_directory(target_directory=target_directory) | |
model_name = model_name or tr.spec.config.model_name | |
if not model_name: | |
model_name = inquire_model_name() |
except ValueError: # In case the value was already removed for whatever reason. | ||
pass | ||
|
||
|
||
# NB(nikhil): mainly taken from above, but with some dependency logic removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this can be consolidated into one function down the road?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely!
return cls.remote_config.name or cls.name | ||
|
||
@abc.abstractmethod | ||
def predict(self, request: Any) -> Any: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use request, it's really confusing if this is a request object or just the body/data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sg to me! Will likely end up not enforcing this anyways, since requiring users to specify a given arg name seems like bad UX
|
||
# NB(nikhil): This seems to pass even when my definition doesn't have remote_config, likely | ||
# pulling from default | ||
if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class var has a default assigned - do we have the same validation for chainlets? It doesn't really make sense to validate this, if it's given by default anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep we do! I mainly copied from the chainlet definition here, agreed I can clean up both if we don't think it's valuable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it can be removed.
# NB(nikhil): Following lines cause ERROR TypeError: <class 'model.Model'> is a built-in class | ||
# src_path = os.path.abspath(inspect.getfile(cls)) | ||
# line = inspect.getsourcelines(cls)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's track this down - but first make the change away from the convoluted truss dir structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I was just missing this line in my first attempt, doesn't seem to be an issue anymore
6f78f3e
to
76468b3
Compare
@@ -138,9 +138,10 @@ def _prepare_push( | |||
if model_name.isspace(): | |||
raise ValueError("Model name cannot be empty") | |||
|
|||
gathered_truss = TrussHandle(truss_handle.gather()) | |||
if truss_handle.is_scattered(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small cleanup, no need to recreate the entire TrussHandle (which reloads the YAML) if we're not scattered
|
||
logging.debug( | ||
f"Deleting modules when exiting import context: {modules_to_delete}" | ||
def _load_module( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_load_module
and _cleanup_module_imports
are pure refactors to extract logic out of import_target
above. It was more helpful when I had separate code to import a ChainletBase and a ModelBase, but now those have been consolidated
@@ -122,6 +122,19 @@ def __init_with_arg_check__(self, *args, **kwargs): | |||
cls.__init__ = __init_with_arg_check__ # type: ignore[method-assign] | |||
|
|||
|
|||
class ModelBase(definitions.ABCChainlet): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This new class allows us to change behavior slightly between traditional chains and solo models being deployed using the chains framework. For now, the main difference is the is_entrypoint
below, but can imagine this growing over time
cls: Type[definitions.ABCChainlet], location: _ErrorLocation | ||
) -> None: | ||
if not hasattr(cls, definitions.REMOTE_CONFIG_NAME): | ||
_collect_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can safely remove this since it'll never fire - the base abstract class instantiates this
@@ -138,10 +137,8 @@ class _ChainSourceGenerator: | |||
def __init__( | |||
self, | |||
options: definitions.PushOptions, | |||
gen_root: pathlib.Path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small cleanup, this gen_root
was passed around to allow callers to optionally pass in a given directory, but every callsite I tracked down defaulted to a generated temp dir
write_truss_config_yaml( | ||
chainlet_dir=chainlet_dir, | ||
chains_config=chainlet_descriptor.chainlet_cls.remote_config, | ||
model_name=model_name or chain_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chain models have their name suffixed, but for traditional trusses we'd want to preserve the name specified in RemoteConfig
exactly (or the one autoderived from the class name)
We can push this coalesce down into write_truss_config_yaml
if we think that's cleaner
name="OverridePassthroughModelName", | ||
docker_image=chains.DockerImage( | ||
pip_requirements=[ | ||
"truss==0.9.59rc2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traditional docker entrypoint doesn't have an equivalent to use_local_chains_src
, this was a hack to get the e2e test to pass for now
@@ -31,6 +31,9 @@ def populate_chainlet_service_predict_urls( | |||
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], | |||
) -> Mapping[str, definitions.DeployedServiceDescriptor]: | |||
chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {} | |||
# If there are no dependencies of this chainlet, no need to derive dynamic URLs | |||
if len(chainlet_to_service) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed to fix the runtime behavior for a traditional truss deployed through the chains framework. The better long term fix might be to change the generated code, but I thought this should be fine for now
76468b3
to
f09cfcb
Compare
f09cfcb
to
c5cad3d
Compare
🚀 What
Introduces a POC of a python first DX for traditional truss models, inspiration taken from https://www.notion.so/ml-infra/Model-Like-DX-17491d24727380e394bfd0f5ac312bdd.
Can look at the introduced e2e test to see what the user would write with this new framework. The overall goal is to hook into as much of the existing chains code gen framework as possible, so we have little bespoke behavior between the two paths.
💻 How
🔬 Testing