Skip to content

Commit

Permalink
Generalizes pre- and post-spmd setup with init modules. (#848)
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee authored Nov 19, 2024
1 parent 2803b36 commit 1af2ba8
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 27 deletions.
11 changes: 6 additions & 5 deletions axlearn/cloud/gcp/tpu_health_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
shortest timeout. Pairwise health check should have the longest timeout since different slices may
bring up their container at different times.
The main API is the `health_check` function, which is commonly enabled via context manager:
with health_check(spec, output_dir=...):
The main API is the `setup` function, which is commonly enabled via context manager:
```
with setup(spec, output_dir=...):
# Initialize jax distributed.
```
"""

import os
import signal
import subprocess
Expand All @@ -31,7 +33,6 @@
from typing import Literal, Optional, Union

import tensorflow as tf
import tensorflow_io # pylint: disable=unused-import
from absl import logging

from axlearn.cloud.gcp import tpu_health_check_main
Expand Down Expand Up @@ -127,7 +128,7 @@ def _run_health_check_program(


@contextmanager
def health_check(check_spec: str, *, output_dir: str):
def setup(check_spec: str, *, output_dir: str):
_pre_init_health_check(check_spec, output_dir=output_dir)
yield
# Skip global health check if there's an exception.
Expand Down
49 changes: 27 additions & 22 deletions axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

"""A library with common flags to launch a trainer."""

import contextlib
import importlib
import os
import sys

# pylint: disable=wrong-import-position,wrong-import-order
from contextlib import nullcontext

# pylint: disable-next=ungrouped-imports
from axlearn.common import compiler_options

# pylint: disable=wrong-import-position,wrong-import-order


instance_type = os.environ.get("TPU_TYPE", "none")
num_tpu_slices = int(os.environ.get("NUM_TPU_SLICES", 1))

Expand Down Expand Up @@ -77,36 +78,40 @@
os.environ.get("PROCESS_ID", None),
"Rank of the current process. Must be None on tpu, otherwise required.",
)
flags.DEFINE_string(
"health_check_module",
None,
"Path to health check module to run, e.g. axlearn.cloud.gcp.tpu_health_check. "
"Defaults to None, meaning no health check will run.",
)
flags.DEFINE_string(
"health_check_spec",
"",
"See the docstring of your `health_check_module`.",
flags.DEFINE_multi_string(
"init_module",
[],
"Zero or more init modules to import prior to setting up JAX distributed. "
"Each flag value should be a string containing 'module_path' or 'module_path:spec', e.g. "
"'axlearn.cloud.gcp.tpu_health_check' or 'axlearn.cloud.gcp.tpu_health_check:output_dir=...'.\n"
"The module should expose a public function `setup`, a context manager exposing pre- and post-"
"SPMD setup logic which is entered prior to `setup_spmd` and exited immediately afterwards.\n"
"The spec (if provided) will be provided to `module.setup(spec)` and therefore can be "
"implementation dependent. Not specifying a spec is equivalent to passing `None` to `setup`.\n"
"If specifying multiple modules, each `setup` context is entered in the given order.",
)

FLAGS = flags.FLAGS


# Kept separate for easier testing.
@contextlib.contextmanager
def _init_context(fv: flags.FlagValues = FLAGS):
with contextlib.ExitStack() as ctx:
for module_spec in fv.init_module:
parts = module_spec.split(":", maxsplit=1) + [None]
module, spec = parts[:2]
ctx.enter_context(importlib.import_module(module).setup(spec))
yield


def setup():
if tpu_flags_exc is not None:
logging.info("LIBTPU_INIT_FLAGS was not set. Reason: %s", tpu_flags_exc)
else:
logging.info("LIBTPU_INIT_ARGS='%s'", os.environ["LIBTPU_INIT_ARGS"])

if FLAGS.health_check_module:
health_check = importlib.import_module(FLAGS.health_check_module).health_check(
FLAGS.health_check_spec,
output_dir=FLAGS.trainer_dir,
)
else:
health_check = nullcontext()

with health_check:
with _init_context():
setup_spmd(
distributed_coordinator=FLAGS.distributed_coordinator,
num_processes=FLAGS.num_processes,
Expand Down
61 changes: 61 additions & 0 deletions axlearn/common/launch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright © 2024 Apple Inc.

"""Tests launch utils."""

import contextlib
from typing import Optional
from unittest import mock

from absl import flags
from absl.testing import parameterized

from axlearn.common.launch import _init_context
from axlearn.common.test_utils import TestCase


class TestInitContext(TestCase):
"""Tests _init_context."""

@parameterized.parameters(
dict(value=["my.module.path"], expect={"my.module.path": None}),
dict(
value=["my.module.path:k1=v1,k2=v2"],
expect={"my.module.path": "k1=v1,k2=v2"},
),
dict(
value=["my.module.path:k1:v1"],
expect={"my.module.path": "k1:v1"},
),
dict(
value=["my.module.path:k1:v1", "my.other.module:k2:v2,k3:v3"],
expect={
"my.module.path": "k1:v1",
"my.other.module": "k2:v2,k3:v3",
},
),
)
def test_init_context(self, value, expect: dict[str, Optional[str]]):
fv = flags.FlagValues()
flags.DEFINE_multi_string("init_module", value, "", flag_values=fv)
fv.mark_as_parsed()

with mock.patch("importlib.import_module") as mock_import:
with _init_context(fv):
for i, k in enumerate(expect):
self.assertEqual(k, mock_import.call_args_list[i][0][0])

side_effect = []
actual_specs = []
for _ in range(len(value)):

@contextlib.contextmanager
def mock_setup(actual):
actual_specs.append(actual)
yield

mock_module = mock.Mock(**{"setup.side_effect": mock_setup})
side_effect.append(mock_module)

with mock.patch("importlib.import_module", side_effect=side_effect):
with _init_context(fv):
self.assertEqual(list(expect.values()), actual_specs)

0 comments on commit 1af2ba8

Please sign in to comment.