diff --git a/modal/_utils/function_utils.py b/modal/_utils/function_utils.py index d6b9aacc9..f0c1840c7 100644 --- a/modal/_utils/function_utils.py +++ b/modal/_utils/function_utils.py @@ -307,7 +307,7 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo: format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters ) - def get_entrypoint_mount(self) -> list[_Mount]: + def get_entrypoint_mount(self) -> dict[str, _Mount]: """ Includes: * Implicit mount of the function itself (the module or package that the function is part of) @@ -321,32 +321,38 @@ def get_entrypoint_mount(self) -> list[_Mount]: """ if self._type == FunctionInfoType.NOTEBOOK: # Don't auto-mount anything for notebooks. - return [] + return {} # make sure the function's own entrypoint is included: if self._type == FunctionInfoType.PACKAGE: + top_level_package = self.module_name.split(".", 1)[0] + if config.get("automount"): - return [_Mount.from_local_python_packages(self.module_name)] + # with automount, sys.modules will include the top level package an automount it anyways, + # so let's include it here for correctness and let it be deduplicated: + return {top_level_package: _Mount.from_local_python_packages(top_level_package)} elif not self.is_serialized(): - # mount only relevant file and __init__.py:s - return [ - _Mount.from_local_dir( + # mount only relevant modal file and __init__.py:s of the package? + return { + top_level_package: _Mount.from_local_dir( self._base_dir, remote_path=self._remote_dir, recursive=True, condition=entrypoint_only_package_mount_condition(self._file), ) - ] + } elif not self.is_serialized(): - remote_path = ROOT_DIR / Path(self._file).name + module_file = Path(self._file) + container_module_name = module_file.stem + remote_path = ROOT_DIR / module_file.name if not _is_modal_path(remote_path): - return [ - _Mount.from_local_file( + return { + container_module_name: _Mount.from_local_file( self._file, remote_path=remote_path, ) - ] - return [] + } + return {} def get_tag(self): return self.function_name diff --git a/modal/functions.py b/modal/functions.py index fe536e265..74181bc26 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -69,7 +69,7 @@ ) from .gpu import GPU_T, parse_gpu_config from .image import _Image -from .mount import _get_client_mount, _Mount, get_auto_mounts +from .mount import _get_client_mount, _Mount, get_sys_modules_mounts from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos from .object import _get_environment_name, _Object, live_method, live_method_gen from .output import _get_output_manager @@ -502,11 +502,16 @@ def from_local( all_mounts = [ _get_client_mount(), *explicit_mounts, - *entrypoint_mounts, + *entrypoint_mounts.values(), ] if config.get("automount"): - auto_mounts = get_auto_mounts() + auto_mounts = get_sys_modules_mounts() + print("entrypoints", entrypoint_mounts.keys()) + # don't need to add entrypoint modules to automounts: + for entrypoint_module in entrypoint_mounts: + auto_mounts.pop(entrypoint_module, None) + warn_missing_modules = set(auto_mounts.keys()) - image._added_python_source_set print(warn_missing_modules) if warn_missing_modules: diff --git a/modal/mount.py b/modal/mount.py index 99e30f440..55c7e59f2 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -760,7 +760,7 @@ def _is_modal_path(remote_path: PurePosixPath): return False -def get_auto_mounts() -> dict[str, _Mount]: +def get_sys_modules_mounts() -> dict[str, _Mount]: """mdmd:hidden Auto-mount local modules that have been imported in global scope. diff --git a/test/mount_test.py b/test/mount_test.py index 001517793..d8fb4b3ad 100644 --- a/test/mount_test.py +++ b/test/mount_test.py @@ -3,7 +3,6 @@ import os import platform import pytest -import re from pathlib import Path, PurePosixPath from test.helpers import deploy_app_externally @@ -206,7 +205,7 @@ def test_missing_python_source_warning(servicer, credentials, supports_dir): # should warn if function doesn't have an imported non-third-party package attached # either through add OR copy mode, unless automount=False mode is used def has_warning(output: str): - return re.match(r".*added the source for the following modules.*:\npkg_d\n.*", output, re.DOTALL) + return 'image.add_local_python_source("pkg_a")' in output output = deploy_app_externally(servicer, credentials, "pkg_d.main", cwd=supports_dir, capture_output=True) assert has_warning(output) diff --git a/test/mounted_files_test.py b/test/mounted_files_test.py index dcc49e39e..5db1510be 100644 --- a/test/mounted_files_test.py +++ b/test/mounted_files_test.py @@ -9,7 +9,7 @@ import modal from modal import Mount -from modal.mount import get_auto_mounts +from modal.mount import get_sys_modules_mounts from . import helpers from .supports.skip import skip_windows @@ -51,7 +51,7 @@ async def env_mount_files(): # If something is installed using pip -e, it will be bundled up as a part of the environment. # Those are env-specific so we ignore those as a part of the test filenames = [] - for mount in get_auto_mounts().values(): + for mount in get_sys_modules_mounts().values(): async for file_info in mount._get_files(mount.entries): filenames.append(file_info.mount_filename) diff --git a/test/supports/pkg_d/main.py b/test/supports/pkg_d/main.py index ea8ba0232..7d758a0d9 100644 --- a/test/supports/pkg_d/main.py +++ b/test/supports/pkg_d/main.py @@ -1,18 +1,18 @@ import os -import modal +from pkg_a import a # noqa # this would cause an automount warning -from . import sibling # noqa # warn if sibling source isn't attached +import modal app = modal.App() image = modal.Image.debian_slim() if os.environ.get("ADD_SOURCE") == "add": - image = image.add_local_python_source("pkg_d") + image = image.add_local_python_source("pkg_a") elif os.environ.get("ADD_SOURCE") == "copy": - image = image.add_local_python_source("pkg_d", copy=True) + image = image.add_local_python_source("pkg_a", copy=True) @app.function(image=image)