diff --git a/modal/cli/profile.py b/modal/cli/profile.py index e3dec84c8..72a316421 100644 --- a/modal/cli/profile.py +++ b/modal/cli/profile.py @@ -80,7 +80,7 @@ async def list(json: Optional[bool] = False): table.add_row(*content, style=highlight if active else "dim") console.print(table) - if env_based_workspace is not None: - console.print( - f"Using [bold]{env_based_workspace}[/bold] workspace based on environment variables", style="yellow" - ) + if env_based_workspace is not None: + console.print( + f"Using [bold]{env_based_workspace}[/bold] workspace based on environment variables", style="yellow" + ) diff --git a/modal/client.py b/modal/client.py index a865fbb20..14eda7b93 100644 --- a/modal/client.py +++ b/modal/client.py @@ -33,7 +33,7 @@ from ._utils.async_utils import TaskContext, synchronize_api from ._utils.grpc_utils import create_channel, retry_transient_errors from ._utils.http_utils import ClientSessionRegistry -from .config import _check_config, config, logger +from .config import _check_config, _is_remote, config, logger from .exception import AuthError, ClientClosed, ConnectionError, DeprecationError, VersionError HEARTBEAT_INTERVAL: float = config.get("heartbeat_interval") @@ -100,6 +100,7 @@ class _Client: _cancellation_context: TaskContext _cancellation_context_event_loop: asyncio.AbstractEventLoop = None _stub: Optional[api_grpc.ModalClientStub] + _credentials: Optional[Tuple[str, str]] def __init__( self, @@ -115,7 +116,6 @@ def __init__( self.client_type = client_type self._credentials = credentials self.version = version - self._authenticated = False self._closed = False self._channel: Optional[grpclib.client.Channel] = None self._stub: Optional[modal_api_grpc.ModalClientModal] = None @@ -127,7 +127,7 @@ def is_closed(self) -> bool: @property def authenticated(self): - return self._authenticated + return self._credentials is not None @property def stub(self) -> modal_api_grpc.ModalClientModal: @@ -159,7 +159,7 @@ async def _close(self, prep_for_restore: bool = False): # Remove cached client. self.set_env_client(None) - async def _init(self): + async def _hello(self): """Connect to server and retrieve version information; raise appropriate error for various failures.""" logger.debug("Client: Starting") _check_config() @@ -174,7 +174,6 @@ async def _init(self): if resp.warning: ALARM_EMOJI = chr(0x1F6A8) warnings.warn(f"{ALARM_EMOJI} {resp.warning} {ALARM_EMOJI}", DeprecationError) - self._authenticated = True except GRPCError as exc: if exc.status == Status.FAILED_PRECONDITION: raise VersionError( @@ -191,7 +190,7 @@ async def _init(self): async def __aenter__(self): await self._open() try: - await self._init() + await self._hello() except BaseException: await self._close() raise @@ -210,7 +209,7 @@ async def anonymous(cls, server_url: str) -> AsyncIterator["_Client"]: client = cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials=None) try: await client._open() - # Skip client._init + # Skip client._hello yield client finally: await client._close() @@ -226,44 +225,36 @@ async def from_env(cls, _override_config=None) -> "_Client": else: c = config - server_url = c["server_url"] - - token_id = c["token_id"] - token_secret = c["token_secret"] - task_id = c["task_id"] - credentials = None - - if task_id: - client_type = api_pb2.CLIENT_TYPE_CONTAINER - else: - client_type = api_pb2.CLIENT_TYPE_CLIENT - if token_id and token_secret: - credentials = (token_id, token_secret) - if cls._client_from_env_lock is None: cls._client_from_env_lock = asyncio.Lock() async with cls._client_from_env_lock: if cls._client_from_env: return cls._client_from_env + + server_url = c["server_url"] + + if _is_remote(): + credentials = None + client_type = api_pb2.CLIENT_TYPE_CONTAINER else: - client = _Client(server_url, client_type, credentials) - await client._open() - async_utils.on_shutdown(client._close()) - try: - await client._init() - except AuthError: - if not credentials: - creds_missing_msg = ( - "Token missing. Could not authenticate client." - " If you have token credentials, see modal.com/docs/reference/modal.config for setup help." - " If you are a new user, register an account at modal.com, then run `modal token new`." - ) - raise AuthError(creds_missing_msg) - else: - raise - cls._client_from_env = client - return client + token_id = c["token_id"] + token_secret = c["token_secret"] + if not token_id or not token_secret: + raise AuthError( + "Token missing. Could not authenticate client." + " If you have token credentials, see modal.com/docs/reference/modal.config for setup help." + " If you are a new user, register an account at modal.com, then run `modal token new`." + ) + credentials = (token_id, token_secret) + client_type = api_pb2.CLIENT_TYPE_CLIENT + + client = _Client(server_url, client_type, credentials) + await client._open() + async_utils.on_shutdown(client._close()) + await client._hello() + cls._client_from_env = client + return client @classmethod async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client": @@ -284,7 +275,7 @@ async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client": client = _Client(server_url, client_type, credentials) await client._open() try: - await client._init() + await client._hello() except BaseException: await client._close() raise @@ -345,7 +336,7 @@ async def _reset_on_pid_change(self): self.set_env_client(None) # TODO(elias): reset _cancellation_context in case ? await self._open() - # intentionally not doing self._init since we should already be authenticated etc. + # intentionally not doing self._hello since we should already be authenticated etc. async def _get_grpclib_method(self, method_name: str) -> Any: # safely get grcplib method that is bound to a valid channel diff --git a/test/conftest.py b/test/conftest.py index 6520219b4..4a314cdd4 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1768,6 +1768,8 @@ async def container_client(servicer): @pytest_asyncio.fixture(scope="function") async def server_url_env(servicer, monkeypatch): monkeypatch.setenv("MODAL_SERVER_URL", servicer.client_addr) + monkeypatch.setenv("MODAL_TOKEN_ID", "ak-123") + monkeypatch.setenv("MODAL_TOKEN_SECRET", "as-123") yield diff --git a/test/container_app_test.py b/test/container_app_test.py index 62e79b8d1..e73721528 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -4,6 +4,7 @@ import os import pytest import time +from contextlib import contextmanager from typing import Dict from unittest import mock @@ -71,6 +72,20 @@ async def stop_app(client, app_id): return await retry_transient_errors(client.stub.AppStop, api_pb2.AppStopRequest(app_id=app_id)) +@contextmanager +def set_env_vars(restore_path, container_addr): + with mock.patch.dict( + os.environ, + { + "MODAL_RESTORE_STATE_PATH": str(restore_path), + "MODAL_SERVER_URL": container_addr, + "MODAL_TOKEN_ID": "ak-123", + "MODAL_TOKEN_SECRET": "as-123", + }, + ): + yield + + @pytest.mark.asyncio async def test_container_snapshot_reference_capture(container_client, tmpdir, servicer, client): app = App() @@ -86,9 +101,7 @@ async def test_container_snapshot_reference_capture(container_client, tmpdir, se assert f.object_id == "fu-1" io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) restore_path = temp_restore_path(tmpdir) - with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} - ): + with set_env_vars(restore_path, servicer.container_addr): io_manager.memory_snapshot() # Stop the App, invalidating the fu- ID stored in `f`. @@ -112,10 +125,7 @@ def test_container_snapshot_restore_heartbeats(tmpdir, servicer, container_clien # Ensure that heartbeats only run after the snapshot heartbeat_interval_secs = 0.01 with io_manager.heartbeats(True): - with mock.patch.dict( - os.environ, - {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, - ): + with set_env_vars(restore_path, servicer.container_addr): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs): time.sleep(heartbeat_interval_secs * 2) assert not list( @@ -140,9 +150,7 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer): # Test that the breakpoint was called test_breakpoint = mock.Mock() with mock.patch("sys.breakpointhook", test_breakpoint): - with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} - ): + with set_env_vars(restore_path, servicer.container_addr): io_manager.memory_snapshot() test_breakpoint.assert_called_once() @@ -192,9 +200,7 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, # Write out a restore file so that snapshot+restore will complete restore_path = temp_restore_path(tmpdir) - with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} - ): + with set_env_vars(restore_path, servicer.container_addr): io_manager.memory_snapshot() assert torch.cuda.device_count() == 2 @@ -211,9 +217,7 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli assert torch.IM_WEIRD == 42 - with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} - ): + with set_env_vars(restore_path, servicer.container_addr): io_manager.memory_snapshot() # should not crash diff --git a/test/container_test.py b/test/container_test.py index c4ce220cc..3e5f6371f 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -302,6 +302,8 @@ def _run_container( with pathlib.Path(tmp_file_name).open("w") as target: json.dump({}, target) env["MODAL_RESTORE_STATE_PATH"] = tmp_file_name + env["MODAL_TOKEN_ID"] = "ak-123" + env["MODAL_TOKEN_SECRET"] = "as-123" # Override server URL to reproduce restore behavior. env["MODAL_SERVER_URL"] = servicer.container_addr @@ -1091,7 +1093,7 @@ def test_cli(servicer): servicer.container_inputs = _get_inputs() # Launch subprocess - env = {"MODAL_SERVER_URL": servicer.container_addr} + env = {"MODAL_TOKEN_ID": "ak-123", "MODAL_TOKEN_SECRET": "as-12", "MODAL_SERVER_URL": servicer.container_addr} lib_dir = pathlib.Path(__file__).parent.parent args: List[str] = [sys.executable, "-m", "modal._container_entrypoint", data_base64] ret = subprocess.run(args, cwd=lib_dir, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) diff --git a/test/e2e_test.py b/test/e2e_test.py index af942682a..b8edb1dd5 100644 --- a/test/e2e_test.py +++ b/test/e2e_test.py @@ -11,6 +11,8 @@ def _cli(args, server_url, extra_env={}, check=True) -> Tuple[int, str, str]: args = [sys.executable] + args env = { "MODAL_SERVER_URL": server_url, + "MODAL_TOKEN_ID": "ak-123", + "MODAL_TOKEN_SECRET": "as-123", **os.environ, "PYTHONUTF8": "1", # For windows **extra_env, diff --git a/test/helpers.py b/test/helpers.py index 7064d1057..d726218db 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -26,7 +26,14 @@ def deploy_app_externally( **{"PYTHONUTF8": "1"}, } # windows apparently needs a bunch of env vars to start python... - env = {**windows_support, "MODAL_SERVER_URL": servicer.client_addr, "MODAL_ENVIRONMENT": "main", **env} + env = { + **windows_support, + "MODAL_SERVER_URL": servicer.client_addr, + "MODAL_ENVIRONMENT": "main", + "MODAL_TOKEN_ID": "ak-123", + "MODAL_TOKEN_SECRET": "as-123", + **env, + } if cwd is None: cwd = pathlib.Path(__file__).parent.parent