Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require credentials for non-container clients #2361

Closed
wants to merge 9 commits into from
8 changes: 4 additions & 4 deletions modal/cli/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def list(json: Optional[bool] = False):
table.add_row(*content, style=highlight if active else "dim")
console.print(table)

if env_based_workspace is not None:
console.print(
f"Using [bold]{env_based_workspace}[/bold] workspace based on environment variables", style="yellow"
)
if env_based_workspace is not None:
console.print(
f"Using [bold]{env_based_workspace}[/bold] workspace based on environment variables", style="yellow"
)
71 changes: 31 additions & 40 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import create_channel, retry_transient_errors
from ._utils.http_utils import ClientSessionRegistry
from .config import _check_config, config, logger
from .config import _check_config, _is_remote, config, logger
from .exception import AuthError, ClientClosed, ConnectionError, DeprecationError, VersionError

HEARTBEAT_INTERVAL: float = config.get("heartbeat_interval")
Expand Down Expand Up @@ -100,6 +100,7 @@ class _Client:
_cancellation_context: TaskContext
_cancellation_context_event_loop: asyncio.AbstractEventLoop = None
_stub: Optional[api_grpc.ModalClientStub]
_credentials: Optional[Tuple[str, str]]

def __init__(
self,
Expand All @@ -115,7 +116,6 @@ def __init__(
self.client_type = client_type
self._credentials = credentials
self.version = version
self._authenticated = False
self._closed = False
self._channel: Optional[grpclib.client.Channel] = None
self._stub: Optional[modal_api_grpc.ModalClientModal] = None
Expand All @@ -127,7 +127,7 @@ def is_closed(self) -> bool:

@property
def authenticated(self):
return self._authenticated
return self._credentials is not None
Comment on lines 129 to +130
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might just be hard to think about because there are many different ways of creating a Client, but this strikes me as wrong? Can't you end up with a Client that has credentials attached, but we don't know if they are valid? Not sure what the practical implications would be, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yeah I think this is actually wrong (and let me try to add a test for it) but for a different reason I think. If a Modal app running in the cloud is trying to deploy a Modal app then I think it asserts that the client is authenticated, which will fail. Honestly think it's maybe better to remove this code altogether and just fail in respective handlers – I'm not sure why we need this check.


@property
def stub(self) -> modal_api_grpc.ModalClientModal:
Expand Down Expand Up @@ -159,7 +159,7 @@ async def _close(self, prep_for_restore: bool = False):
# Remove cached client.
self.set_env_client(None)

async def _init(self):
async def _hello(self):
"""Connect to server and retrieve version information; raise appropriate error for various failures."""
logger.debug("Client: Starting")
_check_config()
Expand All @@ -174,7 +174,6 @@ async def _init(self):
if resp.warning:
ALARM_EMOJI = chr(0x1F6A8)
warnings.warn(f"{ALARM_EMOJI} {resp.warning} {ALARM_EMOJI}", DeprecationError)
self._authenticated = True
except GRPCError as exc:
if exc.status == Status.FAILED_PRECONDITION:
raise VersionError(
Expand All @@ -191,7 +190,7 @@ async def _init(self):
async def __aenter__(self):
await self._open()
try:
await self._init()
await self._hello()
except BaseException:
await self._close()
raise
Expand All @@ -210,7 +209,7 @@ async def anonymous(cls, server_url: str) -> AsyncIterator["_Client"]:
client = cls(server_url, api_pb2.CLIENT_TYPE_CLIENT, credentials=None)
try:
await client._open()
# Skip client._init
# Skip client._hello
yield client
finally:
await client._close()
Expand All @@ -226,44 +225,36 @@ async def from_env(cls, _override_config=None) -> "_Client":
else:
c = config

server_url = c["server_url"]

token_id = c["token_id"]
token_secret = c["token_secret"]
task_id = c["task_id"]
credentials = None

if task_id:
client_type = api_pb2.CLIENT_TYPE_CONTAINER
else:
client_type = api_pb2.CLIENT_TYPE_CLIENT
if token_id and token_secret:
credentials = (token_id, token_secret)

if cls._client_from_env_lock is None:
cls._client_from_env_lock = asyncio.Lock()

async with cls._client_from_env_lock:
if cls._client_from_env:
return cls._client_from_env

server_url = c["server_url"]

if _is_remote():
credentials = None
client_type = api_pb2.CLIENT_TYPE_CONTAINER
Comment on lines +237 to +239
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we're setting ourselves up for (more) confusion about the two distinct meanings of the local/remote dichotomy in Modal:

  • practical sense:
    • local: not on modal infrastructure
    • remote: on modal infrastructure
  • functional sense:
    • local: the context that you run or deploy your applications from
    • remote: the context that you run or deploy your applications in

(This is maybe not a perfect set of definitions but hopefully it makes sense)

The _is_remote helper indicates which side you're on in the practical sense (based on the environment variables that we define in containers).

But what about users who want to run or deploy applications from Modal infrastructure? I thought we had people doing that; including people doing it on behalf of their users (such that the orchestrator app and worker app are using different tokens). Are we breaking that behavior? Maybe we're not — but it feels very hard to reason about.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a test for deploying Modal apps from inside containers. It's possible this PR broke it because of the authenticated check. That being said, the authentication system (client vs container vs web) is basically orthogonal to what we're doing (creating an app vs running code inside an app).

else:
client = _Client(server_url, client_type, credentials)
await client._open()
async_utils.on_shutdown(client._close())
try:
await client._init()
except AuthError:
if not credentials:
creds_missing_msg = (
"Token missing. Could not authenticate client."
" If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
" If you are a new user, register an account at modal.com, then run `modal token new`."
)
raise AuthError(creds_missing_msg)
else:
raise
cls._client_from_env = client
return client
token_id = c["token_id"]
token_secret = c["token_secret"]
if not token_id or not token_secret:
raise AuthError(
"Token missing. Could not authenticate client."
" If you have token credentials, see modal.com/docs/reference/modal.config for setup help."
" If you are a new user, register an account at modal.com, then run `modal token new`."
)
credentials = (token_id, token_secret)
client_type = api_pb2.CLIENT_TYPE_CLIENT

client = _Client(server_url, client_type, credentials)
await client._open()
async_utils.on_shutdown(client._close())
await client._hello()
cls._client_from_env = client
return client

@classmethod
async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client":
Expand All @@ -284,7 +275,7 @@ async def from_credentials(cls, token_id: str, token_secret: str) -> "_Client":
client = _Client(server_url, client_type, credentials)
await client._open()
try:
await client._init()
await client._hello()
except BaseException:
await client._close()
raise
Expand Down Expand Up @@ -345,7 +336,7 @@ async def _reset_on_pid_change(self):
self.set_env_client(None)
# TODO(elias): reset _cancellation_context in case ?
await self._open()
# intentionally not doing self._init since we should already be authenticated etc.
# intentionally not doing self._hello since we should already be authenticated etc.

async def _get_grpclib_method(self, method_name: str) -> Any:
# safely get grcplib method that is bound to a valid channel
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,6 +1768,8 @@ async def container_client(servicer):
@pytest_asyncio.fixture(scope="function")
async def server_url_env(servicer, monkeypatch):
monkeypatch.setenv("MODAL_SERVER_URL", servicer.client_addr)
monkeypatch.setenv("MODAL_TOKEN_ID", "ak-123")
monkeypatch.setenv("MODAL_TOKEN_SECRET", "as-123")
yield


Expand Down
36 changes: 20 additions & 16 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import pytest
import time
from contextlib import contextmanager
from typing import Dict
from unittest import mock

Expand Down Expand Up @@ -71,6 +72,20 @@ async def stop_app(client, app_id):
return await retry_transient_errors(client.stub.AppStop, api_pb2.AppStopRequest(app_id=app_id))


@contextmanager
def set_env_vars(restore_path, container_addr):
with mock.patch.dict(
os.environ,
{
"MODAL_RESTORE_STATE_PATH": str(restore_path),
"MODAL_SERVER_URL": container_addr,
"MODAL_TOKEN_ID": "ak-123",
"MODAL_TOKEN_SECRET": "as-123",
},
):
yield


@pytest.mark.asyncio
async def test_container_snapshot_reference_capture(container_client, tmpdir, servicer, client):
app = App()
Expand All @@ -86,9 +101,7 @@ async def test_container_snapshot_reference_capture(container_client, tmpdir, se
assert f.object_id == "fu-1"
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
with set_env_vars(restore_path, servicer.container_addr):
io_manager.memory_snapshot()

# Stop the App, invalidating the fu- ID stored in `f`.
Expand All @@ -112,10 +125,7 @@ def test_container_snapshot_restore_heartbeats(tmpdir, servicer, container_clien
# Ensure that heartbeats only run after the snapshot
heartbeat_interval_secs = 0.01
with io_manager.heartbeats(True):
with mock.patch.dict(
os.environ,
{"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
with set_env_vars(restore_path, servicer.container_addr):
with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs):
time.sleep(heartbeat_interval_secs * 2)
assert not list(
Expand All @@ -140,9 +150,7 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer):
# Test that the breakpoint was called
test_breakpoint = mock.Mock()
with mock.patch("sys.breakpointhook", test_breakpoint):
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
with set_env_vars(restore_path, servicer.container_addr):
io_manager.memory_snapshot()
test_breakpoint.assert_called_once()

Expand Down Expand Up @@ -192,9 +200,7 @@ async def test_container_snapshot_patching(fake_torch_module, container_client,

# Write out a restore file so that snapshot+restore will complete
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
with set_env_vars(restore_path, servicer.container_addr):
io_manager.memory_snapshot()
assert torch.cuda.device_count() == 2

Expand All @@ -211,9 +217,7 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli

assert torch.IM_WEIRD == 42

with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
with set_env_vars(restore_path, servicer.container_addr):
io_manager.memory_snapshot() # should not crash


Expand Down
4 changes: 3 additions & 1 deletion test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def _run_container(
with pathlib.Path(tmp_file_name).open("w") as target:
json.dump({}, target)
env["MODAL_RESTORE_STATE_PATH"] = tmp_file_name
env["MODAL_TOKEN_ID"] = "ak-123"
env["MODAL_TOKEN_SECRET"] = "as-123"

# Override server URL to reproduce restore behavior.
env["MODAL_SERVER_URL"] = servicer.container_addr
Expand Down Expand Up @@ -1091,7 +1093,7 @@ def test_cli(servicer):
servicer.container_inputs = _get_inputs()

# Launch subprocess
env = {"MODAL_SERVER_URL": servicer.container_addr}
env = {"MODAL_TOKEN_ID": "ak-123", "MODAL_TOKEN_SECRET": "as-12", "MODAL_SERVER_URL": servicer.container_addr}
lib_dir = pathlib.Path(__file__).parent.parent
args: List[str] = [sys.executable, "-m", "modal._container_entrypoint", data_base64]
ret = subprocess.run(args, cwd=lib_dir, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand Down
2 changes: 2 additions & 0 deletions test/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def _cli(args, server_url, extra_env={}, check=True) -> Tuple[int, str, str]:
args = [sys.executable] + args
env = {
"MODAL_SERVER_URL": server_url,
"MODAL_TOKEN_ID": "ak-123",
"MODAL_TOKEN_SECRET": "as-123",
**os.environ,
"PYTHONUTF8": "1", # For windows
**extra_env,
Expand Down
9 changes: 8 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@ def deploy_app_externally(
**{"PYTHONUTF8": "1"},
} # windows apparently needs a bunch of env vars to start python...

env = {**windows_support, "MODAL_SERVER_URL": servicer.client_addr, "MODAL_ENVIRONMENT": "main", **env}
env = {
**windows_support,
"MODAL_SERVER_URL": servicer.client_addr,
"MODAL_ENVIRONMENT": "main",
"MODAL_TOKEN_ID": "ak-123",
"MODAL_TOKEN_SECRET": "as-123",
**env,
}
if cwd is None:
cwd = pathlib.Path(__file__).parent.parent

Expand Down
Loading