Skip to content

Commit

Permalink
Change automounting logic to exclude entrypoint package from auto-mou…
Browse files Browse the repository at this point in the history
…nts (since it is added by another path)
  • Loading branch information
freider committed Jan 10, 2025
1 parent 0dbac39 commit 0cf5340
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 24 deletions.
30 changes: 18 additions & 12 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def class_parameter_info(self) -> api_pb2.ClassParameterInfo:
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
)

def get_entrypoint_mount(self) -> list[_Mount]:
def get_entrypoint_mount(self) -> dict[str, _Mount]:
"""
Includes:
* Implicit mount of the function itself (the module or package that the function is part of)
Expand All @@ -321,32 +321,38 @@ def get_entrypoint_mount(self) -> list[_Mount]:
"""
if self._type == FunctionInfoType.NOTEBOOK:
# Don't auto-mount anything for notebooks.
return []
return {}

# make sure the function's own entrypoint is included:
if self._type == FunctionInfoType.PACKAGE:
top_level_package = self.module_name.split(".", 1)[0]

if config.get("automount"):
return [_Mount.from_local_python_packages(self.module_name)]
# with automount, sys.modules will include the top level package an automount it anyways,
# so let's include it here for correctness and let it be deduplicated:
return {top_level_package: _Mount.from_local_python_packages(top_level_package)}
elif not self.is_serialized():
# mount only relevant file and __init__.py:s
return [
_Mount.from_local_dir(
# mount only relevant modal file and __init__.py:s of the package?
return {
top_level_package: _Mount.from_local_dir(
self._base_dir,
remote_path=self._remote_dir,
recursive=True,
condition=entrypoint_only_package_mount_condition(self._file),
)
]
}
elif not self.is_serialized():
remote_path = ROOT_DIR / Path(self._file).name
module_file = Path(self._file)
container_module_name = module_file.stem
remote_path = ROOT_DIR / module_file.name
if not _is_modal_path(remote_path):
return [
_Mount.from_local_file(
return {
container_module_name: _Mount.from_local_file(
self._file,
remote_path=remote_path,
)
]
return []
}
return {}

def get_tag(self):
return self.function_name
Expand Down
11 changes: 8 additions & 3 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
)
from .gpu import GPU_T, parse_gpu_config
from .image import _Image
from .mount import _get_client_mount, _Mount, get_auto_mounts
from .mount import _get_client_mount, _Mount, get_sys_modules_mounts
from .network_file_system import _NetworkFileSystem, network_file_system_mount_protos
from .object import _get_environment_name, _Object, live_method, live_method_gen
from .output import _get_output_manager
Expand Down Expand Up @@ -502,11 +502,16 @@ def from_local(
all_mounts = [
_get_client_mount(),
*explicit_mounts,
*entrypoint_mounts,
*entrypoint_mounts.values(),
]

if config.get("automount"):
auto_mounts = get_auto_mounts()
auto_mounts = get_sys_modules_mounts()
print("entrypoints", entrypoint_mounts.keys())
# don't need to add entrypoint modules to automounts:
for entrypoint_module in entrypoint_mounts:
auto_mounts.pop(entrypoint_module, None)

warn_missing_modules = set(auto_mounts.keys()) - image._added_python_source_set
print(warn_missing_modules)
if warn_missing_modules:
Expand Down
2 changes: 1 addition & 1 deletion modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def _is_modal_path(remote_path: PurePosixPath):
return False


def get_auto_mounts() -> dict[str, _Mount]:
def get_sys_modules_mounts() -> dict[str, _Mount]:
"""mdmd:hidden
Auto-mount local modules that have been imported in global scope.
Expand Down
3 changes: 1 addition & 2 deletions test/mount_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import platform
import pytest
import re
from pathlib import Path, PurePosixPath
from test.helpers import deploy_app_externally

Expand Down Expand Up @@ -206,7 +205,7 @@ def test_missing_python_source_warning(servicer, credentials, supports_dir):
# should warn if function doesn't have an imported non-third-party package attached
# either through add OR copy mode, unless automount=False mode is used
def has_warning(output: str):
return re.match(r".*added the source for the following modules.*:\npkg_d\n.*", output, re.DOTALL)
return 'image.add_local_python_source("pkg_a")' in output

output = deploy_app_externally(servicer, credentials, "pkg_d.main", cwd=supports_dir, capture_output=True)
assert has_warning(output)
Expand Down
4 changes: 2 additions & 2 deletions test/mounted_files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import modal
from modal import Mount
from modal.mount import get_auto_mounts
from modal.mount import get_sys_modules_mounts

from . import helpers
from .supports.skip import skip_windows
Expand Down Expand Up @@ -51,7 +51,7 @@ async def env_mount_files():
# If something is installed using pip -e, it will be bundled up as a part of the environment.
# Those are env-specific so we ignore those as a part of the test
filenames = []
for mount in get_auto_mounts().values():
for mount in get_sys_modules_mounts().values():
async for file_info in mount._get_files(mount.entries):
filenames.append(file_info.mount_filename)

Expand Down
8 changes: 4 additions & 4 deletions test/supports/pkg_d/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os

import modal
from pkg_a import a # noqa # this would cause an automount warning

from . import sibling # noqa # warn if sibling source isn't attached
import modal

app = modal.App()

image = modal.Image.debian_slim()

if os.environ.get("ADD_SOURCE") == "add":
image = image.add_local_python_source("pkg_d")
image = image.add_local_python_source("pkg_a")

elif os.environ.get("ADD_SOURCE") == "copy":
image = image.add_local_python_source("pkg_d", copy=True)
image = image.add_local_python_source("pkg_a", copy=True)


@app.function(image=image)
Expand Down

0 comments on commit 0cf5340

Please sign in to comment.