diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index a19bdcd020d..d9664287a41 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -21,29 +21,35 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install yapf==0.32.0 - pip install toml==0.10.2 - pip install black==22.10.0 - pip install isort==5.12.0 + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install yapf==0.32.0 + uv pip install toml==0.10.2 + uv pip install black==22.10.0 + uv pip install isort==5.12.0 - name: Running yapf run: | + source ~/test-env/bin/activate yapf --diff --recursive ./ --exclude 'sky/skylet/ray_patches/**' \ --exclude 'sky/skylet/providers/ibm/**' - name: Running black run: | + source ~/test-env/bin/activate black --diff --check sky/skylet/providers/ibm/ - name: Running isort for black formatted files run: | + source ~/test-env/bin/activate isort --diff --check --profile black -l 88 -m 3 \ sky/skylet/providers/ibm/ - name: Running isort for yapf formatted files run: | + source ~/test-env/bin/activate isort --diff --check ./ --sg 'sky/skylet/ray_patches/**' \ --sg 'sky/skylet/providers/ibm/**' diff --git a/.github/workflows/mypy-generic.yml b/.github/workflows/mypy-generic.yml deleted file mode 100644 index c28ffad9bb7..00000000000 --- a/.github/workflows/mypy-generic.yml +++ /dev/null @@ -1,22 +0,0 @@ -# This is needed for GitHub Actions for the "Waiting for status to be reported" problem, -# according to https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/defining-the-mergeability-of-pull-requests/troubleshooting-required-status-checks -name: mypy - -on: - # Trigger the workflow on push or pull request, - # but only for the main branch - push: - branches: - - master - - 'releases/**' - pull_request: - branches: - - master - - 'releases/**' - merge_group: - -jobs: - mypy: - runs-on: ubuntu-latest - steps: - - run: 'echo "No mypy to run"' diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index d59e90a9e99..6df98401fcb 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -11,6 +11,8 @@ on: branches: - master - 'releases/**' + merge_group: + jobs: mypy: runs-on: ubuntu-latest @@ -19,15 +21,18 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install mypy==$(grep mypy requirements-dev.txt | cut -d'=' -f3) - pip install $(grep types- requirements-dev.txt | tr '\n' ' ') + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install mypy==$(grep mypy requirements-dev.txt | cut -d'=' -f3) + uv pip install $(grep types- requirements-dev.txt | tr '\n' ' ') - name: Running mypy run: | + source ~/test-env/bin/activate mypy $(cat tests/mypy_files.txt) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 0555fb934d0..f5cf40a31ad 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -21,16 +21,20 @@ jobs: python-version: ["3.8"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install ".[all]" - pip install pylint==2.14.5 - pip install pylint-quotes==0.2.3 + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + uv pip install ".[all]" + uv pip install pylint==2.14.5 + uv pip install pylint-quotes==0.2.3 - name: Analysing the code with pylint run: | + source ~/test-env/bin/activate pylint --load-plugins pylint_quotes sky diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 757bfec36d2..bface9232cf 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -35,26 +35,21 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v3 - - - name: Install Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - - - name: Cache dependencies - uses: actions/cache@v3 - if: startsWith(runner.os, 'Linux') - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-pytest-${{ matrix.python-version }} - restore-keys: | - ${{ runner.os }}-pip-pytest-${{ matrix.python-version }} - - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install -e ".[all]" - pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + # Use -e to include examples and tests folder in the path for unit + # tests to access them. + uv pip install -e ".[all]" + uv pip install pytest pytest-xdist pytest-env>=0.6 memory-profiler==0.61.0 - name: Run tests with pytest - run: SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 0 --dist no ${{ matrix.test-path }} + run: | + source ~/test-env/bin/activate + SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 pytest -n 0 --dist no ${{ matrix.test-path }} diff --git a/.github/workflows/test-doc-build.yml b/.github/workflows/test-doc-build.yml index 706aa071706..4a55e4fef89 100644 --- a/.github/workflows/test-doc-build.yml +++ b/.github/workflows/test-doc-build.yml @@ -14,24 +14,28 @@ on: merge_group: jobs: - format: + doc-build: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.10"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v4 with: + version: "latest" python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install . + uv venv --seed ~/test-env + source ~/test-env/bin/activate + uv pip install --prerelease=allow "azure-cli>=2.65.0" + uv pip install ".[all]" cd docs - pip install -r ./requirements-docs.txt + uv pip install -r ./requirements-docs.txt - name: Build documentation run: | + source ~/test-env/bin/activate cd ./docs ./build.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..db40b03b5fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,74 @@ +# Ensure this configuration aligns with format.sh and requirements.txt +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/psf/black + rev: 22.10.0 # Match the version from requirements + hooks: + - id: black + name: black (IBM specific) + files: "^sky/skylet/providers/ibm/.*" # Match only files in the IBM directory + +- repo: https://github.com/pycqa/isort + rev: 5.12.0 # Match the version from requirements + hooks: + # First isort command + - id: isort + name: isort (general) + args: + - "--sg=build/**" # Matches "${ISORT_YAPF_EXCLUDES[@]}" + - "--sg=sky/skylet/providers/ibm/**" + files: "^(sky|tests|examples|llm|docs)/.*" # Only match these directories + # Second isort command + - id: isort + name: isort (IBM specific) + args: + - "--profile=black" + - "-l=88" + - "-m=3" + files: "^sky/skylet/providers/ibm/.*" # Only match IBM-specific directory + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 # Match the version from requirements + hooks: + - id: mypy + args: + # From tests/mypy_files.txt + - "sky" + - "--exclude" + - "sky/benchmark|sky/callbacks|sky/skylet/providers/azure|sky/resources.py|sky/backends/monkey_patches" + pass_filenames: false + additional_dependencies: + - types-PyYAML + - types-requests<2.31 # Match the condition in requirements.txt + - types-setuptools + - types-cachetools + - types-pyvmomi + +- repo: https://github.com/google/yapf + rev: v0.32.0 # Match the version from requirements + hooks: + - id: yapf + name: yapf + exclude: (build/.*|sky/skylet/providers/ibm/.*) # Matches exclusions from the script + args: ['--recursive', '--parallel'] # Only necessary flags + additional_dependencies: [toml==0.10.2] + +- repo: https://github.com/pylint-dev/pylint + rev: v2.14.5 # Match the version from requirements + hooks: + - id: pylint + additional_dependencies: + - pylint-quotes==0.2.3 # Match the version from requirements + name: pylint + args: + - --rcfile=.pylintrc # Use your custom pylint configuration + - --load-plugins=pylint_quotes # Load the pylint-quotes plugin + files: ^sky/ # Only include files from the 'sky/' directory + exclude: ^sky/skylet/providers/ibm/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6542f6add27..85ca90b2c4a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,7 +1,7 @@ # Contributing to SkyPilot -Thank you for your interest in contributing to SkyPilot! We welcome and value -all contributions to the project, including but not limited to: +Thank you for your interest in contributing to SkyPilot! We welcome and value +all contributions to the project, including but not limited to: * [Bug reports](https://github.com/skypilot-org/skypilot/issues) and [discussions](https://github.com/skypilot-org/skypilot/discussions) * [Pull requests](https://github.com/skypilot-org/skypilot/pulls) for bug fixes and new features @@ -26,7 +26,7 @@ pip install -r requirements-dev.txt ### Testing To run smoke tests (NOTE: Running all smoke tests launches ~20 clusters): ``` -# Run all tests except for AWS and Lambda Cloud +# Run all tests on AWS and Azure (default smoke test clouds) pytest tests/test_smoke.py # Terminate a test's cluster even if the test failed (default is to keep it around for debugging) @@ -41,11 +41,11 @@ pytest tests/test_smoke.py::test_minimal # Only run managed spot tests pytest tests/test_smoke.py --managed-spot -# Only run test for AWS + generic tests -pytest tests/test_smoke.py --aws +# Only run test for GCP + generic tests +pytest tests/test_smoke.py --gcp -# Change cloud for generic tests to aws -pytest tests/test_smoke.py --generic-cloud aws +# Change cloud for generic tests to Azure +pytest tests/test_smoke.py --generic-cloud azure ``` For profiling code, use: @@ -78,6 +78,7 @@ It has some convenience features which you might find helpful (see [Dockerfile]( - If relevant, add tests for your changes. For changes that touch the core system, run the [smoke tests](#testing) and ensure they pass. - Follow the [Google style guide](https://google.github.io/styleguide/pyguide.html). - Ensure code is properly formatted by running [`format.sh`](https://github.com/skypilot-org/skypilot/blob/master/format.sh). + - [Optional] You can also install pre-commit hooks by running `pre-commit install` to automatically format your code on commit. - Push your changes to your fork and open a pull request in the SkyPilot repository. - In the PR description, write a `Tested:` section to describe relevant tests performed. diff --git a/Dockerfile_k8s b/Dockerfile_k8s index 45625871078..f031dff3668 100644 --- a/Dockerfile_k8s +++ b/Dockerfile_k8s @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # Initialize conda for root user, install ssh and other local dependencies RUN apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* && \ apt remove -y python3 && \ conda init diff --git a/Dockerfile_k8s_gpu b/Dockerfile_k8s_gpu index 09570d102df..6277e7f8d12 100644 --- a/Dockerfile_k8s_gpu +++ b/Dockerfile_k8s_gpu @@ -7,7 +7,7 @@ ARG DEBIAN_FRONTEND=noninteractive # We remove cuda lists to avoid conflicts with the cuda version installed by ray RUN rm -rf /etc/apt/sources.list.d/cuda* && \ apt update -y && \ - apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat curl -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils nano fuse unzip socat netcat-openbsd curl -y && \ rm -rf /var/lib/apt/lists/* # Setup SSH and generate hostkeys @@ -36,6 +36,7 @@ SHELL ["/bin/bash", "-c"] # Install conda and other dependencies # Keep the conda and Ray versions below in sync with the ones in skylet.constants +# Keep this section in sync with the custom image optimization recommendations in our docs (kubernetes-getting-started.rst) RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ bash Miniconda3-Linux-x86_64.sh -b && \ eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ diff --git a/README.md b/README.md index 2629cc4e4c8..f29b57be9ca 100644 --- a/README.md +++ b/README.md @@ -155,10 +155,11 @@ SkyPilot then performs the heavy-lifting for you, including: Refer to [Quickstart](https://skypilot.readthedocs.io/en/latest/getting-started/quickstart.html) to get started with SkyPilot. ## More Information -To learn more, see our [documentation](https://skypilot.readthedocs.io/en/latest/), [blog](https://blog.skypilot.co/), and [community integrations](https://blog.skypilot.co/community/). +To learn more, see [Concept: Sky Computing](https://docs.skypilot.co/en/latest/sky-computing.html), [SkyPilot docs](https://skypilot.readthedocs.io/en/latest/), and [SkyPilot blog](https://blog.skypilot.co/). Runnable examples: +- [**AI Gallery**](https://docs.skypilot.co/en/latest/gallery/index.html) - LLMs on SkyPilot - [Llama 3.2: lightweight and vision models](./llm/llama-3_2/) - [Pixtral](./llm/pixtral/) @@ -183,7 +184,7 @@ Runnable examples: - [LocalGPT](./llm/localgpt) - [Falcon](./llm/falcon) - Add yours here & see more in [`llm/`](./llm)! -- Framework examples: [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/nemo.yaml), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2), [Airflow](./examples/airflow/training_workflow) and [many more (`examples/`)](./examples). +- Framework examples: [PyTorch DDP](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_torch.yaml), [DeepSpeed](./examples/deepspeed-multinode/sky.yaml), [JAX/Flax on TPU](https://github.com/skypilot-org/skypilot/blob/master/examples/tpu/tpuvm_mnist.yaml), [Stable Diffusion](https://github.com/skypilot-org/skypilot/tree/master/examples/stable_diffusion), [Detectron2](https://github.com/skypilot-org/skypilot/blob/master/examples/detectron2_docker.yaml), [Distributed](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_distributed_tf_app.py) [TensorFlow](https://github.com/skypilot-org/skypilot/blob/master/examples/resnet_app_storage.yaml), [Ray Train](examples/distributed_ray_train/ray_train.yaml), [NeMo](https://github.com/skypilot-org/skypilot/blob/master/examples/nemo/), [programmatic grid search](https://github.com/skypilot-org/skypilot/blob/master/examples/huggingface_glue_imdb_grid_search_app.py), [Docker](https://github.com/skypilot-org/skypilot/blob/master/examples/docker/echo_app.yaml), [Cog](https://github.com/skypilot-org/skypilot/blob/master/examples/cog/), [Unsloth](https://github.com/skypilot-org/skypilot/blob/master/examples/unsloth/unsloth.yaml), [Ollama](https://github.com/skypilot-org/skypilot/blob/master/llm/ollama), [llm.c](https://github.com/skypilot-org/skypilot/tree/master/llm/gpt-2), [Airflow](./examples/airflow/training_workflow) and [many more (`examples/`)](./examples). Case Studies and Integrations: [Community Spotlights](https://blog.skypilot.co/community/) diff --git a/docs/source/_static/custom.js b/docs/source/_static/custom.js index 1fa28105186..5ae47b7b7be 100644 --- a/docs/source/_static/custom.js +++ b/docs/source/_static/custom.js @@ -7,7 +7,7 @@ document.addEventListener('DOMContentLoaded', function () { script.setAttribute('data-project-logo', 'https://avatars.githubusercontent.com/u/109387420?s=100&v=4'); script.setAttribute('data-modal-disclaimer', 'Results are automatically generated and may be inaccurate or contain inappropriate information. Do not include any sensitive information in your query.\n**To get further assistance, you can chat directly with the development team** by joining the [SkyPilot Slack](https://slack.skypilot.co/).'); script.setAttribute('data-modal-title', 'SkyPilot Docs AI - Ask a Question.'); - script.setAttribute('data-button-position-bottom', '85px'); + script.setAttribute('data-button-position-bottom', '100px'); script.async = true; document.head.appendChild(script); }); @@ -25,7 +25,6 @@ document.addEventListener('DOMContentLoaded', function () { document.addEventListener('DOMContentLoaded', () => { // New items: const newItems = [ - { selector: '.caption-text', text: 'SkyServe: Model Serving' }, { selector: '.toctree-l1 > a', text: 'Managed Jobs' }, { selector: '.toctree-l1 > a', text: 'Pixtral (Mistral AI)' }, { selector: '.toctree-l1 > a', text: 'Many Parallel Jobs' }, @@ -33,6 +32,7 @@ document.addEventListener('DOMContentLoaded', () => { { selector: '.toctree-l1 > a', text: 'Llama 3.2 (Meta)' }, { selector: '.toctree-l1 > a', text: 'Admin Policy Enforcement' }, { selector: '.toctree-l1 > a', text: 'Using Existing Machines' }, + { selector: '.toctree-l1 > a', text: 'Concept: Sky Computing' }, ]; newItems.forEach(({ selector, text }) => { document.querySelectorAll(selector).forEach((el) => { diff --git a/docs/source/cloud-setup/cloud-permissions/kubernetes.rst b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst index df1d2c5e161..d337c9a2c76 100644 --- a/docs/source/cloud-setup/cloud-permissions/kubernetes.rst +++ b/docs/source/cloud-setup/cloud-permissions/kubernetes.rst @@ -96,7 +96,29 @@ SkyPilot requires permissions equivalent to the following roles to be able to ma These roles must apply to both the user account configured in the kubeconfig file and the service account used by SkyPilot (if configured). -If your tasks use object store mounting or require access to ingress resources, you will need to grant additional permissions as described below. +If you need to view real-time GPU availability with ``sky show-gpus``, your tasks use object store mounting or your tasks require access to ingress resources, you will need to grant additional permissions as described below. + +Permissions for ``sky show-gpus`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``sky show-gpus`` needs to list all pods across all namespaces to calculate GPU availability. To do this, SkyPilot needs the ``get`` and ``list`` permissions for pods in a ``ClusterRole``: + +.. code-block:: yaml + + apiVersion: rbac.authorization.k8s.io/v1 + kind: ClusterRole + metadata: + name: sky-sa-cluster-role-pod-reader + rules: + - apiGroups: [""] + resources: ["pods"] + verbs: ["get", "list"] + + +.. tip:: + + If this role is not granted to the service account, ``sky show-gpus`` will still work but it will only show the total GPUs on the nodes, not the number of free GPUs. + Permissions for Object Store Mounting ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -225,6 +247,9 @@ To create a service account that has all necessary permissions for SkyPilot (inc - apiGroups: ["networking.k8s.io"] # Required for exposing services through ingresses resources: ["ingressclasses"] verbs: ["get", "list", "watch"] + - apiGroups: [""] # Required for `sky show-gpus` command + resources: ["pods"] + verbs: ["get", "list"] --- # ClusterRoleBinding for the service account apiVersion: rbac.authorization.k8s.io/v1 diff --git a/docs/source/docs/index.rst b/docs/source/docs/index.rst index d83bf7821c3..17f8d545fa6 100644 --- a/docs/source/docs/index.rst +++ b/docs/source/docs/index.rst @@ -62,9 +62,9 @@ SkyPilot supports your existing GPU, TPU, and CPU workloads, with no code change Ready to get started? ---------------------- -:ref:`Install SkyPilot ` in ~1 minute. Then, launch your first dev cluster in ~5 minutes in :ref:`Quickstart `. +:ref:`Install SkyPilot ` in 1 minute. Then, launch your first dev cluster in 2 minutes in :ref:`Quickstart `. -Everything is launched within your cloud accounts, VPCs, and cluster(s). +SkyPilot is BYOC: Everything is launched within your cloud accounts, VPCs, and clusters. Contact the SkyPilot team --------------------------------- @@ -74,10 +74,14 @@ You can chat with the SkyPilot team and community on the `SkyPilot Slack ` and `SkyPilot blog `_. + Runnable examples: .. Keep this section in sync with README.md in SkyPilot repo +* :ref:`AI Gallery ` + * **LLMs on SkyPilot** * `Llama 3.2: lightweight and vision models `_ @@ -130,6 +134,7 @@ Read the research: ../getting-started/quickstart ../examples/interactive-development ../getting-started/tutorial + ../sky-computing .. toctree:: diff --git a/docs/source/examples/managed-jobs.rst b/docs/source/examples/managed-jobs.rst index 018a993f588..61c33b5c43e 100644 --- a/docs/source/examples/managed-jobs.rst +++ b/docs/source/examples/managed-jobs.rst @@ -78,9 +78,9 @@ We can launch it with the following: .. code-block:: console + $ git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 $ sky jobs launch -n bert-qa bert_qa.yaml - .. code-block:: yaml # bert_qa.yaml @@ -88,39 +88,37 @@ We can launch it with the following: resources: accelerators: V100:1 - # Use spot instances to save cost. - use_spot: true - - # Assume your working directory is under `~/transformers`. - # To make this example work, please run the following command: - # git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 - workdir: ~/transformers + use_spot: true # Use spot instances to save cost. - setup: | + envs: # Fill in your wandb key: copy from https://wandb.ai/authorize # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY` # to pass the key in the command line, during `sky jobs launch`. - echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc + WANDB_API_KEY: + + # Assume your working directory is under `~/transformers`. + workdir: ~/transformers + setup: | pip install -e . cd examples/pytorch/question-answering/ pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install wandb run: | - cd ./examples/pytorch/question-answering/ + cd examples/pytorch/question-answering/ python run_qa.py \ - --model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --do_eval \ - --per_device_train_batch_size 12 \ - --learning_rate 3e-5 \ - --num_train_epochs 50 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --report_to wandb - + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 50 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --report_to wandb \ + --output_dir /tmp/bert_qa/ .. note:: @@ -162,55 +160,52 @@ An End-to-End Example Below we show an `example `_ for fine-tuning a BERT model on a question-answering task with HuggingFace. .. code-block:: yaml - :emphasize-lines: 13-16,42-45 + :emphasize-lines: 8-11,41-44 # bert_qa.yaml name: bert-qa resources: accelerators: V100:1 - use_spot: true - - # Assume your working directory is under `~/transformers`. - # To make this example work, please run the following command: - # git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1 - workdir: ~/transformers + use_spot: true # Use spot instances to save cost. file_mounts: /checkpoint: name: # NOTE: Fill in your bucket name mode: MOUNT - setup: | + envs: # Fill in your wandb key: copy from https://wandb.ai/authorize # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY` # to pass the key in the command line, during `sky jobs launch`. - echo export WANDB_API_KEY=[YOUR-WANDB-API-KEY] >> ~/.bashrc + WANDB_API_KEY: + + # Assume your working directory is under `~/transformers`. + workdir: ~/transformers + setup: | pip install -e . cd examples/pytorch/question-answering/ - pip install -r requirements.txt + pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install wandb run: | - cd ./examples/pytorch/question-answering/ + cd examples/pytorch/question-answering/ python run_qa.py \ - --model_name_or_path bert-base-uncased \ - --dataset_name squad \ - --do_train \ - --do_eval \ - --per_device_train_batch_size 12 \ - --learning_rate 3e-5 \ - --num_train_epochs 50 \ - --max_seq_length 384 \ - --doc_stride 128 \ - --report_to wandb \ - --run_name $SKYPILOT_TASK_ID \ - --output_dir /checkpoint/bert_qa/ \ - --save_total_limit 10 \ - --save_steps 1000 - - + --model_name_or_path bert-base-uncased \ + --dataset_name squad \ + --do_train \ + --do_eval \ + --per_device_train_batch_size 12 \ + --learning_rate 3e-5 \ + --num_train_epochs 50 \ + --max_seq_length 384 \ + --doc_stride 128 \ + --report_to wandb \ + --output_dir /checkpoint/bert_qa/ \ + --run_name $SKYPILOT_TASK_ID \ + --save_total_limit 10 \ + --save_steps 1000 As HuggingFace has built-in support for periodically checkpointing, we only need to pass the highlighted arguments for setting up the output directory and frequency of checkpointing (see more diff --git a/docs/source/getting-started/installation.rst b/docs/source/getting-started/installation.rst index 69303a582e2..deb2307b67b 100644 --- a/docs/source/getting-started/installation.rst +++ b/docs/source/getting-started/installation.rst @@ -267,6 +267,14 @@ The :code:`~/.oci/config` file should contain the following fields: # Note that we should avoid using full home path for the key_file configuration, e.g. use ~/.oci instead of /home/username/.oci key_file=~/.oci/oci_api_key.pem +By default, the provisioned nodes will be in the root `compartment `__. To specify the `compartment `_ other than root, create/edit the file :code:`~/.sky/config.yaml`, put the compartment's OCID there, as the following: + +.. code-block:: text + + oci: + default: + compartment_ocid: ocid1.compartment.oc1..aaaaaaaa...... + Lambda Cloud ~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/images/sky-above-clouds-gen.jpg b/docs/source/images/sky-above-clouds-gen.jpg new file mode 100644 index 00000000000..bc4aeb82c8f Binary files /dev/null and b/docs/source/images/sky-above-clouds-gen.jpg differ diff --git a/docs/source/reference/comparison.rst b/docs/source/reference/comparison.rst index e9bffabba68..23985e5081b 100644 --- a/docs/source/reference/comparison.rst +++ b/docs/source/reference/comparison.rst @@ -46,7 +46,7 @@ SkyPilot provides faster iteration for interactive development. For example, a c * :strong:`With SkyPilot, a single command (`:literal:`sky launch`:strong:`) takes care of everything.` Behind the scenes, SkyPilot provisions pods, installs all required dependencies, executes the job, returns logs, and provides SSH and VSCode access to debug. -.. figure:: https://blog.skypilot.co/ai-on-kubernetes/images/k8s_vs_skypilot_iterative_v2.png +.. figure:: https://i.imgur.com/xfCfz4N.png :align: center :width: 95% :alt: Iterative Development with Kubernetes vs SkyPilot diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index b8255b46402..286788625bd 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -244,6 +244,10 @@ Available fields and semantics: # instances. SkyPilot will auto-create and reuse a service account (IAM # role) for AWS instances. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # Customized service account (IAM role): or # - : apply the service account with the specified name to all instances. # Example: @@ -263,7 +267,8 @@ Available fields and semantics: # # - This only affects AWS instances. Local AWS credentials will still be # uploaded to non-AWS instances (since those instances may need to access - # AWS resources). + # AWS resources). To fully disable credential upload, set + # `remote_identity: NO_UPLOAD`. # - If the SkyPilot jobs/serve controller is on AWS, this setting will make # non-AWS managed jobs / non-AWS service replicas fail to access any # resources on AWS (since the controllers don't have AWS credential @@ -406,11 +411,16 @@ Available fields and semantics: # instances. SkyPilot will auto-create and reuse a service account for GCP # instances. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # Two caveats of SERVICE_ACCOUNT for multicloud users: # # - This only affects GCP instances. Local GCP credentials will still be # uploaded to non-GCP instances (since those instances may need to access - # GCP resources). + # GCP resources). To fully disable credential uploads, set + # `remote_identity: NO_UPLOAD`. # - If the SkyPilot jobs/serve controller is on GCP, this setting will make # non-GCP managed jobs / non-GCP service replicas fail to access any # resources on GCP (since the controllers don't have GCP credential @@ -431,6 +441,12 @@ Available fields and semantics: # Advanced Azure configurations (optional). # Apply to all new instances but not existing ones. azure: + # By default, SkyPilot creates a unique resource group for each VM when + # launched. If specified, all VMs will be launched within the provided + # resource group. Additionally, controllers for serve and managed jobs will + # be created in this resource group. Note: This setting only applies to VMs + # and does not affect storage accounts or containers. + resource_group_vm: user-resource-group-name # Specify an existing Azure storage account for SkyPilot-managed containers. # If not set, SkyPilot will use its default naming convention to create and # use the storage account unless container endpoint URI is used as source. @@ -491,6 +507,10 @@ Available fields and semantics: # SkyPilot will auto-create and reuse a service account with necessary roles # in the user's namespace. # + # NO_UPLOAD: No credentials will be uploaded to the pods. Useful for + # avoiding overriding any existing credentials that may be automounted on + # the cluster. + # # : The name of a service account to use for all Kubernetes pods. # This service account must exist in the user's namespace and have all # necessary permissions. Refer to https://skypilot.readthedocs.io/en/latest/cloud-setup/cloud-permissions/kubernetes.html @@ -499,7 +519,8 @@ Available fields and semantics: # Using SERVICE_ACCOUNT or a custom service account only affects Kubernetes # instances. Local ~/.kube/config will still be uploaded to non-Kubernetes # instances (e.g., a serve controller on GCP or AWS may need to provision - # Kubernetes resources). + # Kubernetes resources). To fully disable credential uploads, set + # `remote_identity: NO_UPLOAD`. # # Default: 'SERVICE_ACCOUNT'. remote_identity: my-k8s-service-account diff --git a/docs/source/reference/kubernetes/index.rst b/docs/source/reference/kubernetes/index.rst index 86e153bd8fc..89a57862c88 100644 --- a/docs/source/reference/kubernetes/index.rst +++ b/docs/source/reference/kubernetes/index.rst @@ -103,17 +103,3 @@ Table of Contents Getting Started kubernetes-setup kubernetes-troubleshooting - - -Features and Roadmap --------------------- - -Kubernetes support is under active development. Some features are in progress and will be released soon: - -* CPU and GPU Tasks - ✅ Available -* Auto-down - ✅ Available -* Storage mounting - ✅ Available on x86_64 clusters -* Multi-node tasks - ✅ Available -* Custom images - ✅ Available -* Opening ports and exposing services - ✅ Available -* Multiple Kubernetes Clusters - 🚧 In progress diff --git a/docs/source/reference/kubernetes/kubernetes-getting-started.rst b/docs/source/reference/kubernetes/kubernetes-getting-started.rst index 99d8f56cd34..e4bbb2c8915 100644 --- a/docs/source/reference/kubernetes/kubernetes-getting-started.rst +++ b/docs/source/reference/kubernetes/kubernetes-getting-started.rst @@ -261,9 +261,25 @@ After launching the cluster with :code:`sky launch -c myclus task.yaml`, you can FAQs ---- +* **Can I use multiple Kubernetes clusters with SkyPilot?** + + SkyPilot can work with multiple Kubernetes contexts set in your kubeconfig file. By default, SkyPilot will use the current active context. To use a different context, change your current context using :code:`kubectl config use-context `. + + If you would like to use multiple contexts seamlessly during failover, check out the :code:`allowed_contexts` feature in :ref:`config-yaml`. + * **Are autoscaling Kubernetes clusters supported?** - To run on an autoscaling cluster, you may need to adjust the resource provisioning timeout (:code:`Kubernetes.TIMEOUT` in `clouds/kubernetes.py`) to a large value to give enough time for the cluster to autoscale. We are working on a better interface to adjust this timeout - stay tuned! + To run on autoscaling clusters, set the :code:`provision_timeout` key in :code:`~/.sky/config.yaml` to a large value to give enough time for the cluster autoscaler to provision new nodes. + This will direct SkyPilot to wait for the cluster to scale up before failing over to the next candidate resource (e.g., next cloud). + + If you are using GPUs in a scale-to-zero setting, you should also set the :code:`autoscaler` key to the autoscaler type of your cluster. More details in :ref:`config-yaml`. + + .. code-block:: yaml + + # ~/.sky/config.yaml + kubernetes: + provision_timeout: 900 # Wait 15 minutes for nodes to get provisioned before failover. Set to -1 to wait indefinitely. + autoscaler: gke # [gke, karpenter, generic]; required if using GPUs in scale-to-zero setting * **Can SkyPilot provision a Kubernetes cluster for me? Will SkyPilot add more nodes to my Kubernetes clusters?** @@ -280,7 +296,7 @@ FAQs * **How can I specify custom configuration for the pods created by SkyPilot?** You can override the pod configuration used by SkyPilot by setting the :code:`pod_config` key in :code:`~/.sky/config.yaml`. - The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_. + The value of :code:`pod_config` should be a dictionary that follows the `Kubernetes Pod API `_. For example, to set custom environment variables and attach a volume on your pods, you can add the following to your :code:`~/.sky/config.yaml` file: @@ -296,6 +312,11 @@ FAQs volumeMounts: # Custom volume mounts for the pod - mountPath: /foo name: example-volume + resources: # Custom resource requests and limits + requests: + rdma/rdma_shared_device_a: 1 + limits: + rdma/rdma_shared_device_a: 1 volumes: - name: example-volume hostPath: @@ -303,3 +324,32 @@ FAQs type: Directory For more details refer to :ref:`config-yaml`. + +* **I am using a custom image. How can I speed up the pod startup time?** + + You can pre-install SkyPilot dependencies in your custom image to speed up the pod startup time. Simply add these lines at the end of your Dockerfile: + + .. code-block:: dockerfile + + FROM + + # Install system dependencies + RUN apt update -y && \ + apt install git gcc rsync sudo patch openssh-server pciutils fuse unzip socat netcat-openbsd curl -y && \ + rm -rf /var/lib/apt/lists/* + + # Install conda and other python dependencies + RUN curl https://repo.anaconda.com/miniconda/Miniconda3-py310_23.11.0-2-Linux-x86_64.sh -o Miniconda3-Linux-x86_64.sh && \ + bash Miniconda3-Linux-x86_64.sh -b && \ + eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && conda config --set auto_activate_base true && conda activate base && \ + grep "# >>> conda initialize >>>" ~/.bashrc || { conda init && source ~/.bashrc; } && \ + rm Miniconda3-Linux-x86_64.sh && \ + export PIP_DISABLE_PIP_VERSION_CHECK=1 && \ + python3 -m venv ~/skypilot-runtime && \ + PYTHON_EXEC=$(echo ~/skypilot-runtime)/bin/python && \ + $PYTHON_EXEC -m pip install 'skypilot-nightly[remote,kubernetes]' 'ray[default]==2.9.3' 'pycryptodome==3.12.0' && \ + $PYTHON_EXEC -m pip uninstall skypilot-nightly -y && \ + curl -LO "https://dl.k8s.io/release/v1.28.11/bin/linux/amd64/kubectl" && \ + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl && \ + echo 'export PATH="$PATH:$HOME/.local/bin"' >> ~/.bashrc + diff --git a/docs/source/running-jobs/environment-variables.rst b/docs/source/running-jobs/environment-variables.rst index f7138af95fa..d88424359d1 100644 --- a/docs/source/running-jobs/environment-variables.rst +++ b/docs/source/running-jobs/environment-variables.rst @@ -16,7 +16,7 @@ User-specified environment variables User-specified environment variables are useful for passing secrets and any arguments or configurations needed for your tasks. They are made available in ``file_mounts``, ``setup``, and ``run``. -You can specify environment variables to be made available to a task in two ways: +You can specify environment variables to be made available to a task in several ways: - ``envs`` field (dict) in a :ref:`task YAML `: @@ -24,7 +24,18 @@ You can specify environment variables to be made available to a task in two ways envs: MYVAR: val - + + +- ``--env-file`` flag in ``sky launch/exec`` :ref:`CLI `, which is a path to a `dotenv` file (takes precedence over the above): + + .. code-block:: text + + # sky launch example.yaml --env-file my_app.env + # cat my_app.env + MYVAR=val + WANDB_API_KEY=MY_WANDB_API_KEY + HF_TOKEN=MY_HF_TOKEN + - ``--env`` flag in ``sky launch/exec`` :ref:`CLI ` (takes precedence over the above) .. tip:: @@ -145,9 +156,9 @@ Environment variables for ``setup`` - 0 * - ``SKYPILOT_SETUP_NODE_IPS`` - A string of IP addresses of the nodes in the cluster with the same order as the node ranks, where each line contains one IP address. - + Note that this is not necessarily the same as the nodes in ``run`` stage: the ``setup`` stage runs on all nodes of the cluster, while the ``run`` stage can run on a subset of nodes. - - + - .. code-block:: text 1.2.3.4 @@ -158,19 +169,19 @@ Environment variables for ``setup`` - 2 * - ``SKYPILOT_TASK_ID`` - A unique ID assigned to each task. - - This environment variable is available only when the task is submitted + + This environment variable is available only when the task is submitted with :code:`sky launch --detach-setup`, or run as a managed spot job. - + Refer to the description in the :ref:`environment variables for run `. - sky-2023-07-06-21-18-31-563597_myclus_1 - + For managed spot jobs: sky-managed-2023-07-06-21-18-31-563597_my-job-name_1-0 * - ``SKYPILOT_CLUSTER_INFO`` - A JSON string containing information about the cluster. To access the information, you could parse the JSON string in bash ``echo $SKYPILOT_CLUSTER_INFO | jq .cloud`` or in Python : .. code-block:: python - + import json json.loads( os.environ['SKYPILOT_CLUSTER_INFO'] @@ -200,7 +211,7 @@ Environment variables for ``run`` - 0 * - ``SKYPILOT_NODE_IPS`` - A string of IP addresses of the nodes reserved to execute the task, where each line contains one IP address. Read more :ref:`here `. - - + - .. code-block:: text 1.2.3.4 @@ -221,13 +232,13 @@ Environment variables for ``run`` If a task is run as a :ref:`managed spot job `, then all recoveries of that job will have the same ID value. The ID is in the format "sky-managed-_(_)_-", where ```` will appear when a pipeline is used, i.e., more than one task in a managed spot job. Read more :ref:`here `. - sky-2023-07-06-21-18-31-563597_myclus_1 - + For managed spot jobs: sky-managed-2023-07-06-21-18-31-563597_my-job-name_1-0 * - ``SKYPILOT_CLUSTER_INFO`` - A JSON string containing information about the cluster. To access the information, you could parse the JSON string in bash ``echo $SKYPILOT_CLUSTER_INFO | jq .cloud`` or in Python : .. code-block:: python - + import json json.loads( os.environ['SKYPILOT_CLUSTER_INFO'] diff --git a/docs/source/sky-computing.rst b/docs/source/sky-computing.rst new file mode 100644 index 00000000000..15134204ebf --- /dev/null +++ b/docs/source/sky-computing.rst @@ -0,0 +1,154 @@ +.. _sky-computing: + +Concept: Sky Computing +=============================== + +SkyPilot is an open-source Sky Computing framework. + +.. figure:: images/sky-above-clouds-gen.jpg + :width: 60% + :align: center + +Problem: Cloud infra's explosive complexity +------------------------------------------- + +Today's cloud infra has exploded in complexity. +Organizations are forced to deal with a combinatorially large number of infra choices, across three dimensions: + +- **Locations**: 10s of regions and 100s of zones within a single cloud. Teams are also increasingly multi-cluster and multicloud (3+ hyperscalers, 10+ + specialized clouds). +- **Hardware**: 500+ instance types per cloud; many new accelerators (e.g., GPUs, TPUs). +- **Pricing models**: On-demand, reserved, and preemptible spot instances, each with different pricing and availability. + +The search space of ``(locations, hardware, pricing models)`` is combinatorially +large, **even within one cloud**. +It is also dynamic, since availability and pricing change over time and differ by location. +Seemingly simple tasks like "run jobs in any of my US +regions/clusters with the lowest cost" or "monitor and manage jobs on both AWS and GCP" become highly costly: + +- Valuable engineering hours are invested to mask the differences across infra choices. +- Workloads are forced to run on suboptimal choices (to heuristically simplify the search space), wasting utilization, cost savings, and capacity. + +Sky Computing +------------------------- + +To combat this, *Sky Computing* was recently proposed in two papers from UC Berkeley: +`From Cloud Computing to Sky Computing `_ and +`The Sky Above The Clouds `_ (whitepaper). + +In a nutshell, Sky Computing **combines a team's diverse cloud infra into a unified pool**, a "sky". +Sky comes with two components to simplify---and exploit---the complex search space: + +- A unified interface to run workloads on different cloud infra. +- An optimizer to find the best infra choice (cheapest & most available) for each workload. + +Both components make using complex cloud infra simple: + +- The unified Sky interface allows workloads to be specified once with the same interface, and be able to run on different infra. +- The Sky optimizer cuts across the search space to exploit the (dynamically changing) pricing and availability differences in the compute pool. + +With Sky, cloud users and their workloads gain the following benefits: + +* **Cloud is easier to use**: With the unified interface, infra is simplified & multicloud ready. +* **Lower costs**: Engineering time is saved from dealing with cloud infra. Sky optimizes the cost of each workload. Large organizations gain pricing leverage due to portability. +* **Higher capacity**: Workloads can now run on a bigger compute pool, with many choices of locations, hardware, and pricing models. + +Why does AI benefit from Sky Computing? +--------------------------------------------------- + +AI is highly **capacity and cost intensive**, many orders of magnitude more so +than prior cloud workloads. To increase capacity and reduce costs, AI teams are using compute anywhere and in whatever forms they can. + +- Locations: AI teams use a mix of hyperscalers (AWS/GCP/Azure/..), GPU + clouds (CoreWeave/Lambda/..), many regions within a cloud, and/or many + Kubernetes clusters. +- Hardware: Different GPU generations for different tasks (e.g., H100 for + training, L4 for inference); AMD GPUs; accelerators on hyperscalers (e.g., TPUs, Trainium, Inferentia). +- Pricing models: Teams use a mix of reserved, on-demand, spot GPUs to save costs. + +These choices often interleave: e.g., it is common for AI teams to use reserved H100 on cloud X for training and spot L4 on cloud Y +for large-scale batch inference. +Therefore, AI workloads inherently require managing many compute choices in the complex search space. + +Sky Computing presents a natural solution. +Sky offers AI teams **a unified interface to easily and portably run AI** on their diverse compute. +Further, Sky intelligently orchestrates across a team's AI compute choices, providing large cost savings and higher compute capacity. + +SkyPilot and Sky Computing +--------------------------------------------------- + +SkyPilot was born out of the same `UC Berkeley lab `_ that +proposed Sky Computing. +SkyPilot is Sky's first instantiation, and it was started to implement Sky Computing for one important class of workloads: AI and compute-intensive workloads. + +Over the last few years, SkyPilot has been widely adopted by ~100s of leading companies and AI teams. +While the initial development team +consisted of Berkeley PhDs/researchers, the SkyPilot community today has grown to +100+ `contributors `_ from many organizations. + +SkyPilot operates in a BYOC (Bring Your Own Cloud) model, where all resources +are launched in a user's existing cloud accounts, VPCs, and clusters. + +SkyPilot is open sourced under the permissive Apache 2 license and under +active development on `GitHub `_. + +What if I have a single cloud: Levels of Sky Computing +------------------------------------------------------ + +Just like autonomous driving has different levels of autonomy (e.g., Level 1-5), one can adopt Sky Computing and SkyPilot in increasing "levels" and benefits. + +**For users on a fixed cluster** (e.g., Kubernetes, Slurm), SkyPilot provides: + +- A simple interface to submit and manage AI workloads, tailored to AI users' ergonomics. +- Support for dev clusters, jobs, and serving on your cluster. +- Cost savings: Autostop, queueing, and higher hardware utilization. +- Future-proofness: No retooling when you add other clusters or clouds in the future. + +**For users on one cloud's single region/zone**, SkyPilot provides: + +- Auto-retry, auto-fallback provisioner: Specify many hardware fallback targets and SkyPilot will auto-optimize and auto-retry to combat GPU shortage. +- Battle-tested job recovery, including training and serving on spot instances. +- :ref:`Simple workload packaging `: Wrap your existing AI projects in a simple SkyPilot YAML and have all infra tasks handled for you. +- Plus all of the benefits above. + +**For users on one cloud's multiple regions**, SkyPilot provides: + +- Support for a single job to utilize multiple regions for GPU availability & faster recovery. +- Support for a model's replicas to span multiple regions for availability & cost savings. +- Plus all of the benefits above. + +**For users on multiple clouds or clusters**, SkyPilot + +- Combines all of your infra into a unified pool (your *Sky*), for higher utilization, cost savings, and capacity. +- Plus all of the benefits above. + + + +Learning more +--------------------------------------------------- + +Today, the systems community at UC Berkeley --- and beyond --- have +produced several follow-up projects to expand the Sky Computing stack: + +- `SkyServe `_: SkyPilot's cross-region, cross-cloud AI serving library (:ref:`user docs `). +- `Can't Be Late `_: Advanced spot instance scheduling policy for SkyPilot (NSDI '24 Best Paper). +- `Skyplane `_: Open-source tool for fast and cost-effective cross-cloud data transfer. +- `Cloudcast `_: High-throughout, cost-aware cross-region and cross-cloud multicast. +- `FogROS2 `_: Open-source cloud robotics platform leveraging Sky Computing via SkyPilot. +- …and a few more in the pipeline. + +To learn more about SkyPilot, refer to the `project announcement blog post `_, or the `SkyPilot NSDI 2023 paper +`_ and `talk +`_. + +To learn more about Sky Computing, see the `Sky Computing whitepaper `_. + + +Getting involved +--------------------------------------------------- + +**Try out SkyPilot**: Experience Sky Computing in your cloud(s) in 5 minutes via :ref:`Quickstart `. + +**Share your feedback**: Chat with the team on `SkyPilot Slack `_ or drop a note on our `GitHub `_. + +**Contributing**: We welcome contributions from the community! See `CONTRIBUTING `_. diff --git a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py index 7addcffbe3c..ff95162ac63 100644 --- a/examples/admin_policy/example_policy/example_policy/skypilot_policy.py +++ b/examples/admin_policy/example_policy/example_policy/skypilot_policy.py @@ -139,7 +139,7 @@ def update_current_kubernetes_clusters_from_registry(): def get_allowed_contexts(): """Mock implementation of getting allowed kubernetes contexts.""" from sky.provision.kubernetes import utils - contexts = utils.get_all_kube_config_context_names() + contexts = utils.get_all_kube_context_names() return contexts[:2] diff --git a/examples/oci/serve-http-cpu.yaml b/examples/oci/serve-http-cpu.yaml new file mode 100644 index 00000000000..68e3d18c9e5 --- /dev/null +++ b/examples/oci/serve-http-cpu.yaml @@ -0,0 +1,11 @@ +service: + readiness_probe: / + replicas: 2 + +resources: + cloud: oci + region: us-sanjose-1 + ports: 8080 + cpus: 2+ + +run: python -m http.server 8080 diff --git a/examples/oci/serve-qwen-7b.yaml b/examples/oci/serve-qwen-7b.yaml new file mode 100644 index 00000000000..004e912b088 --- /dev/null +++ b/examples/oci/serve-qwen-7b.yaml @@ -0,0 +1,25 @@ +# service.yaml +service: + readiness_probe: /v1/models + replicas: 2 + +# Fields below describe each replica. +resources: + cloud: oci + region: us-sanjose-1 + ports: 8080 + accelerators: {A10:1} + +setup: | + conda create -n vllm python=3.12 -y + conda activate vllm + pip install vllm==0.6.3.post1 + pip install vllm-flash-attn==2.6.2 + +run: | + conda activate vllm + python -u -m vllm.entrypoints.openai.api_server \ + --host 0.0.0.0 --port 8080 \ + --model Qwen/Qwen2-7B-Instruct \ + --served-model-name Qwen2-7B-Instruct \ + --device=cuda --dtype auto --max-model-len=2048 diff --git a/examples/serve/load_balancing_policies_example.yaml b/examples/serve/load_balancing_policies_example.yaml new file mode 100644 index 00000000000..50038a9083a --- /dev/null +++ b/examples/serve/load_balancing_policies_example.yaml @@ -0,0 +1,26 @@ +# SkyServe YAML to demonstrate multiple load balancing policies. +# +# Usage: +# sky serve up -n load_balancing_policy_test examples/serve/load_balancing_policies_example.yaml +# The endpoint will be printed in the console. You +# could also check the endpoint by running: +# sky serve status --endpoint load_balancing_policy_test + +service: + readiness_probe: + path: /health + initial_delay_seconds: 20 + replica_policy: + min_replicas: 2 + max_replicas: 4 + target_qps_per_replica: 1 + # Load balancing policy configuration + load_balancing_policy: round_robin # Change this to test different policies... + +resources: + ports: 8080 + cpus: 2+ + +workdir: examples/serve/http_server + +run: python3 server.py diff --git a/examples/tpu/tpuvm_mnist.yaml b/examples/tpu/tpuvm_mnist.yaml index d1fd434fad6..41b14283fac 100644 --- a/examples/tpu/tpuvm_mnist.yaml +++ b/examples/tpu/tpuvm_mnist.yaml @@ -5,7 +5,7 @@ resources: # The setup command. Will be run under the working directory. setup: | - git clone https://github.com/google/flax.git --branch v0.8.2 + git clone https://github.com/google/flax.git --branch v0.10.1 conda activate flax if [ $? -eq 0 ]; then @@ -15,7 +15,7 @@ setup: | conda activate flax # Make sure to install TPU related packages in a conda env to avoid package conflicts. pip install \ - -f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.25" \ + -f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.35" \ clu \ tensorflow tensorflow-datasets pip install -e flax diff --git a/llm/llama-3_2/README.md b/llm/llama-3_2/README.md index eb62071471d..987dc0d90c5 100644 --- a/llm/llama-3_2/README.md +++ b/llm/llama-3_2/README.md @@ -351,4 +351,4 @@ See more details in [SkyServe docs](https://skypilot.readthedocs.io/en/latest/se ## Developing and Finetuning Llama 3 series -SkyPilot also simplifies the development and finetuning of Llama 3 series. Check out the development and finetuning guides: [Develop](https://github.com/skypilot-org/skypilot/blob/master/llm/llama-3_1/README.md) and [Finetune](https://github.com/skypilot-org/skypilot/blob/master/llm/llama-3_1-finetuning/README.md). +SkyPilot also simplifies the development and finetuning of Llama 3 series. Check out the development and finetuning guides: [Develop](https://github.com/skypilot-org/skypilot/blob/master/llm/llama-3_1/README.md) and [Finetune](https://github.com/skypilot-org/skypilot/blob/master/llm/llama-3_1-finetuning/readme.md). diff --git a/llm/lorax/README.md b/llm/lorax/README.md index 6cc44cf1134..edd153d45f1 100644 --- a/llm/lorax/README.md +++ b/llm/lorax/README.md @@ -91,7 +91,7 @@ Here are some other interesting Mistral-7B fine-tuned models to test out: - [IlyaGusev/saiga_mistral_7b_lora](https://huggingface.co/IlyaGusev/saiga_mistral_7b_lora): Russian chatbot based on `Open-Orca/Mistral-7B-OpenOrca`. - [Undi95/Mistral-7B-roleplay_alpaca-lora](https://huggingface.co/Undi95/Mistral-7B-roleplay_alpaca-lora): Fine-tuned using role-play prompts. -You can find more LoRA adapters [here](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending&search=-lora), or try fine-tuning your own with [PEFT](https://github.com/huggingface/peft) or [Ludwig](https://ludwig.ai). +You can find more LoRA adapters [here](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending&search=lora), or try fine-tuning your own with [PEFT](https://github.com/huggingface/peft) or [Ludwig](https://ludwig.ai). ## Stop the deployment diff --git a/llm/sglang/README.md b/llm/sglang/README.md index fc79529148a..7d41b8fc168 100644 --- a/llm/sglang/README.md +++ b/llm/sglang/README.md @@ -63,7 +63,7 @@ ENDPOINT=$(sky serve status --endpoint sglang-llava) 4. Once it status is `READY`, you can use the endpoint to talk to the model with both text and image inputs:
- +
Input image to the LLaVA model.
@@ -80,7 +80,7 @@ curl $ENDPOINT/v1/chat/completions \ { "type": "image_url", "image_url": { - "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/examples/quick_start/images/cat.jpeg" + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/examples/frontend_language/quick_start/images/cat.jpeg" } } ] diff --git a/sky/__init__.py b/sky/__init__.py index b851775dabf..4e720d63ce0 100644 --- a/sky/__init__.py +++ b/sky/__init__.py @@ -105,6 +105,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): from sky.data import StoreType from sky.execution import exec # pylint: disable=redefined-builtin from sky.execution import launch +from sky.jobs import ManagedJobStatus # TODO (zhwu): These imports are for backward compatibility, and spot APIs # should be called with `sky.spot.xxx` instead. Remove in release 0.8.0 from sky.jobs.core import spot_cancel @@ -163,6 +164,7 @@ def set_proxy_env_var(proxy_var: str, urllib_var: Optional[str]): 'StoreType', 'ClusterStatus', 'JobStatus', + 'ManagedJobStatus', # APIs 'Dag', 'Task', diff --git a/sky/adaptors/kubernetes.py b/sky/adaptors/kubernetes.py index ea8fb194efa..001d397ac9e 100644 --- a/sky/adaptors/kubernetes.py +++ b/sky/adaptors/kubernetes.py @@ -19,6 +19,13 @@ # Timeout to use for API calls API_TIMEOUT = 5 +DEFAULT_IN_CLUSTER_REGION = 'in-cluster' +# The name for the environment variable that stores the in-cluster context name +# for Kubernetes clusters. This is used to associate a name with the current +# context when running with in-cluster auth. If not set, the context name is +# set to DEFAULT_IN_CLUSTER_REGION. +IN_CLUSTER_CONTEXT_NAME_ENV_VAR = 'SKYPILOT_IN_CLUSTER_CONTEXT_NAME' + def _decorate_methods(obj: Any, decorator: Callable, decoration_type: str): for attr_name in dir(obj): @@ -57,16 +64,8 @@ def wrapped(*args, **kwargs): def _load_config(context: Optional[str] = None): urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - try: - # Load in-cluster config if running in a pod - # Kubernetes set environment variables for service discovery do not - # show up in SkyPilot tasks. For now, we work around by using - # DNS name instead of environment variables. - # See issue: https://github.com/skypilot-org/skypilot/issues/2287 - os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' - os.environ['KUBERNETES_SERVICE_PORT'] = '443' - kubernetes.config.load_incluster_config() - except kubernetes.config.config_exception.ConfigException: + + def _load_config_from_kubeconfig(context: Optional[str] = None): try: kubernetes.config.load_kube_config(context=context) except kubernetes.config.config_exception.ConfigException as e: @@ -90,6 +89,21 @@ def _load_config(context: Optional[str] = None): with ux_utils.print_exception_no_traceback(): raise ValueError(err_str) from None + if context == in_cluster_context_name() or context is None: + try: + # Load in-cluster config if running in a pod and context is None. + # Kubernetes set environment variables for service discovery do not + # show up in SkyPilot tasks. For now, we work around by using + # DNS name instead of environment variables. + # See issue: https://github.com/skypilot-org/skypilot/issues/2287 + os.environ['KUBERNETES_SERVICE_HOST'] = 'kubernetes.default.svc' + os.environ['KUBERNETES_SERVICE_PORT'] = '443' + kubernetes.config.load_incluster_config() + except kubernetes.config.config_exception.ConfigException: + _load_config_from_kubeconfig() + else: + _load_config_from_kubeconfig(context) + @_api_logging_decorator('urllib3', logging.ERROR) @functools.lru_cache() @@ -154,3 +168,13 @@ def max_retry_error(): def stream(): return kubernetes.stream.stream + + +def in_cluster_context_name() -> Optional[str]: + """Returns the name of the in-cluster context from the environment. + + If the environment variable is not set, returns the default in-cluster + context name. + """ + return (os.environ.get(IN_CLUSTER_CONTEXT_NAME_ENV_VAR) or + DEFAULT_IN_CLUSTER_REGION) diff --git a/sky/authentication.py b/sky/authentication.py index 41a7d02dfb7..2eb65bd9f6f 100644 --- a/sky/authentication.py +++ b/sky/authentication.py @@ -380,8 +380,8 @@ def setup_kubernetes_authentication(config: Dict[str, Any]) -> Dict[str, Any]: secret_field_name = clouds.Kubernetes().ssh_key_secret_field_name context = config['provider'].get( 'context', kubernetes_utils.get_current_kube_config_context_name()) - if context == kubernetes_utils.IN_CLUSTER_REGION: - # If the context is set to IN_CLUSTER_REGION, we are running in a pod + if context == kubernetes.in_cluster_context_name(): + # If the context is an in-cluster context name, we are running in a pod # with in-cluster configuration. We need to set the context to None # to use the mounted service account. context = None diff --git a/sky/backends/backend.py b/sky/backends/backend.py index 10b51b06038..d5fd6f19925 100644 --- a/sky/backends/backend.py +++ b/sky/backends/backend.py @@ -45,20 +45,45 @@ def check_resources_fit_cluster(self, handle: _ResourceHandleType, @timeline.event @usage_lib.messages.usage.update_runtime('provision') def provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: Optional[str] = None, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: Optional[str] = None, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[_ResourceHandleType]: + """Provisions resources for the given task. + + Args: + task: The task to provision resources for. + to_provision: Resource config to provision. Should only be None if + cluster_name refers to an existing cluster, whose resources will + be used. + dryrun: If True, don't actually provision anything. + stream_logs: If True, stream additional logs to console. + cluster_name: Name of the cluster to provision. If None, a name will + be auto-generated. If the name refers to an existing cluster, + the existing cluster will be reused and re-provisioned. + retry_until_up: If True, retry provisioning until resources are + successfully launched. + skip_if_no_cluster_updates: If True, compare the cluster config to + the existing cluster_name's config. Skip provisioning if no + updates are needed for the existing cluster. + + Returns: + A ResourceHandle object for the provisioned resources, or None if + dryrun is True. + """ if cluster_name is None: cluster_name = sky.backends.backend_utils.generate_cluster_name() usage_lib.record_cluster_name_for_current_operation(cluster_name) usage_lib.messages.usage.update_actual_task(task) with rich_utils.safe_status(ux_utils.spinner_message('Launching')): return self._provision(task, to_provision, dryrun, stream_logs, - cluster_name, retry_until_up) + cluster_name, retry_until_up, + skip_unnecessary_provisioning) @timeline.event @usage_lib.messages.usage.update_runtime('sync_workdir') @@ -126,13 +151,15 @@ def register_info(self, **kwargs) -> None: # --- Implementations of the APIs --- def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[_ResourceHandleType]: + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[_ResourceHandleType]: raise NotImplementedError def _sync_workdir(self, handle: _ResourceHandleType, workdir: Path) -> None: diff --git a/sky/backends/backend_utils.py b/sky/backends/backend_utils.py index e4633ef0671..7292001cc09 100644 --- a/sky/backends/backend_utils.py +++ b/sky/backends/backend_utils.py @@ -3,6 +3,7 @@ import enum import fnmatch import functools +import hashlib import os import pathlib import pprint @@ -100,6 +101,10 @@ CLUSTER_STATUS_LOCK_PATH = os.path.expanduser('~/.sky/.{}.lock') CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS = 20 +# Time that must elapse since the last status check before we should re-check if +# the cluster has been terminated or autostopped. +_CLUSTER_STATUS_CACHE_DURATION_SECONDS = 2 + # Filelocks for updating cluster's file_mounts. CLUSTER_FILE_MOUNTS_LOCK_PATH = os.path.expanduser( '~/.sky/.{}_file_mounts.lock') @@ -640,11 +645,17 @@ def write_cluster_config( keep_launch_fields_in_existing_config: bool = True) -> Dict[str, str]: """Fills in cluster configuration templates and writes them out. - Returns: {provisioner: path to yaml, the provisioning spec}. - 'provisioner' can be - - 'ray' - - 'tpu-create-script' (if TPU is requested) - - 'tpu-delete-script' (if TPU is requested) + Returns: + Dict with the following keys: + - 'ray': Path to the generated Ray yaml config file + - 'cluster_name': Name of the cluster + - 'cluster_name_on_cloud': Name of the cluster as it appears in the + cloud provider + - 'config_hash': Hash of the cluster config and file mounts contents. + Can be missing if we unexpectedly failed to calculate the hash for + some reason. In that case we will continue without the optimization to + skip provisioning. + Raises: exceptions.ResourcesUnavailableError: if the region/zones requested does not appear in the catalog, or an ssh_proxy_command is specified but @@ -679,33 +690,68 @@ def write_cluster_config( resources_utils.ClusterName( cluster_name, cluster_name_on_cloud, - ), region, zones, dryrun) + ), region, zones, num_nodes, dryrun) config_dict = {} specific_reservations = set( skypilot_config.get_nested( (str(to_provision.cloud).lower(), 'specific_reservations'), set())) + # Remote identity handling can have 4 cases: + # 1. LOCAL_CREDENTIALS (default for most clouds): Upload local credentials + # 2. SERVICE_ACCOUNT: SkyPilot creates and manages a service account + # 3. Custom service account: Use specified service account + # 4. NO_UPLOAD: Do not upload any credentials + # + # We need to upload credentials only if LOCAL_CREDENTIALS is specified. In + # other cases, we exclude the cloud from credential file uploads after + # running required checks. assert cluster_name is not None - excluded_clouds = [] + excluded_clouds = set() remote_identity_config = skypilot_config.get_nested( (str(cloud).lower(), 'remote_identity'), None) remote_identity = schemas.get_default_remote_identity(str(cloud).lower()) if isinstance(remote_identity_config, str): remote_identity = remote_identity_config if isinstance(remote_identity_config, list): + # Some clouds (e.g., AWS) support specifying multiple service accounts + # chosen based on the cluster name. Do the matching here to pick the + # correct one. for profile in remote_identity_config: if fnmatch.fnmatchcase(cluster_name, list(profile.keys())[0]): remote_identity = list(profile.values())[0] break if remote_identity != schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value: - if not cloud.supports_service_account_on_remote(): + # If LOCAL_CREDENTIALS is not specified, we add the cloud to the + # excluded_clouds set, but we must also check if the cloud supports + # service accounts. + if remote_identity == schemas.RemoteIdentityOptions.NO_UPLOAD.value: + # If NO_UPLOAD is specified, fall back to default remote identity + # for downstream logic but add it to excluded_clouds to skip + # credential file uploads. + remote_identity = schemas.get_default_remote_identity( + str(cloud).lower()) + elif not cloud.supports_service_account_on_remote(): raise exceptions.InvalidCloudConfigs( 'remote_identity: SERVICE_ACCOUNT is specified in ' f'{skypilot_config.loaded_config_path!r} for {cloud}, but it ' 'is not supported by this cloud. Remove the config or set: ' '`remote_identity: LOCAL_CREDENTIALS`.') - excluded_clouds = [cloud] + if isinstance(cloud, clouds.Kubernetes): + if skypilot_config.get_nested( + ('kubernetes', 'allowed_contexts'), None) is None: + excluded_clouds.add(cloud) + else: + excluded_clouds.add(cloud) + + for cloud_str, cloud_obj in cloud_registry.CLOUD_REGISTRY.items(): + remote_identity_config = skypilot_config.get_nested( + (cloud_str.lower(), 'remote_identity'), None) + if remote_identity_config: + if (remote_identity_config == + schemas.RemoteIdentityOptions.NO_UPLOAD.value): + excluded_clouds.add(cloud_obj) + credentials = sky_check.get_cloud_credential_file_mounts(excluded_clouds) auth_config = {'ssh_private_key': auth.PRIVATE_SSH_KEY_PATH} @@ -810,7 +856,11 @@ def write_cluster_config( '{sky_wheel_hash}', wheel_hash).replace('{cloud}', str(cloud).lower())), - + 'skypilot_wheel_installation_commands': + constants.SKYPILOT_WHEEL_INSTALLATION_COMMANDS.replace( + '{sky_wheel_hash}', + wheel_hash).replace('{cloud}', + str(cloud).lower()), # Port of Ray (GCS server). # Ray's default port 6379 is conflicted with Redis. 'ray_port': constants.SKY_REMOTE_RAY_PORT, @@ -860,6 +910,12 @@ def write_cluster_config( if dryrun: # If dryrun, return the unfinished tmp yaml path. config_dict['ray'] = tmp_yaml_path + try: + config_dict['config_hash'] = _deterministic_cluster_yaml_hash( + tmp_yaml_path) + except Exception as e: # pylint: disable=broad-except + logger.warning(f'Failed to calculate config_hash: {e}') + logger.debug('Full exception:', exc_info=e) return config_dict _add_auth_to_cluster_config(cloud, tmp_yaml_path) @@ -882,6 +938,17 @@ def write_cluster_config( yaml_config = common_utils.read_yaml(tmp_yaml_path) config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name'] + # Make sure to do this before we optimize file mounts. Optimization is + # non-deterministic, but everything else before this point should be + # deterministic. + try: + config_dict['config_hash'] = _deterministic_cluster_yaml_hash( + tmp_yaml_path) + except Exception as e: # pylint: disable=broad-except + logger.warning('Failed to calculate config_hash: ' + f'{common_utils.format_exception(e)}') + logger.debug('Full exception:', exc_info=e) + # Optimization: copy the contents of source files in file_mounts to a # special dir, and upload that as the only file_mount instead. Delay # calling this optimization until now, when all source files have been @@ -990,6 +1057,115 @@ def get_ready_nodes_counts(pattern, output): return ready_head, ready_workers +@timeline.event +def _deterministic_cluster_yaml_hash(yaml_path: str) -> str: + """Hash the cluster yaml and contents of file mounts to a unique string. + + Two invocations of this function should return the same string if and only + if the contents of the yaml are the same and the file contents of all the + file_mounts specified in the yaml are the same. + + Limitations: + - This function can be expensive if the file mounts are large. (E.g. a few + seconds for ~1GB.) This should be okay since we expect that the + file_mounts in the cluster yaml (the wheel and cloud credentials) will be + small. + - Symbolic links are not explicitly handled. Some symbolic link changes may + not be detected. + + Implementation: We create a byte sequence that captures the state of the + yaml file and all the files in the file mounts, then hash the byte sequence. + + The format of the byte sequence is: + 32 bytes - sha256 hash of the yaml file + for each file mount: + file mount remote destination (UTF-8), \0 + if the file mount source is a file: + 'file' encoded to UTF-8 + 32 byte sha256 hash of the file contents + if the file mount source is a directory: + 'dir' encoded to UTF-8 + for each directory and subdirectory withinin the file mount (starting from + the root and descending recursively): + name of the directory (UTF-8), \0 + name of each subdirectory within the directory (UTF-8) terminated by \0 + \0 + for each file in the directory: + name of the file (UTF-8), \0 + 32 bytes - sha256 hash of the file contents + \0 + if the file mount source is something else or does not exist, nothing + \0\0 + + Rather than constructing the whole byte sequence, which may be quite large, + we construct it incrementally by using hash.update() to add new bytes. + """ + + def _hash_file(path: str) -> bytes: + return common_utils.hash_file(path, 'sha256').digest() + + config_hash = hashlib.sha256() + + config_hash.update(_hash_file(yaml_path)) + + yaml_config = common_utils.read_yaml(yaml_path) + file_mounts = yaml_config.get('file_mounts', {}) + # Remove the file mounts added by the newline. + if '' in file_mounts: + assert file_mounts[''] == '', file_mounts[''] + file_mounts.pop('') + + for dst, src in sorted(file_mounts.items()): + expanded_src = os.path.expanduser(src) + config_hash.update(dst.encode('utf-8') + b'\0') + + # If the file mount source is a symlink, this should be true. In that + # case we hash the contents of the symlink destination. + if os.path.isfile(expanded_src): + config_hash.update('file'.encode('utf-8')) + config_hash.update(_hash_file(expanded_src)) + + # This can also be a symlink to a directory. os.walk will treat it as a + # normal directory and list the contents of the symlink destination. + elif os.path.isdir(expanded_src): + config_hash.update('dir'.encode('utf-8')) + + # Aside from expanded_src, os.walk will list symlinks to directories + # but will not recurse into them. + for (dirpath, dirnames, filenames) in os.walk(expanded_src): + config_hash.update(dirpath.encode('utf-8') + b'\0') + + # Note: inplace sort will also affect the traversal order of + # os.walk. We need it so that the os.walk order is + # deterministic. + dirnames.sort() + # This includes symlinks to directories. os.walk will recurse + # into all the directories but not the symlinks. We don't hash + # the link destination, so if a symlink to a directory changes, + # we won't notice. + for dirname in dirnames: + config_hash.update(dirname.encode('utf-8') + b'\0') + config_hash.update(b'\0') + + filenames.sort() + # This includes symlinks to files. We could hash the symlink + # destination itself but instead just hash the destination + # contents. + for filename in filenames: + config_hash.update(filename.encode('utf-8') + b'\0') + config_hash.update( + _hash_file(os.path.join(dirpath, filename))) + config_hash.update(b'\0') + + else: + logger.debug( + f'Unexpected file_mount that is not a file or dir: {src}') + + config_hash.update(b'\0\0') + + return config_hash.hexdigest() + + def get_docker_user(ip: str, cluster_config_file: str) -> str: """Find docker container username.""" ssh_credentials = ssh_credential_from_yaml(cluster_config_file) @@ -1157,18 +1333,18 @@ def ssh_credential_from_yaml( def parallel_data_transfer_to_nodes( - runners: List[command_runner.CommandRunner], - source: Optional[str], - target: str, - cmd: Optional[str], - run_rsync: bool, - *, - action_message: str, - # Advanced options. - log_path: str = os.devnull, - stream_logs: bool = False, - source_bashrc: bool = False, -): + runners: List[command_runner.CommandRunner], + source: Optional[str], + target: str, + cmd: Optional[str], + run_rsync: bool, + *, + action_message: str, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = False, + source_bashrc: bool = False, + num_threads: Optional[int] = None): """Runs a command on all nodes and optionally runs rsync from src->dst. Args: @@ -1180,6 +1356,7 @@ def parallel_data_transfer_to_nodes( log_path: str; Path to the log file stream_logs: bool; Whether to stream logs to stdout source_bashrc: bool; Source bashrc before running the command. + num_threads: Optional[int]; Number of threads to use. """ style = colorama.Style @@ -1220,7 +1397,7 @@ def _sync_node(runner: 'command_runner.CommandRunner') -> None: message = (f' {style.DIM}{action_message} (to {num_nodes} node{plural})' f': {origin_source} -> {target}{style.RESET_ALL}') logger.info(message) - subprocess_utils.run_in_parallel(_sync_node, runners) + subprocess_utils.run_in_parallel(_sync_node, runners, num_threads) def check_local_gpus() -> bool: @@ -1418,6 +1595,7 @@ def check_network_connection(): 'Network seems down.') from e +@timeline.event def check_owner_identity(cluster_name: str) -> None: """Check if current user is the same as the user who created the cluster. @@ -1567,14 +1745,14 @@ def check_can_clone_disk_and_override_task( The task to use and the resource handle of the source cluster. Raises: - ValueError: If the source cluster does not exist. + exceptions.ClusterDoesNotExist: If the source cluster does not exist. exceptions.NotSupportedError: If the source cluster is not valid or the task is not compatible to clone disk from the source cluster. """ source_cluster_status, handle = refresh_cluster_status_handle(cluster_name) if source_cluster_status is None: with ux_utils.print_exception_no_traceback(): - raise ValueError( + raise exceptions.ClusterDoesNotExist( f'Cannot find cluster {cluster_name!r} to clone disk from.') if not isinstance(handle, backends.CloudVmRayResourceHandle): @@ -1668,11 +1846,27 @@ def check_can_clone_disk_and_override_task( def _update_cluster_status_no_lock( cluster_name: str) -> Optional[Dict[str, Any]]: - """Updates the status of the cluster. + """Update the cluster status. + + The cluster status is updated by checking ray cluster and real status from + cloud. + + The function will update the cached cluster status in the global state. For + the design of the cluster status and transition, please refer to the + sky/design_docs/cluster_status.md + + Returns: + If the cluster is terminated or does not exist, return None. Otherwise + returns the input record with status and handle potentially updated. Raises: + exceptions.ClusterOwnerIdentityMismatchError: if the current user is + not the same as the user who created the cluster. + exceptions.CloudUserIdentityError: if we fail to get the current user + identity. exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider. + fetched from the cloud provider or there are leaked nodes causing + the node number larger than expected. """ record = global_user_state.get_cluster_from_name(cluster_name) if record is None: @@ -1892,52 +2086,22 @@ def run_ray_status_to_check_ray_cluster_healthy() -> bool: return global_user_state.get_cluster_from_name(cluster_name) -def _update_cluster_status( - cluster_name: str, - acquire_per_cluster_status_lock: bool, - cluster_status_lock_timeout: int = CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS -) -> Optional[Dict[str, Any]]: - """Update the cluster status. - - The cluster status is updated by checking ray cluster and real status from - cloud. - - The function will update the cached cluster status in the global state. For - the design of the cluster status and transition, please refer to the - sky/design_docs/cluster_status.md - - Args: - cluster_name: The name of the cluster. - acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. - cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. - - Returns: - If the cluster is terminated or does not exist, return None. Otherwise - returns the input record with status and handle potentially updated. +def _must_refresh_cluster_status( + record: Dict[str, Any], + force_refresh_statuses: Optional[Set[status_lib.ClusterStatus]] +) -> bool: + force_refresh_for_cluster = (force_refresh_statuses is not None and + record['status'] in force_refresh_statuses) - Raises: - exceptions.ClusterOwnerIdentityMismatchError: if the current user is - not the same as the user who created the cluster. - exceptions.CloudUserIdentityError: if we fail to get the current user - identity. - exceptions.ClusterStatusFetchingError: the cluster status cannot be - fetched from the cloud provider or there are leaked nodes causing - the node number larger than expected. - """ - if not acquire_per_cluster_status_lock: - return _update_cluster_status_no_lock(cluster_name) + use_spot = record['handle'].launched_resources.use_spot + has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and + record['autostop'] >= 0) + recently_refreshed = (record['status_updated_at'] is not None and + time.time() - record['status_updated_at'] < + _CLUSTER_STATUS_CACHE_DURATION_SECONDS) + is_stale = (use_spot or has_autostop) and not recently_refreshed - try: - with filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name), - timeout=cluster_status_lock_timeout): - return _update_cluster_status_no_lock(cluster_name) - except filelock.Timeout: - logger.debug('Refreshing status: Failed get the lock for cluster ' - f'{cluster_name!r}. Using the cached status.') - record = global_user_state.get_cluster_from_name(cluster_name) - return record + return force_refresh_for_cluster or is_stale def refresh_cluster_record( @@ -1955,16 +2119,22 @@ def refresh_cluster_record( Args: cluster_name: The name of the cluster. - force_refresh_statuses: if specified, refresh the cluster if it has one of - the specified statuses. Additionally, clusters satisfying the - following conditions will always be refreshed no matter the - argument is specified or not: - 1. is a spot cluster, or - 2. is a non-spot cluster, is not STOPPED, and autostop is set. + force_refresh_statuses: if specified, refresh the cluster if it has one + of the specified statuses. Additionally, clusters satisfying the + following conditions will be refreshed no matter the argument is + specified or not: + - the most latest available status update is more than + _CLUSTER_STATUS_CACHE_DURATION_SECONDS old, and one of: + 1. the cluster is a spot cluster, or + 2. cluster autostop is set and the cluster is not STOPPED. acquire_per_cluster_status_lock: Whether to acquire the per-cluster lock - before updating the status. + before updating the status. Even if this is True, the lock may not be + acquired if the status does not need to be refreshed. cluster_status_lock_timeout: The timeout to acquire the per-cluster - lock. If timeout, the function will use the cached status. + lock. If timeout, the function will use the cached status. If the + value is <0, do not timeout (wait for the lock indefinitely). By + default, this is set to CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS. Warning: + if correctness is required, you must set this to -1. Returns: If the cluster is terminated or does not exist, return None. @@ -1985,19 +2155,58 @@ def refresh_cluster_record( return None check_owner_identity(cluster_name) - handle = record['handle'] - if isinstance(handle, backends.CloudVmRayResourceHandle): - use_spot = handle.launched_resources.use_spot - has_autostop = (record['status'] != status_lib.ClusterStatus.STOPPED and - record['autostop'] >= 0) - force_refresh_for_cluster = (force_refresh_statuses is not None and - record['status'] in force_refresh_statuses) - if force_refresh_for_cluster or has_autostop or use_spot: - record = _update_cluster_status( - cluster_name, - acquire_per_cluster_status_lock=acquire_per_cluster_status_lock, - cluster_status_lock_timeout=cluster_status_lock_timeout) - return record + if not isinstance(record['handle'], backends.CloudVmRayResourceHandle): + return record + + # The loop logic allows us to notice if the status was updated in the + # global_user_state by another process and stop trying to get the lock. + # The core loop logic is adapted from FileLock's implementation. + lock = filelock.FileLock(CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) + start_time = time.perf_counter() + + # Loop until we have an up-to-date status or until we acquire the lock. + while True: + # Check to see if we can return the cached status. + if not _must_refresh_cluster_status(record, force_refresh_statuses): + return record + + if not acquire_per_cluster_status_lock: + return _update_cluster_status_no_lock(cluster_name) + + # Try to acquire the lock so we can fetch the status. + try: + with lock.acquire(blocking=False): + # Lock acquired. + + # Check the cluster status again, since it could have been + # updated between our last check and acquiring the lock. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None or not _must_refresh_cluster_status( + record, force_refresh_statuses): + return record + + # Update and return the cluster status. + return _update_cluster_status_no_lock(cluster_name) + except filelock.Timeout: + # lock.acquire() will throw a Timeout exception if the lock is not + # available and we have blocking=False. + pass + + # Logic adapted from FileLock.acquire(). + # If cluster_status_lock_time is <0, we will never hit this. No timeout. + # Otherwise, if we have timed out, return the cached status. This has + # the potential to cause correctness issues, but if so it is the + # caller's responsibility to set the timeout to -1. + if 0 <= cluster_status_lock_timeout < time.perf_counter() - start_time: + logger.debug('Refreshing status: Failed get the lock for cluster ' + f'{cluster_name!r}. Using the cached status.') + return record + time.sleep(0.05) + + # Refresh for next loop iteration. + record = global_user_state.get_cluster_from_name(cluster_name) + if record is None: + return None @timeline.event @@ -2060,7 +2269,7 @@ def check_cluster_available( """Check if the cluster is available. Raises: - ValueError: if the cluster does not exist. + exceptions.ClusterDoesNotExist: if the cluster does not exist. exceptions.ClusterNotUpError: if the cluster is not UP. exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -2125,7 +2334,8 @@ def check_cluster_available( error_msg += message with ux_utils.print_exception_no_traceback(): - raise ValueError(f'{colorama.Fore.YELLOW}{error_msg}{reset}') + raise exceptions.ClusterDoesNotExist( + f'{colorama.Fore.YELLOW}{error_msg}{reset}') assert cluster_status is not None, 'handle is not None but status is None' backend = get_backend_from_handle(handle) if check_cloud_vm_ray_backend and not isinstance( @@ -2603,15 +2813,18 @@ def check_stale_runtime_on_remote(returncode: int, stderr: str, pattern = re.compile(r'AttributeError: module \'sky\.(.*)\' has no ' r'attribute \'(.*)\'') if returncode != 0: + # TODO(zhwu): Backward compatibility for old SkyPilot runtime version on + # the remote cluster. Remove this after 0.10.0 is released. attribute_error = re.findall(pattern, stderr) - if attribute_error: + if attribute_error or 'SkyPilot runtime is too old' in stderr: with ux_utils.print_exception_no_traceback(): raise RuntimeError( f'{colorama.Fore.RED}SkyPilot runtime needs to be updated ' - 'on the remote cluster. To update, run (existing jobs are ' - f'not interrupted): {colorama.Style.BRIGHT}sky start -f -y ' + f'on the remote cluster: {cluster_name}. To update, run ' + '(existing jobs will not be interrupted): ' + f'{colorama.Style.BRIGHT}sky start -f -y ' f'{cluster_name}{colorama.Style.RESET_ALL}' - f'\n--- Details ---\n{stderr.strip()}\n') + f'\n--- Details ---\n{stderr.strip()}\n') from None def get_endpoints(cluster: str, diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 68d9bfecfb0..b7ef8850132 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -269,6 +269,13 @@ def add_prologue(self, job_id: int) -> None: import time from typing import Dict, List, Optional, Tuple, Union + # Set the environment variables to avoid deduplicating logs and + # scheduler events. This should be set in driver code, since we are + # not using `ray job submit` anymore, and the environment variables + # from the ray cluster is not inherited. + os.environ['RAY_DEDUP_LOGS'] = '0' + os.environ['RAY_SCHEDULER_EVENTS'] = '0' + import ray import ray.util as ray_util @@ -276,6 +283,7 @@ def add_prologue(self, job_id: int) -> None: from sky.skylet import constants from sky.skylet import job_lib from sky.utils import log_utils + from sky.utils import subprocess_utils SKY_REMOTE_WORKDIR = {constants.SKY_REMOTE_WORKDIR!r} @@ -293,6 +301,8 @@ def add_prologue(self, job_id: int) -> None: ) def get_or_fail(futures, pg) -> List[int]: \"\"\"Wait for tasks, if any fails, cancel all unready.\"\"\" + if not futures: + return [] returncodes = [1] * len(futures) # Wait for 1 task to be ready. ready = [] @@ -1145,6 +1155,7 @@ def __init__( prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], prev_cluster_ever_up: bool, + prev_config_hash: Optional[str], ) -> None: assert cluster_name is not None, 'cluster_name must be specified.' self.cluster_name = cluster_name @@ -1153,6 +1164,7 @@ def __init__( self.prev_cluster_status = prev_cluster_status self.prev_handle = prev_handle self.prev_cluster_ever_up = prev_cluster_ever_up + self.prev_config_hash = prev_config_hash def __init__(self, log_dir: str, @@ -1314,8 +1326,21 @@ def _retry_zones( prev_cluster_status: Optional[status_lib.ClusterStatus], prev_handle: Optional['CloudVmRayResourceHandle'], prev_cluster_ever_up: bool, + skip_if_config_hash_matches: Optional[str], ) -> Dict[str, Any]: - """The provision retry loop.""" + """The provision retry loop. + + Returns a config_dict with the following fields: + All fields from backend_utils.write_cluster_config(). See its + docstring. + - 'provisioning_skipped': True if provisioning was short-circuited + by skip_if_config_hash_matches, False otherwise. + - 'handle': The provisioned cluster handle. + - 'provision_record': (Only if using the new skypilot provisioner) The + record returned by provisioner.bulk_provision(). + - 'resources_vars': (Only if using the new skypilot provisioner) The + resources variables given by make_deploy_resources_variables(). + """ # Get log_path name log_path = os.path.join(self.log_dir, 'provision.log') log_abs_path = os.path.abspath(log_path) @@ -1424,8 +1449,18 @@ def _retry_zones( raise exceptions.ResourcesUnavailableError( f'Failed to provision on cloud {to_provision.cloud} due to ' f'invalid cloud config: {common_utils.format_exception(e)}') + + if ('config_hash' in config_dict and + skip_if_config_hash_matches == config_dict['config_hash']): + logger.debug('Skipping provisioning of cluster with matching ' + 'config hash.') + config_dict['provisioning_skipped'] = True + return config_dict + config_dict['provisioning_skipped'] = False + if dryrun: return config_dict + cluster_config_file = config_dict['ray'] launched_resources = to_provision.copy(region=region.name) @@ -1527,7 +1562,7 @@ def _retry_zones( to_provision, resources_utils.ClusterName( cluster_name, handle.cluster_name_on_cloud), - region, zones)) + region, zones, num_nodes)) config_dict['provision_record'] = provision_record config_dict['resources_vars'] = resources_vars config_dict['handle'] = handle @@ -1937,8 +1972,13 @@ def provision_with_retries( to_provision_config: ToProvisionConfig, dryrun: bool, stream_logs: bool, + skip_unnecessary_provisioning: bool, ) -> Dict[str, Any]: - """Provision with retries for all launchable resources.""" + """Provision with retries for all launchable resources. + + Returns the config_dict from _retry_zones() - see its docstring for + details. + """ cluster_name = to_provision_config.cluster_name to_provision = to_provision_config.resources num_nodes = to_provision_config.num_nodes @@ -1947,6 +1987,8 @@ def provision_with_retries( prev_cluster_ever_up = to_provision_config.prev_cluster_ever_up launchable_retries_disabled = (self._dag is None or self._optimize_target is None) + skip_if_config_hash_matches = (to_provision_config.prev_config_hash if + skip_unnecessary_provisioning else None) failover_history: List[Exception] = list() @@ -1986,7 +2028,8 @@ def provision_with_retries( cloud_user_identity=cloud_user, prev_cluster_status=prev_cluster_status, prev_handle=prev_handle, - prev_cluster_ever_up=prev_cluster_ever_up) + prev_cluster_ever_up=prev_cluster_ever_up, + skip_if_config_hash_matches=skip_if_config_hash_matches) if dryrun: return config_dict except (exceptions.InvalidClusterNameError, @@ -2687,14 +2730,21 @@ def check_resources_fit_cluster( return valid_resource def _provision( - self, - task: task_lib.Task, - to_provision: Optional[resources_lib.Resources], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False) -> Optional[CloudVmRayResourceHandle]: - """Provisions using 'ray up'. + self, + task: task_lib.Task, + to_provision: Optional[resources_lib.Resources], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, + ) -> Optional[CloudVmRayResourceHandle]: + """Provisions the cluster, or re-provisions an existing cluster. + + Use the SKYPILOT provisioner if it's supported by the cloud, otherwise + use 'ray up'. + + See also docstring for Backend.provision(). Raises: exceptions.ClusterOwnerIdentityMismatchError: if the cluster @@ -2779,7 +2829,8 @@ def _provision( rich_utils.force_update_status( ux_utils.spinner_message('Launching', log_path)) config_dict = retry_provisioner.provision_with_retries( - task, to_provision_config, dryrun, stream_logs) + task, to_provision_config, dryrun, stream_logs, + skip_unnecessary_provisioning) break except exceptions.ResourcesUnavailableError as e: # Do not remove the stopped cluster from the global state @@ -2829,11 +2880,23 @@ def _provision( record = global_user_state.get_cluster_from_name(cluster_name) return record['handle'] if record is not None else None + if config_dict['provisioning_skipped']: + # Skip further provisioning. + # In this case, we won't have certain fields in the config_dict + # ('handle', 'provision_record', 'resources_vars') + # We need to return the handle - but it should be the existing + # handle for the cluster. + record = global_user_state.get_cluster_from_name(cluster_name) + assert record is not None and record['handle'] is not None, ( + cluster_name, record) + return record['handle'] + if 'provision_record' in config_dict: # New provisioner is used here. handle = config_dict['handle'] provision_record = config_dict['provision_record'] resources_vars = config_dict['resources_vars'] + config_hash = config_dict.get('config_hash', None) # Setup SkyPilot runtime after the cluster is provisioned # 1. Wait for SSH to be ready. @@ -2868,7 +2931,7 @@ def _provision( self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, prev_cluster_status, handle.external_ips(), - handle.external_ssh_ports(), lock_path) + handle.external_ssh_ports(), lock_path, config_hash) return handle cluster_config_file = config_dict['ray'] @@ -2940,7 +3003,8 @@ def _get_zone(runner): self._update_after_cluster_provisioned( handle, to_provision_config.prev_handle, task, - prev_cluster_status, ip_list, ssh_port_list, lock_path) + prev_cluster_status, ip_list, ssh_port_list, lock_path, + config_hash) return handle def _open_ports(self, handle: CloudVmRayResourceHandle) -> None: @@ -2958,8 +3022,8 @@ def _update_after_cluster_provisioned( prev_handle: Optional[CloudVmRayResourceHandle], task: task_lib.Task, prev_cluster_status: Optional[status_lib.ClusterStatus], - ip_list: List[str], ssh_port_list: List[int], - lock_path: str) -> None: + ip_list: List[str], ssh_port_list: List[int], lock_path: str, + config_hash: str) -> None: usage_lib.messages.usage.update_cluster_resources( handle.launched_nodes, handle.launched_resources) usage_lib.messages.usage.update_final_cluster_status( @@ -3019,6 +3083,7 @@ def _update_after_cluster_provisioned( handle, set(task.resources), ready=True, + config_hash=config_hash, ) usage_lib.messages.usage.update_final_cluster_status( status_lib.ClusterStatus.UP) @@ -3085,9 +3150,12 @@ def _sync_workdir_node(runner: command_runner.CommandRunner) -> None: f'{workdir} -> {SKY_REMOTE_WORKDIR}{style.RESET_ALL}') os.makedirs(os.path.expanduser(self.log_dir), exist_ok=True) os.system(f'touch {log_path}') + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) with rich_utils.safe_status( ux_utils.spinner_message('Syncing workdir', log_path)): - subprocess_utils.run_in_parallel(_sync_workdir_node, runners) + subprocess_utils.run_in_parallel(_sync_workdir_node, runners, + num_threads) logger.info(ux_utils.finishing_message('Workdir synced.', log_path)) def _sync_file_mounts( @@ -3275,14 +3343,13 @@ def _exec_code_on_head( encoded_script = shlex.quote(codegen) create_script_code = (f'{{ echo {encoded_script} > {script_path}; }}') job_submit_cmd = ( - f'RAY_DASHBOARD_PORT=$({constants.SKY_PYTHON_CMD} -c "from sky.skylet import job_lib; print(job_lib.get_job_submission_port())" 2> /dev/null || echo 8265);' # pylint: disable=line-too-long - f'{cd} && {constants.SKY_RAY_CMD} job submit ' - '--address=http://127.0.0.1:$RAY_DASHBOARD_PORT ' - f'--submission-id {job_id}-$(whoami) --no-wait ' - f'"{constants.SKY_PYTHON_CMD} -u {script_path} ' + # JOB_CMD_IDENTIFIER is used for identifying the process retrieved + # with pid is the same driver process. + f'{job_lib.JOB_CMD_IDENTIFIER.format(job_id)} && ' + f'{cd} && {constants.SKY_PYTHON_CMD} -u {script_path}' # Do not use &>, which is not POSIX and may not work. # Note that the order of ">filename 2>&1" matters. - f'> {remote_log_path} 2>&1"') + f'> {remote_log_path} 2>&1') code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd) job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code]) @@ -3330,6 +3397,10 @@ def _dump_code_to_file(codegen: str) -> None: job_submit_cmd, stream_logs=False, require_outputs=True) + # Happens when someone calls `sky exec` but remote is outdated for + # running a job. Necessitating calling `sky launch`. + backend_utils.check_stale_runtime_on_remote(returncode, stderr, + handle.cluster_name) if returncode == 255 and 'too long' in stdout + stderr: # If the generated script is too long, we retry it with dumping # the script to a file and running it with SSH. We use a general @@ -3344,10 +3415,6 @@ def _dump_code_to_file(codegen: str) -> None: stream_logs=False, require_outputs=True) - # Happens when someone calls `sky exec` but remote is outdated - # necessitating calling `sky launch`. - backend_utils.check_stale_runtime_on_remote(returncode, stdout, - handle.cluster_name) subprocess_utils.handle_returncode(returncode, job_submit_cmd, f'Failed to submit job {job_id}.', @@ -3417,6 +3484,10 @@ def _add_job(self, handle: CloudVmRayResourceHandle, stream_logs=False, require_outputs=True, separate_stderr=True) + # Happens when someone calls `sky exec` but remote is outdated for + # adding a job. Necessitating calling `sky launch`. + backend_utils.check_stale_runtime_on_remote(returncode, stderr, + handle.cluster_name) # TODO(zhwu): this sometimes will unexpectedly fail, we can add # retry for this, after we figure out the reason. subprocess_utils.handle_returncode(returncode, code, @@ -3446,15 +3517,33 @@ def _execute( Returns: Job id if the task is submitted to the cluster, None otherwise. """ - if task.run is None: + if task.run is None and self._setup_cmd is None: + # This message is fine without mentioning setup, as there are three + # cases when run section is empty: + # 1. setup specified, no --detach-setup: setup is executed and this + # message is fine for saying no run command specified. + # 2. setup specified, with --detach-setup: setup is executed in + # detached mode and this message will not be shown. + # 3. no setup specified: this message is fine as a user is likely + # creating a cluster only, and ok with the empty run command. logger.info('Run commands not specified or empty.') return None - # Check the task resources vs the cluster resources. Since `sky exec` - # will not run the provision and _check_existing_cluster - # We need to check ports here since sky.exec shouldn't change resources - valid_resource = self.check_resources_fit_cluster(handle, - task, - check_ports=True) + if task.run is None: + # If the task has no run command, we still need to execute the + # generated ray driver program to run the setup command in detached + # mode. + # In this case, we reset the resources for the task, so that the + # detached setup does not need to wait for the task resources to be + # ready (which is not used for setup anyway). + valid_resource = sky.Resources() + else: + # Check the task resources vs the cluster resources. Since + # `sky exec` will not run the provision and _check_existing_cluster + # We need to check ports here since sky.exec shouldn't change + # resources. + valid_resource = self.check_resources_fit_cluster(handle, + task, + check_ports=True) task_copy = copy.copy(task) # Handle multiple resources exec case. task_copy.set_resources(valid_resource) @@ -3554,7 +3643,7 @@ def _teardown(self, backend_utils.CLUSTER_STATUS_LOCK_PATH.format(cluster_name)) try: - with filelock.FileLock( + with timeline.FileLockEvent( lock_path, backend_utils.CLUSTER_STATUS_LOCK_TIMEOUT_SECONDS): self.teardown_no_lock( @@ -3711,7 +3800,8 @@ def tail_logs(self, handle: CloudVmRayResourceHandle, job_id: Optional[int], managed_job_id: Optional[int] = None, - follow: bool = True) -> int: + follow: bool = True, + tail: int = 0) -> int: """Tail the logs of a job. Args: @@ -3719,10 +3809,13 @@ def tail_logs(self, job_id: The job ID to tail the logs of. managed_job_id: The managed job ID for display purpose only. follow: Whether to follow the logs. + tail: The number of lines to display from the end of the + log file. If 0, print all lines. """ code = job_lib.JobLibCodeGen.tail_logs(job_id, managed_job_id=managed_job_id, - follow=follow) + follow=follow, + tail=tail) if job_id is None and managed_job_id is None: logger.info( 'Job ID not provided. Streaming the logs of the latest job.') @@ -3975,25 +4068,6 @@ def teardown_no_lock(self, stdout = '' stderr = str(e) - # Apr, 2023 by Hysun(hysun.he@oracle.com): Added support for OCI - # May, 2023 by Hysun: Allow terminate INIT cluster which may have - # some instances provisioning in background but not completed. - elif (isinstance(cloud, clouds.OCI) and terminate and - prev_cluster_status in (status_lib.ClusterStatus.STOPPED, - status_lib.ClusterStatus.INIT)): - region = config['provider']['region'] - - # pylint: disable=import-outside-toplevel - from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME - - from sky.skylet.providers.oci.query_helper import oci_query_helper - - # 0: All terminated successfully, failed count otherwise - returncode = oci_query_helper.terminate_instances_by_tags( - {TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud}, region) - - # To avoid undefined local variables error. - stdout = stderr = '' else: config['provider']['cache_stopped_nodes'] = not terminate with tempfile.NamedTemporaryFile('w', @@ -4329,6 +4403,7 @@ def _check_existing_cluster( # cluster is terminated (through console or auto-dwon), the record will # become None and the cluster_ever_up should be considered as False. cluster_ever_up = record is not None and record['cluster_ever_up'] + prev_config_hash = record['config_hash'] if record is not None else None logger.debug(f'cluster_ever_up: {cluster_ever_up}') logger.debug(f'record: {record}') @@ -4367,7 +4442,8 @@ def _check_existing_cluster( handle.launched_nodes, prev_cluster_status=prev_cluster_status, prev_handle=handle, - prev_cluster_ever_up=cluster_ever_up) + prev_cluster_ever_up=cluster_ever_up, + prev_config_hash=prev_config_hash) usage_lib.messages.usage.set_new_cluster() # Use the task_cloud, because the cloud in `to_provision` can be changed # later during the retry. @@ -4408,7 +4484,8 @@ def _check_existing_cluster( task.num_nodes, prev_cluster_status=None, prev_handle=None, - prev_cluster_ever_up=False) + prev_cluster_ever_up=False, + prev_config_hash=prev_config_hash) def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, file_mounts: Optional[Dict[Path, Path]]): @@ -4427,6 +4504,8 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, start = time.time() runners = handle.get_command_runners() log_path = os.path.join(self.log_dir, 'file_mounts.log') + num_threads = subprocess_utils.get_max_workers_for_file_mounts( + file_mounts, str(handle.launched_resources.cloud)) # Check the files and warn for dst, src in file_mounts.items(): @@ -4488,6 +4567,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, action_message='Syncing', log_path=log_path, stream_logs=False, + num_threads=num_threads, ) continue @@ -4524,6 +4604,7 @@ def _execute_file_mounts(self, handle: CloudVmRayResourceHandle, # Need to source bashrc, as the cloud specific CLI or SDK may # require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) # (2) Run the commands to create symlinks on all the nodes. symlink_command = ' && '.join(symlink_commands) @@ -4542,7 +4623,8 @@ def _symlink_node(runner: command_runner.CommandRunner): 'Failed to create symlinks. The target destination ' f'may already exist. Log: {log_path}') - subprocess_utils.run_in_parallel(_symlink_node, runners) + subprocess_utils.run_in_parallel(_symlink_node, runners, + num_threads) end = time.time() logger.debug(f'File mount sync took {end - start} seconds.') logger.info(ux_utils.finishing_message('Files synced.', log_path)) @@ -4571,6 +4653,8 @@ def _execute_storage_mounts( return start = time.time() runners = handle.get_command_runners() + num_threads = subprocess_utils.get_parallel_threads( + str(handle.launched_resources.cloud)) log_path = os.path.join(self.log_dir, 'storage_mounts.log') plural = 's' if len(storage_mounts) > 1 else '' @@ -4609,6 +4693,7 @@ def _execute_storage_mounts( # Need to source bashrc, as the cloud specific CLI or SDK # may require PATH in bashrc. source_bashrc=True, + num_threads=num_threads, ) except exceptions.CommandError as e: if e.returncode == exceptions.MOUNT_PATH_NON_EMPTY_CODE: diff --git a/sky/backends/local_docker_backend.py b/sky/backends/local_docker_backend.py index 2cc3f3347a5..c10e51e7975 100644 --- a/sky/backends/local_docker_backend.py +++ b/sky/backends/local_docker_backend.py @@ -131,13 +131,14 @@ def check_resources_fit_cluster(self, handle: 'LocalDockerResourceHandle', pass def _provision( - self, - task: 'task_lib.Task', - to_provision: Optional['resources.Resources'], - dryrun: bool, - stream_logs: bool, - cluster_name: str, - retry_until_up: bool = False + self, + task: 'task_lib.Task', + to_provision: Optional['resources.Resources'], + dryrun: bool, + stream_logs: bool, + cluster_name: str, + retry_until_up: bool = False, + skip_unnecessary_provisioning: bool = False, ) -> Optional[LocalDockerResourceHandle]: """Builds docker image for the task and returns cluster name as handle. @@ -153,6 +154,9 @@ def _provision( logger.warning( f'Retrying until up is not supported in backend: {self.NAME}. ' 'Ignored the flag.') + if skip_unnecessary_provisioning: + logger.warning(f'skip_unnecessary_provisioning is not supported in ' + f'backend: {self.NAME}. Ignored the flag.') if stream_logs: logger.info( 'Streaming build logs is not supported in LocalDockerBackend. ' diff --git a/sky/cli.py b/sky/cli.py index 29e3b2e51cf..1faf0003ff9 100644 --- a/sky/cli.py +++ b/sky/cli.py @@ -486,7 +486,7 @@ def _parse_override_params( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None) -> Dict[str, Any]: + ports: Optional[Tuple[str, ...]] = None) -> Dict[str, Any]: """Parses the override parameters into a dictionary.""" override_params: Dict[str, Any] = {} if cloud is not None: @@ -539,7 +539,14 @@ def _parse_override_params( else: override_params['disk_tier'] = disk_tier if ports: - override_params['ports'] = ports + if any(p.lower() == 'none' for p in ports): + if len(ports) > 1: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both "none" and other ' + 'ports.') + override_params['ports'] = None + else: + override_params['ports'] = ports return override_params @@ -730,7 +737,7 @@ def _make_task_or_dag_from_entrypoint_with_overrides( image_id: Optional[str] = None, disk_size: Optional[int] = None, disk_tier: Optional[str] = None, - ports: Optional[Tuple[str]] = None, + ports: Optional[Tuple[str, ...]] = None, env: Optional[List[Tuple[str, str]]] = None, field_to_ignore: Optional[List[str]] = None, # job launch specific @@ -1084,7 +1091,7 @@ def launch( env: List[Tuple[str, str]], disk_size: Optional[int], disk_tier: Optional[str], - ports: Tuple[str], + ports: Tuple[str, ...], idle_minutes_to_autostop: Optional[int], down: bool, # pylint: disable=redefined-outer-name retry_until_up: bool, @@ -2022,6 +2029,12 @@ def queue(clusters: List[str], skip_finished: bool, all_users: bool): help=('Follow the logs of a job. ' 'If --no-follow is specified, print the log so far and exit. ' '[default: --follow]')) +@click.option( + '--tail', + default=0, + type=int, + help=('The number of lines to display from the end of the log file. ' + 'Default is 0, which means print all lines.')) @click.argument('cluster', required=True, type=str, @@ -2035,6 +2048,7 @@ def logs( sync_down: bool, status: bool, # pylint: disable=redefined-outer-name follow: bool, + tail: int, ): # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail the log of a job. @@ -2101,7 +2115,7 @@ def logs( click.secho(f'Job {id_str}not found', fg='red') sys.exit(1) - core.tail_logs(cluster, job_id, follow) + core.tail_logs(cluster, job_id, follow, tail) @cli.command() @@ -3095,6 +3109,7 @@ def show_gpus( kubernetes_autoscaling = kubernetes_utils.get_autoscaler_type() is not None kubernetes_is_enabled = sky_clouds.cloud_in_iterable( sky_clouds.Kubernetes(), global_user_state.get_cached_enabled_clouds()) + no_permissions_str = '' def _list_to_str(lst): return ', '.join([str(e) for e in lst]) @@ -3135,13 +3150,16 @@ def _get_kubernetes_realtime_gpu_table( 'in Kubernetes cluster. ') debug_msg = ('To show available accelerators on kubernetes,' ' run: sky show-gpus --cloud kubernetes ') - full_err_msg = (err_msg + kubernetes_utils.NO_GPU_HELP_MESSAGE + + full_err_msg = (err_msg + + kubernetes_utils.NO_ACCELERATOR_HELP_MESSAGE + debug_msg) raise ValueError(full_err_msg) for gpu, _ in sorted(counts.items()): + available_qty = available[gpu] if available[gpu] != -1 else ( + no_permissions_str) realtime_gpu_table.add_row([ gpu, - _list_to_str(counts.pop(gpu)), capacity[gpu], available[gpu] + _list_to_str(counts.pop(gpu)), capacity[gpu], available_qty ]) return realtime_gpu_table @@ -3151,10 +3169,12 @@ def _get_kubernetes_node_info_table(context: Optional[str]): node_info_dict = kubernetes_utils.get_kubernetes_node_info(context) for node_name, node_info in node_info_dict.items(): + available = node_info.free[ + 'accelerators_available'] if node_info.free[ + 'accelerators_available'] != -1 else no_permissions_str node_table.add_row([ - node_name, node_info.gpu_type, - node_info.total['nvidia.com/gpu'], - node_info.free['nvidia.com/gpu'] + node_name, node_info.accelerator_type, + node_info.total['accelerator_count'], available ]) return node_table @@ -3209,8 +3229,18 @@ def _output(): yield from k8s_realtime_table.get_string() k8s_node_table = _get_kubernetes_node_info_table(context) yield '\n\n' + # TODO(Doyoung): Update the message with the multi-host TPU + # support. + k8s_per_node_acc_message = ( + 'Kubernetes per node accelerator availability ') + if kubernetes_utils.multi_host_tpu_exists_in_cluster( + context): + k8s_per_node_acc_message += ( + '(Note: Multi-host TPUs are detected and excluded ' + 'from the display as multi-host TPUs are not ' + 'supported.)') yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}' - f'Kubernetes per node GPU availability' + f'{k8s_per_node_acc_message}' f'{colorama.Style.RESET_ALL}\n') yield from k8s_node_table.get_string() if kubernetes_autoscaling: @@ -3676,13 +3706,24 @@ def jobs_launch( dag_utils.maybe_infer_and_fill_dag_and_task_names(dag) dag_utils.fill_default_config_in_dag_for_job_launch(dag) - click.secho(f'Managed job {dag.name!r} will be launched on (estimated):', - fg='cyan') dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=False) - dag = sky.optimize(dag) - if not yes: + if yes: + # Skip resource preview if -y is set, since we are probably running in + # a script and the user won't have a chance to review it anyway. + # This can save a couple of seconds. + click.secho( + f'Resources for managed job {dag.name!r} will be computed on the ' + 'managed jobs controller, since --yes is set.', + fg='cyan') + + else: + click.secho( + f'Managed job {dag.name!r} will be launched on (estimated):', + fg='cyan') + dag = sky.optimize(dag) + prompt = f'Launching a managed job {dag.name!r}. Proceed?' if prompt is not None: click.confirm(prompt, default=True, abort=True, show_default=True) @@ -3873,16 +3914,25 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool): default=False, help=('Show the controller logs of this job; useful for debugging ' 'launching/recoveries, etc.')) +@click.option( + '--refresh', + '-r', + default=False, + is_flag=True, + required=False, + help='Query the latest job logs, restarting the jobs controller if stopped.' +) @click.argument('job_id', required=False, type=int) @usage_lib.entrypoint def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool): + controller: bool, refresh: bool): """Tail the log of a managed job.""" try: managed_jobs.tail_logs(name=name, job_id=job_id, follow=follow, - controller=controller) + controller=controller, + refresh=refresh) except exceptions.ClusterNotUpError: with ux_utils.print_exception_no_traceback(): raise diff --git a/sky/clouds/aws.py b/sky/clouds/aws.py index 4a9f2d63f35..c42d67f8ba4 100644 --- a/sky/clouds/aws.py +++ b/sky/clouds/aws.py @@ -401,6 +401,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: del dryrun # unused assert zones is not None, (region, zones) @@ -663,6 +664,7 @@ def _is_access_key_of_type(type_str: str) -> bool: return AWSIdentityType.SHARED_CREDENTIALS_FILE @classmethod + @functools.lru_cache(maxsize=1) # Cache since getting identity is slow. def get_user_identities(cls) -> Optional[List[List[str]]]: """Returns a [UserId, Account] list that uniquely identifies the user. diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index 9d399869666..eb76d2b5e48 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -302,6 +302,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Any]: assert zones is None, ('Azure does not support zones', zones) diff --git a/sky/clouds/cloud.py b/sky/clouds/cloud.py index 4028c1fef59..455baeaf5d9 100644 --- a/sky/clouds/cloud.py +++ b/sky/clouds/cloud.py @@ -18,6 +18,7 @@ from sky.clouds import service_catalog from sky.utils import log_utils from sky.utils import resources_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -282,6 +283,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'Region', zones: Optional[List['Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. @@ -366,6 +368,7 @@ def is_label_valid(cls, label_key: str, del label_key, label_value return True, None + @timeline.event def get_feasible_launchable_resources( self, resources: 'resources_lib.Resources', diff --git a/sky/clouds/cudo.py b/sky/clouds/cudo.py index 6f02e007049..145a5d1c26e 100644 --- a/sky/clouds/cudo.py +++ b/sky/clouds/cudo.py @@ -196,6 +196,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: del zones, cluster_name # unused diff --git a/sky/clouds/fluidstack.py b/sky/clouds/fluidstack.py index 31e2112f8f7..2668ea3e5e0 100644 --- a/sky/clouds/fluidstack.py +++ b/sky/clouds/fluidstack.py @@ -176,6 +176,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: diff --git a/sky/clouds/gcp.py b/sky/clouds/gcp.py index 0e20fdc9789..8a28a35505e 100644 --- a/sky/clouds/gcp.py +++ b/sky/clouds/gcp.py @@ -417,6 +417,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: assert zones is not None, (region, zones) diff --git a/sky/clouds/ibm.py b/sky/clouds/ibm.py index 0ac3c36cc48..13f6a27e78a 100644 --- a/sky/clouds/ibm.py +++ b/sky/clouds/ibm.py @@ -170,6 +170,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to cloud-specific resource variables. diff --git a/sky/clouds/kubernetes.py b/sky/clouds/kubernetes.py index d930a24271f..471639626eb 100644 --- a/sky/clouds/kubernetes.py +++ b/sky/clouds/kubernetes.py @@ -10,8 +10,10 @@ from sky import skypilot_config from sky.adaptors import kubernetes from sky.clouds import service_catalog +from sky.provision import instance_setup from sky.provision.kubernetes import network_utils from sky.provision.kubernetes import utils as kubernetes_utils +from sky.skylet import constants from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import schemas @@ -39,6 +41,8 @@ class Kubernetes(clouds.Cloud): SKY_SSH_KEY_SECRET_NAME = 'sky-ssh-keys' SKY_SSH_JUMP_NAME = 'sky-ssh-jump-pod' + LEGACY_SINGLETON_REGION = 'kubernetes' + # Limit the length of the cluster name to avoid exceeding the limit of 63 # characters for Kubernetes resources. We limit to 42 characters (63-21) to # allow additional characters for creating ingress services to expose ports. @@ -52,7 +56,6 @@ class Kubernetes(clouds.Cloud): _DEFAULT_MEMORY_CPU_RATIO = 1 _DEFAULT_MEMORY_CPU_RATIO_WITH_GPU = 4 # Allocate more memory for GPU tasks _REPR = 'Kubernetes' - _LEGACY_SINGLETON_REGION = 'kubernetes' _CLOUD_UNSUPPORTED_FEATURES = { # TODO(romilb): Stopping might be possible to implement with # container checkpointing introduced in Kubernetes v1.25. See: @@ -128,32 +131,30 @@ def _log_skipped_contexts_once(cls, skipped_contexts: Tuple[str, 'Ignoring these contexts.') @classmethod - def _existing_allowed_contexts(cls) -> List[Optional[str]]: + def _existing_allowed_contexts(cls) -> List[str]: """Get existing allowed contexts. If None is returned in the list, it means that we are running in a pod with in-cluster auth. In this case, we specify None context, which will use the service account mounted in the pod. """ - all_contexts = kubernetes_utils.get_all_kube_config_context_names() + all_contexts = kubernetes_utils.get_all_kube_context_names() if len(all_contexts) == 0: return [] - if all_contexts == [None]: - # If only one context is found and it is None, we are running in a - # pod with in-cluster auth. In this case, we allow it to be used - # without checking against allowed_contexts. - # TODO(romilb): We may want check in-cluster auth against - # allowed_contexts in the future by adding a special context name - # for in-cluster auth. - return [None] + all_contexts = set(all_contexts) allowed_contexts = skypilot_config.get_nested( ('kubernetes', 'allowed_contexts'), None) if allowed_contexts is None: + # Try kubeconfig if present current_context = ( kubernetes_utils.get_current_kube_config_context_name()) + if (current_context is None and + kubernetes_utils.is_incluster_config_available()): + # If no kubeconfig contexts found, use in-cluster if available + current_context = kubernetes.in_cluster_context_name() allowed_contexts = [] if current_context is not None: allowed_contexts = [current_context] @@ -178,13 +179,7 @@ def regions_with_offering(cls, instance_type: Optional[str], regions = [] for context in existing_contexts: - if context is None: - # If running in-cluster, we allow the region to be set to the - # singleton region since there is no context name available. - regions.append(clouds.Region( - kubernetes_utils.IN_CLUSTER_REGION)) - else: - regions.append(clouds.Region(context)) + regions.append(clouds.Region(context)) if region is not None: regions = [r for r in regions if r.name == region] @@ -311,12 +306,34 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> int: # we don't have a notion of disk size in Kubernetes. return 0 + @staticmethod + def _calculate_provision_timeout(num_nodes: int) -> int: + """Calculate provision timeout based on number of nodes. + + The timeout scales linearly with the number of nodes to account for + scheduling overhead, but is capped to avoid excessive waiting. + + Args: + num_nodes: Number of nodes being provisioned + + Returns: + Timeout in seconds + """ + base_timeout = 10 # Base timeout for single node + per_node_timeout = 0.2 # Additional seconds per node + max_timeout = 60 # Cap at 1 minute + + return int( + min(base_timeout + (per_node_timeout * (num_nodes - 1)), + max_timeout)) + def make_deploy_resources_variables( self, resources: 'resources_lib.Resources', cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, zones, dryrun # Unused. if region is None: @@ -362,23 +379,48 @@ def make_deploy_resources_variables( k8s_acc_label_key = None k8s_acc_label_value = None + k8s_topology_label_key = None + k8s_topology_label_value = None + k8s_resource_key = None + tpu_requested = False - # If GPUs are requested, set node label to match the GPU type. + # If GPU/TPUs are requested, set node label to match the GPU/TPU type. if acc_count > 0 and acc_type is not None: - k8s_acc_label_key, k8s_acc_label_value = \ - kubernetes_utils.get_gpu_label_key_value(context, acc_type) + (k8s_acc_label_key, k8s_acc_label_value, k8s_topology_label_key, + k8s_topology_label_value) = ( + kubernetes_utils.get_accelerator_label_key_value( + context, acc_type, acc_count)) + if (k8s_acc_label_key == + kubernetes_utils.GKELabelFormatter.TPU_LABEL_KEY): + tpu_requested = True + k8s_resource_key = kubernetes_utils.TPU_RESOURCE_KEY + else: + k8s_resource_key = kubernetes_utils.GPU_RESOURCE_KEY port_mode = network_utils.get_port_mode(None) remote_identity = skypilot_config.get_nested( ('kubernetes', 'remote_identity'), schemas.get_default_remote_identity('kubernetes')) - if (remote_identity == + + if isinstance(remote_identity, dict): + # If remote_identity is a dict, use the service account for the + # current context + k8s_service_account_name = remote_identity.get(context, None) + if k8s_service_account_name is None: + err_msg = (f'Context {context!r} not found in ' + 'remote identities from config.yaml') + raise ValueError(err_msg) + else: + # If remote_identity is not a dict, use + k8s_service_account_name = remote_identity + + if (k8s_service_account_name == schemas.RemoteIdentityOptions.LOCAL_CREDENTIALS.value): # SA name doesn't matter since automounting credentials is disabled k8s_service_account_name = 'default' k8s_automount_sa_token = 'false' - elif (remote_identity == + elif (k8s_service_account_name == schemas.RemoteIdentityOptions.SERVICE_ACCOUNT.value): # Use the default service account k8s_service_account_name = ( @@ -386,7 +428,6 @@ def make_deploy_resources_variables( k8s_automount_sa_token = 'true' else: # User specified a custom service account - k8s_service_account_name = remote_identity k8s_automount_sa_token = 'true' fuse_device_required = bool(resources.requires_fuse) @@ -401,12 +442,30 @@ def make_deploy_resources_variables( # Larger timeout may be required for autoscaling clusters, since # autoscaler may take some time to provision new nodes. # Note that this timeout includes time taken by the Kubernetes scheduler - # itself, which can be upto 2-3 seconds. - # For non-autoscaling clusters, we conservatively set this to 10s. + # itself, which can be upto 2-3 seconds, and up to 10-15 seconds when + # scheduling 100s of pods. + # We use a linear scaling formula to determine the timeout based on the + # number of nodes. + + timeout = self._calculate_provision_timeout(num_nodes) timeout = skypilot_config.get_nested( ('kubernetes', 'provision_timeout'), - 10, + timeout, override_configs=resources.cluster_config_overrides) + + # Set environment variables for the pod. Note that SkyPilot env vars + # are set separately when the task is run. These env vars are + # independent of the SkyPilot task to be run. + k8s_env_vars = {kubernetes.IN_CLUSTER_CONTEXT_NAME_ENV_VAR: context} + + # We specify object-store-memory to be 500MB to avoid taking up too + # much memory on the head node. 'num-cpus' should be set to limit + # the CPU usage on the head pod, otherwise the ray cluster will use the + # CPU resources on the node instead within the pod. + custom_ray_options = { + 'object-store-memory': 500000000, + 'num-cpus': str(int(cpus)), + } deploy_vars = { 'instance_type': resources.instance_type, 'custom_resources': custom_resources, @@ -428,7 +487,18 @@ def make_deploy_resources_variables( 'k8s_skypilot_system_namespace': _SKYPILOT_SYSTEM_NAMESPACE, 'k8s_spot_label_key': spot_label_key, 'k8s_spot_label_value': spot_label_value, + 'tpu_requested': tpu_requested, + 'k8s_topology_label_key': k8s_topology_label_key, + 'k8s_topology_label_value': k8s_topology_label_value, + 'k8s_resource_key': k8s_resource_key, + 'k8s_env_vars': k8s_env_vars, 'image_id': image_id, + 'ray_installation_commands': constants.RAY_INSTALLATION_COMMANDS, + 'ray_head_start_command': instance_setup.ray_head_start_command( + custom_resources, custom_ray_options), + 'skypilot_ray_port': constants.SKY_REMOTE_RAY_PORT, + 'ray_worker_start_command': instance_setup.ray_worker_start_command( + custom_resources, custom_ray_options, no_restart=False), } # Add kubecontext if it is set. It may be None if SkyPilot is running @@ -520,7 +590,11 @@ def _make(instance_list): @classmethod def check_credentials(cls) -> Tuple[bool, Optional[str]]: # Test using python API - existing_allowed_contexts = cls._existing_allowed_contexts() + try: + existing_allowed_contexts = cls._existing_allowed_contexts() + except ImportError as e: + return (False, + f'{common_utils.format_exception(e, use_bracket=True)}') if not existing_allowed_contexts: if skypilot_config.loaded_config_path() is None: check_skypilot_config_msg = '' @@ -557,22 +631,19 @@ def instance_type_exists(self, instance_type: str) -> bool: instance_type) def validate_region_zone(self, region: Optional[str], zone: Optional[str]): - if region == self._LEGACY_SINGLETON_REGION: + if region == self.LEGACY_SINGLETON_REGION: # For backward compatibility, we allow the region to be set to the # legacy singleton region. # TODO: Remove this after 0.9.0. return region, zone - if region == kubernetes_utils.IN_CLUSTER_REGION: + if region == kubernetes.in_cluster_context_name(): # If running incluster, we set region to IN_CLUSTER_REGION # since there is no context name available. return region, zone - all_contexts = kubernetes_utils.get_all_kube_config_context_names() - if all_contexts == [None]: - # If [None] context is returned, use the singleton region since we - # are running in a pod with in-cluster auth. - all_contexts = [kubernetes_utils.IN_CLUSTER_REGION] + all_contexts = kubernetes_utils.get_all_kube_context_names() + if region not in all_contexts: raise ValueError( f'Context {region} not found in kubeconfig. Kubernetes only ' diff --git a/sky/clouds/lambda_cloud.py b/sky/clouds/lambda_cloud.py index 055a5338750..11ec96a78c1 100644 --- a/sky/clouds/lambda_cloud.py +++ b/sky/clouds/lambda_cloud.py @@ -157,6 +157,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'Lambda does not support zones.' diff --git a/sky/clouds/oci.py b/sky/clouds/oci.py index 93a70c5ac37..95f4efe95e3 100644 --- a/sky/clouds/oci.py +++ b/sky/clouds/oci.py @@ -31,6 +31,7 @@ from sky.adaptors import oci as oci_adaptor from sky.clouds import service_catalog from sky.clouds.utils import oci_utils +from sky.provision.oci.query_utils import query_helper from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import ux_utils @@ -60,6 +61,9 @@ class OCI(clouds.Cloud): {resources_utils.DiskTier.ULTRA}) _BEST_DISK_TIER = resources_utils.DiskTier.HIGH + PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT + STATUS_VERSION = clouds.StatusVersion.SKYPILOT + @classmethod def _unsupported_features_for_resources( cls, resources: 'resources_lib.Resources' @@ -71,8 +75,6 @@ def _unsupported_features_for_resources( (f'Docker image is currently not supported on {cls._REPR}. ' 'You can try running docker command inside the ' '`run` section in task.yaml.'), - clouds.CloudImplementationFeatures.OPEN_PORTS: - (f'Opening ports is currently not supported on {cls._REPR}.'), } if resources.use_spot: features[clouds.CloudImplementationFeatures.STOP] = ( @@ -206,6 +208,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: Optional['clouds.Region'], zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert region is not None, resources @@ -433,7 +436,7 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]: return True, None except (oci_adaptor.oci.exceptions.ConfigFileNotFound, oci_adaptor.oci.exceptions.InvalidConfig, - oci_adaptor.service_exception()) as e: + oci_adaptor.oci.exceptions.ServiceError) as e: return False, ( f'OCI credential is not correctly set. ' f'Check the credential file at {conf_file}\n' @@ -597,25 +600,11 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], region: Optional[str], zone: Optional[str], **kwargs) -> List[status_lib.ClusterStatus]: del zone, kwargs # Unused. - # Check the lifecycleState definition from the page - # https://docs.oracle.com/en-us/iaas/api/#/en/iaas/latest/Instance/ - status_map = { - 'PROVISIONING': status_lib.ClusterStatus.INIT, - 'STARTING': status_lib.ClusterStatus.INIT, - 'RUNNING': status_lib.ClusterStatus.UP, - 'STOPPING': status_lib.ClusterStatus.STOPPED, - 'STOPPED': status_lib.ClusterStatus.STOPPED, - 'TERMINATED': None, - 'TERMINATING': None, - } - - # pylint: disable=import-outside-toplevel - from sky.skylet.providers.oci.query_helper import oci_query_helper status_list = [] try: - vms = oci_query_helper.query_instances_by_tags( - tag_filters=tag_filters, region=region) + vms = query_helper.query_instances_by_tags(tag_filters=tag_filters, + region=region) except Exception as e: # pylint: disable=broad-except with ux_utils.print_exception_no_traceback(): raise exceptions.ClusterStatusFetchingError( @@ -625,9 +614,9 @@ def query_status(cls, name: str, tag_filters: Dict[str, str], for node in vms: vm_status = node.lifecycle_state - if vm_status in status_map: - sky_status = status_map[vm_status] - if sky_status is not None: - status_list.append(sky_status) + sky_status = oci_utils.oci_config.STATE_MAPPING_OCI_TO_SKY.get( + vm_status, None) + if sky_status is not None: + status_list.append(sky_status) return status_list diff --git a/sky/clouds/paperspace.py b/sky/clouds/paperspace.py index 4047a2f5926..69a0d69ca61 100644 --- a/sky/clouds/paperspace.py +++ b/sky/clouds/paperspace.py @@ -175,6 +175,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name diff --git a/sky/clouds/runpod.py b/sky/clouds/runpod.py index 0d693fd9f60..487793ecf97 100644 --- a/sky/clouds/runpod.py +++ b/sky/clouds/runpod.py @@ -160,6 +160,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del zones, dryrun, cluster_name # unused diff --git a/sky/clouds/scp.py b/sky/clouds/scp.py index d0ad611bf0c..4a6b8564a97 100644 --- a/sky/clouds/scp.py +++ b/sky/clouds/scp.py @@ -181,6 +181,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False) -> Dict[str, Optional[str]]: del cluster_name, dryrun # Unused. assert zones is None, 'SCP does not support zones.' diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 4deab8ac204..d28b530ff06 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -324,9 +324,8 @@ def get_common_gpus() -> List[str]: 'A100', 'A100-80GB', 'H100', - 'K80', 'L4', - 'M60', + 'L40S', 'P100', 'T4', 'V100', @@ -337,13 +336,13 @@ def get_common_gpus() -> List[str]: def get_tpus() -> List[str]: """Returns a list of TPU names.""" # TODO(wei-lin): refactor below hard-coded list. - # There are many TPU configurations available, we show the three smallest - # and the largest configuration for the latest gen TPUs. + # There are many TPU configurations available, we show the some smallest + # ones for each generation, and people should find larger ones with + # sky show-gpus tpu. return [ - 'tpu-v2-512', 'tpu-v3-2048', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', - 'tpu-v4-3968', 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', - 'tpu-v5litepod-256', 'tpu-v5p-8', 'tpu-v5p-32', 'tpu-v5p-128', - 'tpu-v5p-12288' + 'tpu-v2-8', 'tpu-v3-8', 'tpu-v4-8', 'tpu-v4-16', 'tpu-v4-32', + 'tpu-v5litepod-1', 'tpu-v5litepod-4', 'tpu-v5litepod-8', 'tpu-v5p-8', + 'tpu-v5p-16', 'tpu-v5p-32', 'tpu-v6e-1', 'tpu-v6e-4', 'tpu-v6e-8' ] diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 918a4070414..bbd48863755 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -20,6 +20,7 @@ from sky.utils import common_utils from sky.utils import resources_utils from sky.utils import rich_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -100,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 1082b4e9efd..67c6e09b27e 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -15,6 +15,7 @@ from sky.clouds import cloud as cloud_lib from sky.clouds import cloud_registry from sky.clouds.service_catalog import constants +from sky.utils import common_utils from sky.utils import rich_utils from sky.utils import ux_utils @@ -69,8 +70,7 @@ def is_catalog_modified(filename: str) -> bool: meta_path = os.path.join(_ABSOLUTE_VERSIONED_CATALOG_DIR, '.meta', filename) md5_filepath = meta_path + '.md5' if os.path.exists(md5_filepath): - with open(catalog_path, 'rb') as f: - file_md5 = hashlib.md5(f.read()).hexdigest() + file_md5 = common_utils.hash_file(catalog_path, 'md5').hexdigest() with open(md5_filepath, 'r', encoding='utf-8') as f: last_md5 = f.read() return file_md5 != last_md5 @@ -203,7 +203,8 @@ def _update_catalog(): f'Updating {cloud} catalog: {filename}') + f'{update_frequency_str}'): try: - r = requests.get(url) + r = requests.get(url=url, + headers={'User-Agent': 'SkyPilot/0.7'}) r.raise_for_status() except requests.exceptions.RequestException as e: error_str = (f'Failed to fetch {cloud} catalog ' diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py index f646cac339a..4aef41f9c90 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_azure.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_azure.py @@ -64,7 +64,7 @@ 'standardNVSv2Family': 'M60', 'standardNVSv3Family': 'M60', 'standardNVPromoFamily': 'M60', - 'standardNVSv4Family': 'Radeon MI25', + 'standardNVSv4Family': 'MI25', 'standardNDSFamily': 'P40', 'StandardNVADSA10v5Family': 'A10', 'StandardNCadsH100v5Family': 'H100', diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py b/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py index cf943541e08..7a8b7e42e79 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_fluidstack.py @@ -15,6 +15,26 @@ DEFAULT_FLUIDSTACK_API_KEY_PATH = os.path.expanduser('~/.fluidstack/api_key') plan_vcpus_memory = [{ + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 1, + 'min_cpu_count': 52, + 'min_memory': 450 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 2, + 'min_cpu_count': 52, + 'min_memory': 450 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 4, + 'min_cpu_count': 104, + 'min_memory': 900 +}, { + 'gpu_type': 'H100_SXM5_80GB', + 'gpu_count': 8, + 'min_cpu_count': 192, + 'min_memory': 1800 +}, { 'gpu_type': 'RTX_A6000_48GB', 'gpu_count': 2, 'min_cpu_count': 12, @@ -150,7 +170,8 @@ 'H100_PCIE_80GB': 'H100', 'H100_NVLINK_80GB': 'H100', 'A100_NVLINK_80GB': 'A100-80GB', - 'A100_SXM4_80GB': 'A100-80GB', + 'A100_SXM4_80GB': 'A100-80GB-SXM', + 'H100_SXM5_80GB': 'H100-SXM', 'A100_PCIE_80GB': 'A100-80GB', 'A100_SXM4_40GB': 'A100', 'A100_PCIE_40GB': 'A100', diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py index 8cc9fc6c127..e0ec7f66042 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_gcp.py @@ -380,7 +380,7 @@ def get_vm_df(skus: List[Dict[str, Any]], region_prefix: str) -> 'pd.DataFrame': df = df[~df['AvailabilityZone'].str.startswith(tuple(TPU_V4_ZONES))] # TODO(woosuk): Make this more efficient. - def get_vm_price(row: pd.Series, spot: bool) -> float: + def get_vm_price(row: pd.Series, spot: bool) -> Optional[float]: series = row['InstanceType'].split('-')[0].lower() ondemand_or_spot = 'OnDemand' if not spot else 'Preemptible' @@ -431,12 +431,26 @@ def get_vm_price(row: pd.Series, spot: bool) -> float: if series in ['f1', 'g1']: memory_price = 0.0 - assert cpu_price is not None, row - assert memory_price is not None, row + # TODO(tian): (2024/11/10) Some SKUs are missing in the SKUs API. We + # skip them in the catalog for now. We should investigate why they are + # missing and add them back. + if cpu_price is None or memory_price is None: + return None return cpu_price + memory_price df['Price'] = df.apply(lambda row: get_vm_price(row, spot=False), axis=1) df['SpotPrice'] = df.apply(lambda row: get_vm_price(row, spot=True), axis=1) + dropped_rows = df[df['Price'].isna() & df['SpotPrice'].isna()] + dropped_info = (dropped_rows[['InstanceType', + 'AvailabilityZone']].drop_duplicates()) + az2missing = dropped_info.groupby('AvailabilityZone').apply( + lambda x: x['InstanceType'].tolist()) + print('Price not found for the following zones and instance types. ' + 'Dropping them.') + for az, instances in az2missing.items(): + print('-' * 30, az, '-' * 30) + print(', '.join(instances)) + df = df.dropna(subset=['Price', 'SpotPrice'], how='all') df = df.reset_index(drop=True) df = df.sort_values(['InstanceType', 'Region', 'AvailabilityZone']) return df diff --git a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py index e4bb6e8547a..008bfe6abeb 100644 --- a/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py +++ b/sky/clouds/service_catalog/data_fetchers/fetch_lambda_cloud.py @@ -46,6 +46,7 @@ 'RTX6000': 24576, 'V100': 16384, 'H100': 81920, + 'GH200': 98304, 'GENERAL': None } diff --git a/sky/clouds/service_catalog/kubernetes_catalog.py b/sky/clouds/service_catalog/kubernetes_catalog.py index 7ff8f49c621..2c7eafc20e5 100644 --- a/sky/clouds/service_catalog/kubernetes_catalog.py +++ b/sky/clouds/service_catalog/kubernetes_catalog.py @@ -10,6 +10,7 @@ from sky import check as sky_check from sky import sky_logging from sky.adaptors import common as adaptors_common +from sky.adaptors import kubernetes from sky.clouds import Kubernetes from sky.clouds.service_catalog import CloudFilter from sky.clouds.service_catalog import common @@ -22,6 +23,8 @@ else: pd = adaptors_common.LazyImport('pandas') +logger = sky_logging.init_logger(__name__) + _PULL_FREQUENCY_HOURS = 7 # We keep pull_frequency_hours so we can remotely update the default image paths @@ -62,9 +65,14 @@ def list_accelerators( # TODO(romilb): We should consider putting a lru_cache() with TTL to # avoid multiple calls to kubernetes API in a short period of time (e.g., # from the optimizer). - return list_accelerators_realtime(gpus_only, name_filter, region_filter, - quantity_filter, case_sensitive, - all_regions, require_price)[0] + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=False)[0] def list_accelerators_realtime( @@ -77,6 +85,37 @@ def list_accelerators_realtime( require_price: bool = True ) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, int]]: + return _list_accelerators(gpus_only, + name_filter, + region_filter, + quantity_filter, + case_sensitive, + all_regions, + require_price, + realtime=True) + + +def _list_accelerators( + gpus_only: bool, + name_filter: Optional[str], + region_filter: Optional[str], + quantity_filter: Optional[int], + case_sensitive: bool = True, + all_regions: bool = False, + require_price: bool = True, + realtime: bool = False +) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str, + int]]: + """List accelerators in the Kubernetes cluster. + + If realtime is True, the function will query the cluster to fetch real-time + GPU usage, which is returned in total_accelerators_available. Note that + this may require an expensive list_pod_for_all_namespaces call, which + requires cluster-wide pod read permissions. + + If the user does not have sufficient permissions to list pods in all + namespaces, the function will return free GPUs as -1. + """ # TODO(romilb): This should be refactored to use get_kubernetes_node_info() # function from kubernetes_utils. del all_regions, require_price # Unused. @@ -96,19 +135,31 @@ def list_accelerators_realtime( ) or not kubernetes_utils.check_credentials(context)[0]: return {}, {}, {} - has_gpu = kubernetes_utils.detect_gpu_resource(context) + has_gpu = kubernetes_utils.detect_accelerator_resource(context) if not has_gpu: return {}, {}, {} - label_formatter, _ = kubernetes_utils.detect_gpu_label_formatter(context) - if not label_formatter: + lf, _ = kubernetes_utils.detect_gpu_label_formatter(context) + if not lf: return {}, {}, {} accelerators_qtys: Set[Tuple[str, int]] = set() - key = label_formatter.get_label_key() + keys = lf.get_label_keys() nodes = kubernetes_utils.get_kubernetes_nodes(context) - # Get the pods to get the real-time GPU usage - pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) + pods = None + if realtime: + # Get the pods to get the real-time GPU usage + try: + pods = kubernetes_utils.get_all_pods_in_kubernetes_cluster(context) + except kubernetes.api_exception() as e: + if e.status == 403: + logger.warning( + 'Failed to get pods in the Kubernetes cluster ' + '(forbidden). Please check if your account has ' + 'necessary permissions to list pods. Realtime GPU ' + 'availability information may be incorrect.') + else: + raise # Total number of GPUs in the cluster total_accelerators_capacity: Dict[str, int] = {} # Total number of GPUs currently available in the cluster @@ -116,62 +167,84 @@ def list_accelerators_realtime( min_quantity_filter = quantity_filter if quantity_filter else 1 for node in nodes: - if key in node.metadata.labels: - allocated_qty = 0 - accelerator_name = label_formatter.get_accelerator_from_label_value( - node.metadata.labels.get(key)) - - # Check if name_filter regex matches the accelerator_name - regex_flags = 0 if case_sensitive else re.IGNORECASE - if name_filter and not re.match( - name_filter, accelerator_name, flags=regex_flags): - continue - - accelerator_count = int( - node.status.allocatable.get('nvidia.com/gpu', 0)) - - # Generate the GPU quantities for the accelerators - if accelerator_name and accelerator_count > 0: - count = 1 - while count <= accelerator_count: - accelerators_qtys.add((accelerator_name, count)) - count *= 2 - # Add the accelerator count if it's not already in the set - # (e.g., if there's 12 GPUs, we should have qtys 1, 2, 4, 8, 12) - if accelerator_count not in accelerators_qtys: - accelerators_qtys.add((accelerator_name, accelerator_count)) - - for pod in pods: - # Get all the pods running on the node - if (pod.spec.node_name == node.metadata.name and - pod.status.phase in ['Running', 'Pending']): - # Iterate over all the containers in the pod and sum the - # GPU requests - for container in pod.spec.containers: - if container.resources.requests: - allocated_qty += int( - container.resources.requests.get( - 'nvidia.com/gpu', 0)) - - accelerators_available = accelerator_count - allocated_qty - - if accelerator_count >= min_quantity_filter: - quantized_count = (min_quantity_filter * - (accelerator_count // min_quantity_filter)) - if accelerator_name not in total_accelerators_capacity: - total_accelerators_capacity[ - accelerator_name] = quantized_count - else: - total_accelerators_capacity[ - accelerator_name] += quantized_count - - if accelerator_name not in total_accelerators_available: - total_accelerators_available[accelerator_name] = 0 - if accelerators_available >= min_quantity_filter: - quantized_availability = min_quantity_filter * ( - accelerators_available // min_quantity_filter) - total_accelerators_available[ - accelerator_name] += quantized_availability + for key in keys: + if key in node.metadata.labels: + allocated_qty = 0 + accelerator_name = lf.get_accelerator_from_label_value( + node.metadata.labels.get(key)) + + # Exclude multi-host TPUs from being processed. + # TODO(Doyoung): Remove the logic when adding support for + # multi-host TPUs. + if kubernetes_utils.is_multi_host_tpu(node.metadata.labels): + continue + + # Check if name_filter regex matches the accelerator_name + regex_flags = 0 if case_sensitive else re.IGNORECASE + if name_filter and not re.match( + name_filter, accelerator_name, flags=regex_flags): + continue + + # Generate the accelerator quantities + accelerator_count = ( + kubernetes_utils.get_node_accelerator_count( + node.status.allocatable)) + + if accelerator_name and accelerator_count > 0: + # TPUs are counted in a different way compared to GPUs. + # Multi-node GPUs can be split into smaller units and be + # provisioned, but TPUs are considered as an atomic unit. + if kubernetes_utils.is_tpu_on_gke(accelerator_name): + accelerators_qtys.add( + (accelerator_name, accelerator_count)) + else: + count = 1 + while count <= accelerator_count: + accelerators_qtys.add((accelerator_name, count)) + count *= 2 + # Add the accelerator count if it's not already in the + # set (e.g., if there's 12 GPUs, we should have qtys 1, + # 2, 4, 8, 12) + if accelerator_count not in accelerators_qtys: + accelerators_qtys.add( + (accelerator_name, accelerator_count)) + + if accelerator_count >= min_quantity_filter: + quantized_count = ( + min_quantity_filter * + (accelerator_count // min_quantity_filter)) + if accelerator_name not in total_accelerators_capacity: + total_accelerators_capacity[ + accelerator_name] = quantized_count + else: + total_accelerators_capacity[ + accelerator_name] += quantized_count + + if pods is None: + # If we can't get the pods, we can't get the GPU usage + total_accelerators_available[accelerator_name] = -1 + continue + + for pod in pods: + # Get all the pods running on the node + if (pod.spec.node_name == node.metadata.name and + pod.status.phase in ['Running', 'Pending']): + # Iterate over all the containers in the pod and sum + # the GPU requests + for container in pod.spec.containers: + if container.resources.requests: + allocated_qty += ( + kubernetes_utils.get_node_accelerator_count( + container.resources.requests)) + + accelerators_available = accelerator_count - allocated_qty + + if accelerators_available >= min_quantity_filter: + quantized_availability = min_quantity_filter * ( + accelerators_available // min_quantity_filter) + total_accelerators_available[accelerator_name] = ( + total_accelerators_available.get(accelerator_name, 0) + + quantized_availability) result = [] diff --git a/sky/clouds/service_catalog/oci_catalog.py b/sky/clouds/service_catalog/oci_catalog.py index c8e475df871..b93e9d622e1 100644 --- a/sky/clouds/service_catalog/oci_catalog.py +++ b/sky/clouds/service_catalog/oci_catalog.py @@ -66,7 +66,7 @@ def _get_df() -> 'pd.DataFrame': logger.debug(f'It is OK goes here when testing: {str(e)}') subscribed_regions = [] - except oci_adaptor.service_exception() as e: + except oci_adaptor.oci.exceptions.ServiceError as e: # Should never expect going here. However, we still catch # it so that if any OCI call failed, the program can still # proceed with try-and-error way. diff --git a/sky/clouds/utils/gcp_utils.py b/sky/clouds/utils/gcp_utils.py index cfb893c8cb4..e899c60fa4c 100644 --- a/sky/clouds/utils/gcp_utils.py +++ b/sky/clouds/utils/gcp_utils.py @@ -17,6 +17,7 @@ from sky import sky_logging from sky import skypilot_config from sky.provision.gcp import constants +from sky.provision.kubernetes import utils as kubernetes_utils from sky.utils import subprocess_utils if typing.TYPE_CHECKING: @@ -35,7 +36,10 @@ def is_tpu(resources: Optional['resources_lib.Resources']) -> bool: def is_tpu_vm(resources: Optional['resources_lib.Resources']) -> bool: if not is_tpu(resources): return False - assert resources is not None + assert (resources is not None and len(resources.accelerators) == 1) + acc, _ = list(resources.accelerators.items())[0] + if kubernetes_utils.is_tpu_on_gke(acc): + return False if resources.accelerator_args is None: return True return resources.accelerator_args.get('tpu_vm', True) diff --git a/sky/clouds/utils/oci_utils.py b/sky/clouds/utils/oci_utils.py index 86647071f3e..0cd4f33e647 100644 --- a/sky/clouds/utils/oci_utils.py +++ b/sky/clouds/utils/oci_utils.py @@ -4,14 +4,17 @@ - Zhanghao Wu @ Oct 2023: Formatting and refactoring - Hysun He (hysun.he@oracle.com) @ Oct, 2024: Add default image OS configuration. + - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add the constant + SERVICE_PORT_RULE_TAG """ -import logging import os +from sky import sky_logging from sky import skypilot_config +from sky import status_lib from sky.utils import resources_utils -logger = logging.getLogger(__name__) +logger = sky_logging.init_logger(__name__) class OCIConfig: @@ -41,6 +44,9 @@ class OCIConfig: VCN_CIDR_INTERNET = '0.0.0.0/0' VCN_CIDR = '192.168.0.0/16' VCN_SUBNET_CIDR = '192.168.0.0/18' + SERVICE_PORT_RULE_TAG = 'SkyServe-Service-Port' + # NSG name template + NSG_NAME_TEMPLATE = 'nsg_{cluster_name}' MAX_RETRY_COUNT = 3 RETRY_INTERVAL_BASE_SECONDS = 5 @@ -77,6 +83,19 @@ class OCIConfig: resources_utils.DiskTier.HIGH: DISK_TIER_HIGH, } + # Oracle instance's lifecycle state to sky state mapping. + # For Oracle VM instance's lifecyle state, please refer to the link: + # https://docs.oracle.com/en-us/iaas/api/#/en/iaas/latest/Instance/ + STATE_MAPPING_OCI_TO_SKY = { + 'PROVISIONING': status_lib.ClusterStatus.INIT, + 'STARTING': status_lib.ClusterStatus.INIT, + 'RUNNING': status_lib.ClusterStatus.UP, + 'STOPPING': status_lib.ClusterStatus.STOPPED, + 'STOPPED': status_lib.ClusterStatus.STOPPED, + 'TERMINATED': None, + 'TERMINATING': None, + } + @classmethod def get_compartment(cls, region): # Allow task(cluster)-specific compartment/VCN parameters. diff --git a/sky/clouds/vsphere.py b/sky/clouds/vsphere.py index 88d5df3232a..92e62a8a240 100644 --- a/sky/clouds/vsphere.py +++ b/sky/clouds/vsphere.py @@ -173,6 +173,7 @@ def make_deploy_resources_variables( cluster_name: resources_utils.ClusterName, region: 'clouds.Region', zones: Optional[List['clouds.Zone']], + num_nodes: int, dryrun: bool = False, ) -> Dict[str, Optional[str]]: # TODO get image id here. diff --git a/sky/core.py b/sky/core.py index 496b8b8ad5e..9f1288d7fb6 100644 --- a/sky/core.py +++ b/sky/core.py @@ -268,7 +268,8 @@ def _start( cluster_status, handle = backend_utils.refresh_cluster_status_handle( cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') if not force and cluster_status == status_lib.ClusterStatus.UP: sky_logging.print(f'Cluster {cluster_name!r} is already up.') return handle @@ -359,12 +360,13 @@ def start( Useful for upgrading SkyPilot runtime. Raises: - ValueError: argument values are invalid: (1) the specified cluster does - not exist; (2) if ``down`` is set to True but - ``idle_minutes_to_autostop`` is None; (3) if the specified cluster is - the managed jobs controller, and either ``idle_minutes_to_autostop`` - is not None or ``down`` is True (omit them to use the default - autostop settings). + ValueError: argument values are invalid: (1) if ``down`` is set to True + but ``idle_minutes_to_autostop`` is None; (2) if the specified + cluster is the managed jobs controller, and either + ``idle_minutes_to_autostop`` is not None or ``down`` is True (omit + them to use the default autostop settings). + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. sky.exceptions.NotSupportedError: if the cluster to restart was launched using a non-default backend that does not support this operation. @@ -412,7 +414,8 @@ def stop(cluster_name: str, purge: bool = False) -> None: related resources. Raises: - ValueError: the specified cluster does not exist. + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. RuntimeError: failed to stop the cluster. sky.exceptions.NotSupportedError: if the specified cluster is a spot cluster, or a TPU VM Pod cluster, or the managed jobs controller. @@ -423,7 +426,8 @@ def stop(cluster_name: str, purge: bool = False) -> None: f'is not supported.') handle = global_user_state.get_handle_from_cluster_name(cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') backend = backend_utils.get_backend_from_handle(handle) @@ -467,14 +471,16 @@ def down(cluster_name: str, purge: bool = False) -> None: resources. Raises: - ValueError: the specified cluster does not exist. + sky.exceptions.ClusterDoesNotExist: the specified cluster does not + exist. RuntimeError: failed to tear down the cluster. sky.exceptions.NotSupportedError: the specified cluster is the managed jobs controller. """ handle = global_user_state.get_handle_from_cluster_name(cluster_name) if handle is None: - raise ValueError(f'Cluster {cluster_name!r} does not exist.') + raise exceptions.ClusterDoesNotExist( + f'Cluster {cluster_name!r} does not exist.') usage_lib.record_cluster_name_for_current_operation(cluster_name) backend = backend_utils.get_backend_from_handle(handle) @@ -521,7 +527,7 @@ def autostop( rather than autostop (restartable). Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend or the cluster is TPU VM Pod. @@ -615,7 +621,7 @@ def queue(cluster_name: str, } ] raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -674,7 +680,8 @@ def cancel( worker node is preempted in the spot cluster. Raises: - ValueError: if arguments are invalid, or the cluster does not exist. + ValueError: if arguments are invalid. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the specified cluster is a controller that does not support this operation. @@ -742,15 +749,16 @@ def cancel( @usage_lib.entrypoint def tail_logs(cluster_name: str, job_id: Optional[int], - follow: bool = True) -> None: + follow: bool = True, + tail: int = 0) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail the logs of a job. Please refer to the sky.cli.tail_logs for the document. Raises: - ValueError: arguments are invalid or the cluster is not supported or - the cluster does not exist. + ValueError: if arguments are invalid or the cluster is not supported. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -775,7 +783,7 @@ def tail_logs(cluster_name: str, f'{colorama.Style.RESET_ALL}') usage_lib.record_cluster_name_for_current_operation(cluster_name) - backend.tail_logs(handle, job_id, follow=follow) + backend.tail_logs(handle, job_id, follow=follow, tail=tail) @usage_lib.entrypoint @@ -792,7 +800,7 @@ def download_logs( Returns: Dict[str, str]: a mapping of job_id to local log path. Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. @@ -837,7 +845,7 @@ def job_status(cluster_name: str, If job_ids is None and there is no job on the cluster, it will return {None: None}. Raises: - ValueError: if the cluster does not exist. + sky.exceptions.ClusterDoesNotExist: if the cluster does not exist. sky.exceptions.ClusterNotUpError: if the cluster is not UP. sky.exceptions.NotSupportedError: if the cluster is not based on CloudVmRayBackend. diff --git a/sky/dag.py b/sky/dag.py index dc580d0cf6c..19df5107085 100644 --- a/sky/dag.py +++ b/sky/dag.py @@ -56,21 +56,25 @@ def get_graph(self): return self.graph def is_chain(self) -> bool: - # NOTE: this method assumes that the graph has no cycle. - is_chain = True - visited_zero_out_degree = False - for node in self.graph.nodes: - out_degree = self.graph.out_degree(node) - if out_degree > 1: - is_chain = False - break - elif out_degree == 0: - if visited_zero_out_degree: - is_chain = False - break - else: - visited_zero_out_degree = True - return is_chain + """Check if the DAG is a linear chain of tasks.""" + + nodes = list(self.graph.nodes) + + if len(nodes) == 0: + return True + + in_degrees = [self.graph.in_degree(node) for node in nodes] + out_degrees = [self.graph.out_degree(node) for node in nodes] + + # Check out-degrees: all <= 1 and exactly one node has out_degree == 0 + out_degree_condition = (all(degree <= 1 for degree in out_degrees) and + sum(degree == 0 for degree in out_degrees) == 1) + + # Check in-degrees: all <= 1 and exactly one node has in_degree == 0 + in_degree_condition = (all(degree <= 1 for degree in in_degrees) and + sum(degree == 0 for degree in in_degrees) == 1) + + return out_degree_condition and in_degree_condition class _DagContext(threading.local): diff --git a/sky/exceptions.py b/sky/exceptions.py index c1ade2eb02a..40d2b4d867b 100644 --- a/sky/exceptions.py +++ b/sky/exceptions.py @@ -132,6 +132,13 @@ class ClusterSetUpError(Exception): pass +class ClusterDoesNotExist(ValueError): + """Raise when trying to operate on a cluster that does not exist.""" + # This extends ValueError for compatibility reasons - we used to throw + # ValueError instead of this. + pass + + class NotSupportedError(Exception): """Raised when a feature is not supported.""" pass diff --git a/sky/execution.py b/sky/execution.py index df3cdd5efdb..7392d510b17 100644 --- a/sky/execution.py +++ b/sky/execution.py @@ -11,10 +11,10 @@ from sky import admin_policy from sky import backends from sky import clouds -from sky import exceptions from sky import global_user_state from sky import optimizer from sky import sky_logging +from sky import status_lib from sky.backends import backend_utils from sky.usage import usage_lib from sky.utils import admin_policy_utils @@ -108,6 +108,7 @@ def _execute( idle_minutes_to_autostop: Optional[int] = None, no_setup: bool = False, clone_disk_from: Optional[str] = None, + skip_unnecessary_provisioning: bool = False, # Internal only: # pylint: disable=invalid-name _is_launched_by_jobs_controller: bool = False, @@ -128,8 +129,9 @@ def _execute( Note that if errors occur during provisioning/data syncing/setting up, the cluster will not be torn down for debugging purposes. stream_logs: bool; whether to stream all tasks' outputs to the client. - handle: Optional[backends.ResourceHandle]; if provided, execution will use - an existing backend cluster handle instead of provisioning a new one. + handle: Optional[backends.ResourceHandle]; if provided, execution will + attempt to use an existing backend cluster handle instead of + provisioning a new one. backend: Backend; backend to use for executing the tasks. Defaults to CloudVmRayBackend() retry_until_up: bool; whether to retry the provisioning until the cluster @@ -150,6 +152,11 @@ def _execute( idle_minutes_to_autostop: int; if provided, the cluster will be set to autostop after this many minutes of idleness. no_setup: bool; whether to skip setup commands or not when (re-)launching. + clone_disk_from: Optional[str]; if set, clone the disk from the specified + cluster. + skip_unecessary_provisioning: bool; if True, compare the calculated + cluster config to the current cluster's config. If they match, shortcut + provisioning even if we have Stage.PROVISION. Returns: job_id: Optional[int]; the job ID of the submitted job. None if the @@ -267,6 +274,13 @@ def _execute( # no-credential machine should not enter optimize(), which # would directly error out ('No cloud is enabled...'). Fix # by moving `sky check` checks out of optimize()? + + controller = controller_utils.Controllers.from_name( + cluster_name) + if controller is not None: + logger.info( + f'Choosing resources for {controller.value.name}...' + ) dag = sky.optimize(dag, minimize=optimize_target) task = dag.tasks[0] # Keep: dag may have been deep-copied. assert task.best_resources is not None, task @@ -281,13 +295,18 @@ def _execute( try: if Stage.PROVISION in stages: - if handle is None: - handle = backend.provision(task, - task.best_resources, - dryrun=dryrun, - stream_logs=stream_logs, - cluster_name=cluster_name, - retry_until_up=retry_until_up) + assert handle is None or skip_unnecessary_provisioning, ( + 'Provisioning requested, but handle is already set. PROVISION ' + 'should be excluded from stages or ' + 'skip_unecessary_provisioning should be set. ') + handle = backend.provision( + task, + task.best_resources, + dryrun=dryrun, + stream_logs=stream_logs, + cluster_name=cluster_name, + retry_until_up=retry_until_up, + skip_unnecessary_provisioning=skip_unnecessary_provisioning) if handle is None: assert dryrun, ('If not dryrun, handle must be set or ' @@ -298,7 +317,8 @@ def _execute( do_workdir = (Stage.SYNC_WORKDIR in stages and not dryrun and task.workdir is not None) do_file_mounts = (Stage.SYNC_FILE_MOUNTS in stages and not dryrun and - task.file_mounts is not None) + (task.file_mounts is not None or + task.storage_mounts is not None)) if do_workdir or do_file_mounts: logger.info(ux_utils.starting_message('Mounting files.')) @@ -461,30 +481,50 @@ def launch( handle = None stages = None + skip_unnecessary_provisioning = False # Check if cluster exists and we are doing fast provisioning if fast and cluster_name is not None: - maybe_handle = global_user_state.get_handle_from_cluster_name( - cluster_name) - if maybe_handle is not None: - try: - # This will throw if the cluster is not available - backend_utils.check_cluster_available( + cluster_status, maybe_handle = ( + backend_utils.refresh_cluster_status_handle(cluster_name)) + if cluster_status == status_lib.ClusterStatus.INIT: + # If the cluster is INIT, it may be provisioning. We want to prevent + # concurrent calls from queueing up many sequential reprovision + # attempts. Since provisioning will hold the cluster status lock, we + # wait to hold that lock by force refreshing the status. This will + # block until the cluster finishes provisioning, then correctly see + # that it is UP. + # TODO(cooperc): If multiple processes launched in parallel see that + # the cluster is STOPPED or does not exist, they will still all try + # to provision it, since we do not hold the lock continuously from + # the status check until the provision call. Fixing this requires a + # bigger refactor. + cluster_status, maybe_handle = ( + backend_utils.refresh_cluster_status_handle( cluster_name, - operation='executing tasks', - check_cloud_vm_ray_backend=False, - dryrun=dryrun) - handle = maybe_handle - # Get all stages - stages = [ - Stage.SYNC_WORKDIR, - Stage.SYNC_FILE_MOUNTS, - Stage.PRE_EXEC, - Stage.EXEC, - Stage.DOWN, - ] - except exceptions.ClusterNotUpError: - # Proceed with normal provisioning - pass + force_refresh_statuses=[ + # If the cluster is INIT, we want to try to grab the + # status lock, which should block until provisioning is + # finished. + status_lib.ClusterStatus.INIT, + ], + # Wait indefinitely to obtain the lock, so that we don't + # have multiple processes launching the same cluster at + # once. + cluster_status_lock_timeout=-1, + )) + if cluster_status == status_lib.ClusterStatus.UP: + handle = maybe_handle + stages = [ + # Provisioning will be short-circuited if the existing + # cluster config hash matches the calculated one. + Stage.PROVISION, + Stage.SYNC_WORKDIR, + Stage.SYNC_FILE_MOUNTS, + Stage.PRE_EXEC, + Stage.EXEC, + Stage.DOWN, + ] + skip_unnecessary_provisioning = True return _execute( entrypoint=entrypoint, @@ -502,6 +542,7 @@ def launch( idle_minutes_to_autostop=idle_minutes_to_autostop, no_setup=no_setup, clone_disk_from=clone_disk_from, + skip_unnecessary_provisioning=skip_unnecessary_provisioning, _is_launched_by_jobs_controller=_is_launched_by_jobs_controller, _is_launched_by_sky_serve_controller= _is_launched_by_sky_serve_controller, @@ -558,8 +599,9 @@ def exec( # pylint: disable=redefined-builtin submitted. Raises: - ValueError: if the specified cluster does not exist or is not in UP - status. + ValueError: if the specified cluster is not in UP status. + sky.exceptions.ClusterDoesNotExist: if the specified cluster does not + exist. sky.exceptions.NotSupportedError: if the specified cluster is a controller that does not support this operation. diff --git a/sky/global_user_state.py b/sky/global_user_state.py index 7c040ea55fc..2a5cbc7eb3f 100644 --- a/sky/global_user_state.py +++ b/sky/global_user_state.py @@ -60,7 +60,9 @@ def create_table(cursor, conn): owner TEXT DEFAULT null, cluster_hash TEXT DEFAULT null, storage_mounts_metadata BLOB DEFAULT null, - cluster_ever_up INTEGER DEFAULT 0)""") + cluster_ever_up INTEGER DEFAULT 0, + status_updated_at INTEGER DEFAULT null, + config_hash TEXT DEFAULT null)""") # Table for Cluster History # usage_intervals: List[Tuple[int, int]] @@ -130,6 +132,13 @@ def create_table(cursor, conn): # clusters were never really UP, setting it to 1 means they won't be # auto-deleted during any failover. value_to_replace_existing_entries=1) + + db_utils.add_column_to_table(cursor, conn, 'clusters', 'status_updated_at', + 'INTEGER DEFAULT null') + + db_utils.add_column_to_table(cursor, conn, 'clusters', 'config_hash', + 'TEXT DEFAULT null') + conn.commit() @@ -140,7 +149,8 @@ def add_or_update_cluster(cluster_name: str, cluster_handle: 'backends.ResourceHandle', requested_resources: Optional[Set[Any]], ready: bool, - is_launch: bool = True): + is_launch: bool = True, + config_hash: Optional[str] = None): """Adds or updates cluster_name -> cluster_handle mapping. Args: @@ -159,6 +169,7 @@ def add_or_update_cluster(cluster_name: str, status = status_lib.ClusterStatus.INIT if ready: status = status_lib.ClusterStatus.UP + status_updated_at = int(time.time()) # TODO (sumanth): Cluster history table will have multiple entries # when the cluster failover through multiple regions (one entry per region). @@ -191,7 +202,8 @@ def add_or_update_cluster(cluster_name: str, # specified. '(name, launched_at, handle, last_use, status, ' 'autostop, to_down, metadata, owner, cluster_hash, ' - 'storage_mounts_metadata, cluster_ever_up) ' + 'storage_mounts_metadata, cluster_ever_up, status_updated_at, ' + 'config_hash) ' 'VALUES (' # name '?, ' @@ -228,7 +240,11 @@ def add_or_update_cluster(cluster_name: str, 'COALESCE(' '(SELECT storage_mounts_metadata FROM clusters WHERE name=?), null), ' # cluster_ever_up - '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?)' + '((SELECT cluster_ever_up FROM clusters WHERE name=?) OR ?),' + # status_updated_at + '?,' + # config_hash + 'COALESCE(?, (SELECT config_hash FROM clusters WHERE name=?))' ')', ( # name @@ -260,6 +276,11 @@ def add_or_update_cluster(cluster_name: str, # cluster_ever_up cluster_name, int(ready), + # status_updated_at + status_updated_at, + # config_hash + config_hash, + cluster_name, )) launched_nodes = getattr(cluster_handle, 'launched_nodes', None) @@ -330,11 +351,13 @@ def remove_cluster(cluster_name: str, terminate: bool) -> None: # stopped VM, which leads to timeout. if hasattr(handle, 'stable_internal_external_ips'): handle.stable_internal_external_ips = None + current_time = int(time.time()) _DB.cursor.execute( - 'UPDATE clusters SET handle=(?), status=(?) ' - 'WHERE name=(?)', ( + 'UPDATE clusters SET handle=(?), status=(?), ' + 'status_updated_at=(?) WHERE name=(?)', ( pickle.dumps(handle), status_lib.ClusterStatus.STOPPED.value, + current_time, cluster_name, )) _DB.conn.commit() @@ -359,10 +382,10 @@ def get_glob_cluster_names(cluster_name: str) -> List[str]: def set_cluster_status(cluster_name: str, status: status_lib.ClusterStatus) -> None: - _DB.cursor.execute('UPDATE clusters SET status=(?) WHERE name=(?)', ( - status.value, - cluster_name, - )) + current_time = int(time.time()) + _DB.cursor.execute( + 'UPDATE clusters SET status=(?), status_updated_at=(?) WHERE name=(?)', + (status.value, current_time, cluster_name)) count = _DB.cursor.rowcount _DB.conn.commit() assert count <= 1, count @@ -570,15 +593,18 @@ def _load_storage_mounts_metadata( def get_cluster_from_name( cluster_name: Optional[str]) -> Optional[Dict[str, Any]]: - rows = _DB.cursor.execute('SELECT * FROM clusters WHERE name=(?)', - (cluster_name,)).fetchall() + rows = _DB.cursor.execute( + 'SELECT name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at, config_hash ' + 'FROM clusters WHERE name=(?)', (cluster_name,)).fetchall() for row in rows: # Explicitly specify the number of fields to unpack, so that # we can add new fields to the database in the future without # breaking the previous code. (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at, config_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -594,6 +620,8 @@ def get_cluster_from_name( 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, + 'config_hash': config_hash, } return record return None @@ -601,12 +629,15 @@ def get_cluster_from_name( def get_clusters() -> List[Dict[str, Any]]: rows = _DB.cursor.execute( - 'select * from clusters order by launched_at desc').fetchall() + 'select name, launched_at, handle, last_use, status, autostop, ' + 'metadata, to_down, owner, cluster_hash, storage_mounts_metadata, ' + 'cluster_ever_up, status_updated_at, config_hash ' + 'from clusters order by launched_at desc').fetchall() records = [] for row in rows: (name, launched_at, handle, last_use, status, autostop, metadata, - to_down, owner, cluster_hash, storage_mounts_metadata, - cluster_ever_up) = row[:12] + to_down, owner, cluster_hash, storage_mounts_metadata, cluster_ever_up, + status_updated_at, config_hash) = row[:14] # TODO: use namedtuple instead of dict record = { 'name': name, @@ -622,6 +653,8 @@ def get_clusters() -> List[Dict[str, Any]]: 'storage_mounts_metadata': _load_storage_mounts_metadata(storage_mounts_metadata), 'cluster_ever_up': bool(cluster_ever_up), + 'status_updated_at': status_updated_at, + 'config_hash': config_hash, } records.append(record) diff --git a/sky/jobs/controller.py b/sky/jobs/controller.py index 5219c564500..72dce3e50d7 100644 --- a/sky/jobs/controller.py +++ b/sky/jobs/controller.py @@ -6,7 +6,7 @@ import time import traceback import typing -from typing import Tuple +from typing import Optional, Tuple import filelock @@ -87,18 +87,28 @@ def __init__(self, job_id: int, dag_yaml: str, task.update_envs(task_envs) def _download_log_and_stream( - self, - handle: cloud_vm_ray_backend.CloudVmRayResourceHandle) -> None: - """Downloads and streams the logs of the latest job. + self, task_id: Optional[int], + handle: Optional[cloud_vm_ray_backend.CloudVmRayResourceHandle] + ) -> None: + """Downloads and streams the logs of the current job with given task ID. We do not stream the logs from the cluster directly, as the donwload and stream should be faster, and more robust against preemptions or ssh disconnection during the streaming. """ + if handle is None: + logger.info(f'Cluster for job {self._job_id} is not found. ' + 'Skipping downloading and streaming the logs.') + return managed_job_logs_dir = os.path.join(constants.SKY_LOGS_DIRECTORY, 'managed_jobs') - controller_utils.download_and_stream_latest_job_log( + log_file = controller_utils.download_and_stream_latest_job_log( self._backend, handle, managed_job_logs_dir) + if log_file is not None: + # Set the path of the log file for the current task, so it can be + # accessed even after the job is finished + managed_job_state.set_local_log_file(self._job_id, task_id, + log_file) logger.info(f'\n== End of logs (ID: {self._job_id}) ==') def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: @@ -213,7 +223,8 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: if job_status == job_lib.JobStatus.SUCCEEDED: end_time = managed_job_utils.get_job_timestamp( self._backend, cluster_name, get_end_time=True) - # The job is done. + # The job is done. Set the job to SUCCEEDED first before start + # downloading and streaming the logs to make it more responsive. managed_job_state.set_succeeded(self._job_id, task_id, end_time=end_time, @@ -221,12 +232,21 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: logger.info( f'Managed job {self._job_id} (task: {task_id}) SUCCEEDED. ' f'Cleaning up the cluster {cluster_name}.') + clusters = backend_utils.get_clusters( + cluster_names=[cluster_name], + refresh=False, + include_controller=False) + if clusters: + assert len(clusters) == 1, (clusters, cluster_name) + handle = clusters[0].get('handle') + # Best effort to download and stream the logs. + self._download_log_and_stream(task_id, handle) # Only clean up the cluster, not the storages, because tasks may # share storages. recovery_strategy.terminate_cluster(cluster_name=cluster_name) return True - # For single-node jobs, nonterminated job_status indicates a + # For single-node jobs, non-terminated job_status indicates a # healthy cluster. We can safely continue monitoring. # For multi-node jobs, since the job may not be set to FAILED # immediately (depending on user program) when only some of the @@ -278,7 +298,7 @@ def _run_one_task(self, task_id: int, task: 'sky.Task') -> bool: 'The user job failed. Please check the logs below.\n' f'== Logs of the user job (ID: {self._job_id}) ==\n') - self._download_log_and_stream(handle) + self._download_log_and_stream(task_id, handle) managed_job_status = ( managed_job_state.ManagedJobStatus.FAILED) if job_status == job_lib.JobStatus.FAILED_SETUP: diff --git a/sky/jobs/core.py b/sky/jobs/core.py index 9532bd0fa19..9cde3443816 100644 --- a/sky/jobs/core.py +++ b/sky/jobs/core.py @@ -1,6 +1,7 @@ """SDK functions for managed jobs.""" import os import tempfile +import typing from typing import Any, Dict, List, Optional, Union import uuid @@ -26,9 +27,14 @@ from sky.utils import dag_utils from sky.utils import rich_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils +if typing.TYPE_CHECKING: + from sky.backends import cloud_vm_ray_backend + +@timeline.event @usage_lib.entrypoint def launch( task: Union['sky.Task', 'sky.Dag'], @@ -131,7 +137,6 @@ def launch( controller_task.set_resources(controller_resources) controller_task.managed_job_dag = dag - assert len(controller_task.resources) == 1, controller_task sky_logging.print( f'{colorama.Fore.YELLOW}' @@ -224,6 +229,40 @@ def queue_from_kubernetes_pod( return jobs +def _maybe_restart_controller( + refresh: bool, stopped_message: str, spinner_message: str +) -> 'cloud_vm_ray_backend.CloudVmRayResourceHandle': + """Restart controller if refresh is True and it is stopped.""" + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER + if refresh: + stopped_message = '' + try: + handle = backend_utils.is_controller_accessible( + controller=jobs_controller_type, stopped_message=stopped_message) + except exceptions.ClusterNotUpError as e: + if not refresh: + raise + handle = None + controller_status = e.cluster_status + + if handle is not None: + return handle + + sky_logging.print(f'{colorama.Fore.YELLOW}' + f'Restarting {jobs_controller_type.value.name}...' + f'{colorama.Style.RESET_ALL}') + + rich_utils.force_update_status( + ux_utils.spinner_message(f'{spinner_message} - restarting ' + 'controller')) + handle = sky.start(jobs_controller_type.value.cluster_name) + controller_status = status_lib.ClusterStatus.UP + rich_utils.force_update_status(ux_utils.spinner_message(spinner_message)) + + assert handle is not None, (controller_status, refresh) + return handle + + @usage_lib.entrypoint def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. @@ -251,34 +290,11 @@ def queue(refresh: bool, skip_finished: bool = False) -> List[Dict[str, Any]]: does not exist. RuntimeError: if failed to get the managed jobs with ssh. """ - jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER - stopped_message = '' - if not refresh: - stopped_message = 'No in-progress managed jobs.' - try: - handle = backend_utils.is_controller_accessible( - controller=jobs_controller_type, stopped_message=stopped_message) - except exceptions.ClusterNotUpError as e: - if not refresh: - raise - handle = None - controller_status = e.cluster_status - - if refresh and handle is None: - sky_logging.print(f'{colorama.Fore.YELLOW}' - 'Restarting controller for latest status...' - f'{colorama.Style.RESET_ALL}') - - rich_utils.force_update_status( - ux_utils.spinner_message('Checking managed jobs - restarting ' - 'controller')) - handle = sky.start(jobs_controller_type.value.cluster_name) - controller_status = status_lib.ClusterStatus.UP - rich_utils.force_update_status( - ux_utils.spinner_message('Checking managed jobs')) - - assert handle is not None, (controller_status, refresh) - + handle = _maybe_restart_controller(refresh, + stopped_message='No in-progress ' + 'managed jobs.', + spinner_message='Checking ' + 'managed jobs') backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend) @@ -370,7 +386,7 @@ def cancel(name: Optional[str] = None, @usage_lib.entrypoint def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, - controller: bool) -> None: + controller: bool, refresh: bool) -> None: # NOTE(dev): Keep the docstring consistent between the Python API and CLI. """Tail logs of managed jobs. @@ -381,15 +397,26 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool, sky.exceptions.ClusterNotUpError: the jobs controller is not up. """ # TODO(zhwu): Automatically restart the jobs controller + if name is not None and job_id is not None: + with ux_utils.print_exception_no_traceback(): + raise ValueError('Cannot specify both name and job_id.') + jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER - handle = backend_utils.is_controller_accessible( - controller=jobs_controller_type, + job_name_or_id_str = '' + if job_id is not None: + job_name_or_id_str = str(job_id) + elif name is not None: + job_name_or_id_str = f'-n {name}' + else: + job_name_or_id_str = '' + handle = _maybe_restart_controller( + refresh, stopped_message=( - 'Please restart the jobs controller with ' - f'`sky start {jobs_controller_type.value.cluster_name}`.')) + f'{jobs_controller_type.value.name.capitalize()} is stopped. To ' + f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs ' + f'-r {job_name_or_id_str}{colorama.Style.RESET_ALL}'), + spinner_message='Retrieving job logs') - if name is not None and job_id is not None: - raise ValueError('Cannot specify both name and job_id.') backend = backend_utils.get_backend_from_handle(handle) assert isinstance(backend, backends.CloudVmRayBackend), backend diff --git a/sky/jobs/recovery_strategy.py b/sky/jobs/recovery_strategy.py index 09e4bd8ed6e..4fda1a07e08 100644 --- a/sky/jobs/recovery_strategy.py +++ b/sky/jobs/recovery_strategy.py @@ -50,8 +50,9 @@ def terminate_cluster(cluster_name: str, max_retry: int = 3) -> None: usage_lib.messages.usage.set_internal() sky.down(cluster_name) return - except ValueError: + except exceptions.ClusterDoesNotExist: # The cluster is already down. + logger.debug(f'The cluster {cluster_name} is already down.') return except Exception as e: # pylint: disable=broad-except retry_cnt += 1 diff --git a/sky/jobs/state.py b/sky/jobs/state.py index 6a0e3caeda3..9a5ab4b3cad 100644 --- a/sky/jobs/state.py +++ b/sky/jobs/state.py @@ -66,7 +66,8 @@ def create_table(cursor, conn): spot_job_id INTEGER, task_id INTEGER DEFAULT 0, task_name TEXT, - specs TEXT)""") + specs TEXT, + local_log_file TEXT DEFAULT NULL)""") conn.commit() db_utils.add_column_to_table(cursor, conn, 'spot', 'failure_reason', 'TEXT') @@ -103,6 +104,8 @@ def create_table(cursor, conn): value_to_replace_existing_entries=json.dumps({ 'max_restarts_on_errors': 0, })) + db_utils.add_column_to_table(cursor, conn, 'spot', 'local_log_file', + 'TEXT DEFAULT NULL') # `job_info` contains the mapping from job_id to the job_name. # In the future, it may contain more information about each job. @@ -157,6 +160,7 @@ def _get_db_path() -> str: 'task_id', 'task_name', 'specs', + 'local_log_file', # columns from the job_info table '_job_info_job_id', # This should be the same as job_id 'job_name', @@ -512,6 +516,20 @@ def set_cancelled(job_id: int, callback_func: CallbackType): callback_func('CANCELLED') +def set_local_log_file(job_id: int, task_id: Optional[int], + local_log_file: str): + """Set the local log file for a job.""" + filter_str = 'spot_job_id=(?)' + filter_args = [local_log_file, job_id] + if task_id is not None: + filter_str += ' AND task_id=(?)' + filter_args.append(task_id) + with db_utils.safe_cursor(_DB_PATH) as cursor: + cursor.execute( + 'UPDATE spot SET local_log_file=(?) ' + f'WHERE {filter_str}', filter_args) + + # ======== utility functions ======== def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]: """Get non-terminal job ids by name.""" @@ -662,3 +680,17 @@ def get_task_specs(job_id: int, task_id: int) -> Dict[str, Any]: WHERE spot_job_id=(?) AND task_id=(?)""", (job_id, task_id)).fetchone() return json.loads(task_specs[0]) + + +def get_local_log_file(job_id: int, task_id: Optional[int]) -> Optional[str]: + """Get the local log directory for a job.""" + filter_str = 'spot_job_id=(?)' + filter_args = [job_id] + if task_id is not None: + filter_str += ' AND task_id=(?)' + filter_args.append(task_id) + with db_utils.safe_cursor(_DB_PATH) as cursor: + local_log_file = cursor.execute( + f'SELECT local_log_file FROM spot ' + f'WHERE {filter_str}', filter_args).fetchone() + return local_log_file[-1] if local_log_file else None diff --git a/sky/jobs/utils.py b/sky/jobs/utils.py index 896740f6ed6..267c205285b 100644 --- a/sky/jobs/utils.py +++ b/sky/jobs/utils.py @@ -85,7 +85,8 @@ def get_job_status(backend: 'backends.CloudVmRayBackend', cluster_name: str) -> Optional['job_lib.JobStatus']: """Check the status of the job running on a managed job cluster. - It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_SETUP or CANCELLED. + It can be None, INIT, RUNNING, SUCCEEDED, FAILED, FAILED_DRIVER, + FAILED_SETUP or CANCELLED. """ handle = global_user_state.get_handle_from_cluster_name(cluster_name) assert isinstance(handle, backends.CloudVmRayResourceHandle), handle @@ -326,10 +327,24 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str: if managed_job_status.is_failed(): job_msg = ('\nFailure reason: ' f'{managed_job_state.get_failure_reason(job_id)}') + log_file = managed_job_state.get_local_log_file(job_id, None) + if log_file is not None: + with open(log_file, 'r', encoding='utf-8') as f: + # Stream the logs to the console without reading the whole + # file into memory. + start_streaming = False + for line in f: + if log_lib.LOG_FILE_START_STREAMING_AT in line: + start_streaming = True + if start_streaming: + print(line, end='', flush=True) + return '' return (f'{colorama.Fore.YELLOW}' f'Job {job_id} is already in terminal state ' - f'{managed_job_status.value}. Logs will not be shown.' - f'{colorama.Style.RESET_ALL}{job_msg}') + f'{managed_job_status.value}. For more details, run: ' + f'sky jobs logs --controller {job_id}' + f'{colorama.Style.RESET_ALL}' + f'{job_msg}') backend = backends.CloudVmRayBackend() task_id, managed_job_status = ( managed_job_state.get_latest_task_id_status(job_id)) @@ -866,7 +881,7 @@ def stream_logs(cls, code += inspect.getsource(stream_logs) code += textwrap.dedent(f"""\ - msg = stream_logs({job_id!r}, {job_name!r}, + msg = stream_logs({job_id!r}, {job_name!r}, follow={follow}, controller={controller}) print(msg, flush=True) """) @@ -883,7 +898,7 @@ def set_pending(cls, job_id: int, managed_job_dag: 'dag_lib.Dag') -> str: resources_str = backend_utils.get_task_resources_str( task, is_managed_job=True) code += textwrap.dedent(f"""\ - managed_job_state.set_pending({job_id}, {task_id}, + managed_job_state.set_pending({job_id}, {task_id}, {task.name!r}, {resources_str!r}) """) return cls._build(code) diff --git a/sky/optimizer.py b/sky/optimizer.py index 0f931e15079..2f70dd39429 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -22,6 +22,7 @@ from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -105,6 +106,7 @@ def _egress_time(src_cloud: clouds.Cloud, dst_cloud: clouds.Cloud, return egress_time @staticmethod + @timeline.event def optimize(dag: 'dag_lib.Dag', minimize: OptimizeTarget = OptimizeTarget.COST, blocked_resources: Optional[Iterable[ diff --git a/sky/provision/__init__.py b/sky/provision/__init__.py index bbe92b68c3a..02a627b08a3 100644 --- a/sky/provision/__init__.py +++ b/sky/provision/__init__.py @@ -20,9 +20,11 @@ from sky.provision import gcp from sky.provision import kubernetes from sky.provision import lambda_cloud +from sky.provision import oci from sky.provision import runpod from sky.provision import vsphere from sky.utils import command_runner +from sky.utils import timeline logger = sky_logging.init_logger(__name__) @@ -58,6 +60,7 @@ def _wrapper(*args, **kwargs): # pylint: disable=unused-argument +@timeline.event @_route_to_cloud_impl def query_instances( provider_name: str, diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 60159232787..700d31c597f 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -305,7 +305,8 @@ def _create_vm( network_profile=network_profile, identity=compute.VirtualMachineIdentity( type='UserAssigned', - user_assigned_identities={provider_config['msi']: {}})) + user_assigned_identities={provider_config['msi']: {}}), + priority=node_config['azure_arm_parameters'].get('priority', None)) vm_poller = compute_client.virtual_machines.begin_create_or_update( resource_group_name=provider_config['resource_group'], vm_name=vm_name, diff --git a/sky/provision/docker_utils.py b/sky/provision/docker_utils.py index 3ee5d4dfc0c..c55508ab41a 100644 --- a/sky/provision/docker_utils.py +++ b/sky/provision/docker_utils.py @@ -20,7 +20,7 @@ '{ if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } && ' 'printenv | while IFS=\'=\' read -r key value; do echo "export $key=\\\"$value\\\""; done > ' # pylint: disable=line-too-long '~/container_env_var.sh && ' - '$(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh' + '$(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh;' ) # Docker daemon may not be ready when the machine is firstly started. The error diff --git a/sky/provision/fluidstack/instance.py b/sky/provision/fluidstack/instance.py index 538aafc8887..7fa6cb0463b 100644 --- a/sky/provision/fluidstack/instance.py +++ b/sky/provision/fluidstack/instance.py @@ -79,9 +79,7 @@ def run_instances(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Runs instances for the given cluster.""" - pending_status = [ - 'pending', - ] + pending_status = ['pending', 'provisioning'] while True: instances = _filter_instances(cluster_name_on_cloud, pending_status) if len(instances) > config.count: diff --git a/sky/provision/instance_setup.py b/sky/provision/instance_setup.py index 8c390adaf87..86d1c59f36c 100644 --- a/sky/provision/instance_setup.py +++ b/sky/provision/instance_setup.py @@ -4,7 +4,6 @@ import hashlib import json import os -import resource import time from typing import Any, Callable, Dict, List, Optional, Tuple @@ -20,6 +19,7 @@ from sky.utils import command_runner from sky.utils import common_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) @@ -115,7 +115,8 @@ def _parallel_ssh_with_cache(func, if max_workers is None: # Not using the default value of `max_workers` in ThreadPoolExecutor, # as 32 is too large for some machines. - max_workers = subprocess_utils.get_parallel_threads() + max_workers = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) with futures.ThreadPoolExecutor(max_workers=max_workers) as pool: results = [] runners = provision.get_command_runners(cluster_info.provider_name, @@ -170,6 +171,7 @@ def _initialize_docker(runner: command_runner.CommandRunner, log_path: str): @common.log_function_start_end +@timeline.event def setup_runtime_on_cluster(cluster_name: str, setup_commands: List[str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -245,20 +247,9 @@ def _ray_gpu_options(custom_resource: str) -> str: return f' --num-gpus={acc_count}' -@common.log_function_start_end -@_auto_retry() -def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], - cluster_info: common.ClusterInfo, - ssh_credentials: Dict[str, Any]) -> None: - """Start Ray on the head node.""" - runners = provision.get_command_runners(cluster_info.provider_name, - cluster_info, **ssh_credentials) - head_runner = runners[0] - assert cluster_info.head_instance_id is not None, (cluster_name, - cluster_info) - - # Log the head node's output to the provision.log - log_path_abs = str(provision_logging.get_log_path()) +def ray_head_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]]) -> str: + """Returns the command to start Ray on the head node.""" ray_options = ( # --disable-usage-stats in `ray start` saves 10 seconds of idle wait. f'--disable-usage-stats ' @@ -270,23 +261,14 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], if custom_resource: ray_options += f' --resources=\'{custom_resource}\'' ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - if 'use_external_ip' in cluster_info.custom_ray_options: - cluster_info.custom_ray_options.pop('use_external_ip') - for key, value in cluster_info.custom_ray_options.items(): + if custom_ray_options: + if 'use_external_ip' in custom_ray_options: + custom_ray_options.pop('use_external_ip') + for key, value in custom_ray_options.items(): ray_options += f' --{key}={value}' - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY to avoid using credentials - # from environment variables set by user. SkyPilot's ray cluster should use - # the `~/.aws/` credentials, as that is the one used to create the cluster, - # and the autoscaler module started by the `ray start` command should use - # the same credentials. Otherwise, `ray status` will fail to fetch the - # available nodes. - # Reference: https://github.com/skypilot-org/skypilot/issues/2441 cmd = ( f'{constants.SKY_RAY_CMD} stop; ' - 'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' # worker_maximum_startup_concurrency controls the maximum number of # workers that can be started concurrently. However, it also controls @@ -305,6 +287,62 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], 'RAY_worker_maximum_startup_concurrency=$(( 3 * $(nproc --all) )) ' f'{constants.SKY_RAY_CMD} start --head {ray_options} || exit 1;' + _RAY_PRLIMIT + _DUMP_RAY_PORTS + RAY_HEAD_WAIT_INITIALIZED_COMMAND) + return cmd + + +def ray_worker_start_command(custom_resource: Optional[str], + custom_ray_options: Optional[Dict[str, Any]], + no_restart: bool) -> str: + """Returns the command to start Ray on the worker node.""" + # We need to use the ray port in the env variable, because the head node + # determines the port to be used for the worker node. + ray_options = ('--address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT} ' + '--object-manager-port=8076') + + if custom_resource: + ray_options += f' --resources=\'{custom_resource}\'' + ray_options += _ray_gpu_options(custom_resource) + + if custom_ray_options: + for key, value in custom_ray_options.items(): + ray_options += f' --{key}={value}' + + cmd = ( + 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' + f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' + 'exit 1;' + _RAY_PRLIMIT) + if no_restart: + # We do not use ray status to check whether ray is running, because + # on worker node, if the user started their own ray cluster, ray status + # will return 0, i.e., we don't know skypilot's ray cluster is running. + # Instead, we check whether the raylet process is running on gcs address + # that is connected to the head with the correct port. + cmd = ( + f'ps aux | grep "ray/raylet/raylet" | ' + 'grep "gcs-address=${SKYPILOT_RAY_HEAD_IP}:${SKYPILOT_RAY_PORT}" ' + f'|| {{ {cmd} }}') + else: + cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + return cmd + + +@common.log_function_start_end +@_auto_retry() +@timeline.event +def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], + cluster_info: common.ClusterInfo, + ssh_credentials: Dict[str, Any]) -> None: + """Start Ray on the head node.""" + runners = provision.get_command_runners(cluster_info.provider_name, + cluster_info, **ssh_credentials) + head_runner = runners[0] + assert cluster_info.head_instance_id is not None, (cluster_name, + cluster_info) + + # Log the head node's output to the provision.log + log_path_abs = str(provision_logging.get_log_path()) + cmd = ray_head_start_command(custom_resource, + cluster_info.custom_ray_options) logger.info(f'Running command on head node: {cmd}') # TODO(zhwu): add the output to log files. returncode, stdout, stderr = head_runner.run( @@ -324,6 +362,7 @@ def start_ray_on_head_node(cluster_name: str, custom_resource: Optional[str], @common.log_function_start_end @_auto_retry() +@timeline.event def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, custom_resource: Optional[str], ray_port: int, cluster_info: common.ClusterInfo, @@ -358,43 +397,17 @@ def start_ray_on_worker_nodes(cluster_name: str, no_restart: bool, head_ip = (head_instance.internal_ip if not use_external_ip else head_instance.external_ip) - ray_options = (f'--address={head_ip}:{constants.SKY_REMOTE_RAY_PORT} ' - f'--object-manager-port=8076') - - if custom_resource: - ray_options += f' --resources=\'{custom_resource}\'' - ray_options += _ray_gpu_options(custom_resource) - - if cluster_info.custom_ray_options: - for key, value in cluster_info.custom_ray_options.items(): - ray_options += f' --{key}={value}' + ray_cmd = ray_worker_start_command(custom_resource, + cluster_info.custom_ray_options, + no_restart) - # Unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY, see the comment in - # `start_ray_on_head_node`. - cmd = ( - f'unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY; ' - 'RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 ' - f'{constants.SKY_RAY_CMD} start --disable-usage-stats {ray_options} || ' - 'exit 1;' + _RAY_PRLIMIT) - if no_restart: - # We do not use ray status to check whether ray is running, because - # on worker node, if the user started their own ray cluster, ray status - # will return 0, i.e., we don't know skypilot's ray cluster is running. - # Instead, we check whether the raylet process is running on gcs address - # that is connected to the head with the correct port. - cmd = (f'RAY_PORT={ray_port}; ps aux | grep "ray/raylet/raylet" | ' - f'grep "gcs-address={head_ip}:${{RAY_PORT}}" || ' - f'{{ {cmd} }}') - else: - cmd = f'{constants.SKY_RAY_CMD} stop; ' + cmd + cmd = (f'export SKYPILOT_RAY_HEAD_IP="{head_ip}"; ' + f'export SKYPILOT_RAY_PORT={ray_port}; ' + ray_cmd) logger.info(f'Running command on worker nodes: {cmd}') def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, str]): - # for cmd in config_from_yaml['worker_start_ray_commands']: - # cmd = cmd.replace('$RAY_HEAD_IP', ip_list[0][0]) - # runner.run(cmd) runner, instance_id = runner_and_id log_dir = metadata_utils.get_instance_log_dir(cluster_name, instance_id) log_path_abs = str(log_dir / ('ray_cluster' + '.log')) @@ -407,8 +420,10 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, # by ray will have the correct PATH. source_bashrc=True) + num_threads = subprocess_utils.get_parallel_threads( + cluster_info.provider_name) results = subprocess_utils.run_in_parallel( - _setup_ray_worker, list(zip(worker_runners, cache_ids))) + _setup_ray_worker, list(zip(worker_runners, cache_ids)), num_threads) for returncode, stdout, stderr in results: if returncode: with ux_utils.print_exception_no_traceback(): @@ -421,6 +436,7 @@ def _setup_ray_worker(runner_and_id: Tuple[command_runner.CommandRunner, @common.log_function_start_end @_auto_retry() +@timeline.event def start_skylet_on_head_node(cluster_name: str, cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, Any]) -> None: @@ -482,28 +498,8 @@ def _internal_file_mounts(file_mounts: Dict, ) -def _max_workers_for_file_mounts(common_file_mounts: Dict[str, str]) -> int: - fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) - - fd_per_rsync = 5 - for src in common_file_mounts.values(): - if os.path.isdir(src): - # Assume that each file/folder under src takes 5 file descriptors - # on average. - fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) - - # Reserve some file descriptors for the system and other processes - fd_reserve = 100 - - max_workers = (fd_limit - fd_reserve) // fd_per_rsync - # At least 1 worker, and avoid too many workers overloading the system. - max_workers = min(max(max_workers, 1), - subprocess_utils.get_parallel_threads()) - logger.debug(f'Using {max_workers} workers for file mounts.') - return max_workers - - @common.log_function_start_end +@timeline.event def internal_file_mounts(cluster_name: str, common_file_mounts: Dict[str, str], cluster_info: common.ClusterInfo, ssh_credentials: Dict[str, str]) -> None: @@ -524,4 +520,5 @@ def _setup_node(runner: command_runner.CommandRunner, log_path: str): digest=None, cluster_info=cluster_info, ssh_credentials=ssh_credentials, - max_workers=_max_workers_for_file_mounts(common_file_mounts)) + max_workers=subprocess_utils.get_max_workers_for_file_mounts( + common_file_mounts, cluster_info.provider_name)) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index 14eea45149c..2b13e78fdf8 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -2,7 +2,7 @@ import copy import json import time -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import uuid from sky import exceptions @@ -20,12 +20,13 @@ from sky.utils import common_utils from sky.utils import kubernetes_enums from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils POLL_INTERVAL = 2 _TIMEOUT_FOR_POD_TERMINATION = 60 # 1 minutes _MAX_RETRIES = 3 -NUM_THREADS = subprocess_utils.get_parallel_threads() * 2 +_NUM_THREADS = subprocess_utils.get_parallel_threads('kubernetes') logger = sky_logging.init_logger(__name__) TAG_RAY_CLUSTER_NAME = 'ray-cluster-name' @@ -47,6 +48,72 @@ def head_service_selector(cluster_name: str) -> Dict[str, str]: return {'component': f'{cluster_name}-head'} +def _formatted_resource_requirements(pod_or_spec: Union[Any, dict]) -> str: + # Returns a formatted string of resource requirements for a pod. + resource_requirements = {} + + if isinstance(pod_or_spec, dict): + containers = pod_or_spec.get('spec', {}).get('containers', []) + else: + containers = pod_or_spec.spec.containers + + for container in containers: + if isinstance(container, dict): + resources = container.get('resources', {}) + requests = resources.get('requests', {}) + else: + resources = container.resources + requests = resources.requests or {} + + for resource, value in requests.items(): + if resource not in resource_requirements: + resource_requirements[resource] = 0 + if resource == 'memory': + int_value = kubernetes_utils.parse_memory_resource(value) + else: + int_value = kubernetes_utils.parse_cpu_or_gpu_resource(value) + resource_requirements[resource] += int(int_value) + return ', '.join(f'{resource}={value}' + for resource, value in resource_requirements.items()) + + +def _formatted_node_selector(pod_or_spec: Union[Any, dict]) -> Optional[str]: + # Returns a formatted string of node selectors for a pod. + node_selectors = [] + + if isinstance(pod_or_spec, dict): + selectors = pod_or_spec.get('spec', {}).get('nodeSelector', {}) + else: + selectors = pod_or_spec.spec.node_selector + + if not selectors: + return None + + for label_key, label_value in selectors.items(): + node_selectors.append(f'{label_key}={label_value}') + return ', '.join(node_selectors) + + +def _lack_resource_msg(resource: str, + pod_or_spec: Union[Any, dict], + extra_msg: Optional[str] = None, + details: Optional[str] = None) -> str: + resource_requirements = _formatted_resource_requirements(pod_or_spec) + node_selectors = _formatted_node_selector(pod_or_spec) + node_selector_str = f' and labels ({node_selectors})' if ( + node_selectors) else '' + msg = (f'Insufficient {resource} capacity on the cluster. ' + f'Required resources ({resource_requirements}){node_selector_str} ' + 'were not found in a single node. Other SkyPilot tasks or pods may ' + 'be using resources. Check resource usage by running ' + '`kubectl describe nodes`.') + if extra_msg: + msg += f' {extra_msg}' + if details: + msg += f'\nFull error: {details}' + return msg + + def _raise_pod_scheduling_errors(namespace, context, new_nodes): """Raise pod scheduling failure reason. @@ -54,52 +121,9 @@ def _raise_pod_scheduling_errors(namespace, context, new_nodes): are recorded as events. This function retrieves those events and raises descriptive errors for better debugging and user feedback. """ - - def _formatted_resource_requirements(pod): - # Returns a formatted string of resource requirements for a pod. - resource_requirements = {} - for container in pod.spec.containers: - for resource, value in container.resources.requests.items(): - if resource not in resource_requirements: - resource_requirements[resource] = 0 - if resource == 'memory': - int_value = kubernetes_utils.parse_memory_resource(value) - else: - int_value = kubernetes_utils.parse_cpu_or_gpu_resource( - value) - resource_requirements[resource] += int_value - return ', '.join(f'{resource}={value}' - for resource, value in resource_requirements.items()) - - def _formatted_node_selector(pod) -> Optional[str]: - # Returns a formatted string of node selectors for a pod. - node_selectors = [] - if pod.spec.node_selector is None: - return None - for label_key, label_value in pod.spec.node_selector.items(): - node_selectors.append(f'{label_key}={label_value}') - return ', '.join(node_selectors) - - def _lack_resource_msg(resource: str, - pod, - extra_msg: Optional[str] = None, - details: Optional[str] = None) -> str: - resource_requirements = _formatted_resource_requirements(pod) - node_selectors = _formatted_node_selector(pod) - node_selector_str = f' and labels ({node_selectors})' if ( - node_selectors) else '' - msg = ( - f'Insufficient {resource} capacity on the cluster. ' - f'Required resources ({resource_requirements}){node_selector_str} ' - 'were not found in a single node. Other SkyPilot tasks or pods may ' - 'be using resources. Check resource usage by running ' - '`kubectl describe nodes`.') - if extra_msg: - msg += f' {extra_msg}' - if details: - msg += f'\nFull error: {details}' - return msg - + timeout_err_msg = ('Timed out while waiting for nodes to start. ' + 'Cluster may be out of resources or ' + 'may be too slow to autoscale.') for new_node in new_nodes: pod = kubernetes.core_api(context).read_namespaced_pod( new_node.metadata.name, namespace) @@ -128,9 +152,6 @@ def _lack_resource_msg(resource: str, if event.reason == 'FailedScheduling': event_message = event.message break - timeout_err_msg = ('Timed out while waiting for nodes to start. ' - 'Cluster may be out of resources or ' - 'may be too slow to autoscale.') if event_message is not None: if pod_status == 'Pending': logger.info(event_message) @@ -148,8 +169,8 @@ def _lack_resource_msg(resource: str, '`kubectl delete pods -n skypilot-system -l name=smarter-device-manager`.' # pylint: disable=line-too-long f' Full error: {event_message}') gpu_lf_keys = [ - lf.get_label_key() - for lf in kubernetes_utils.LABEL_FORMATTER_REGISTRY + key for lf in kubernetes_utils.LABEL_FORMATTER_REGISTRY + for key in lf.get_label_keys() ] if pod.spec.node_selector: for label_key in pod.spec.node_selector.keys(): @@ -157,10 +178,24 @@ def _lack_resource_msg(resource: str, # TODO(romilb): We may have additional node # affinity selectors in the future - in that # case we will need to update this logic. - if (('Insufficient nvidia.com/gpu' - in event_message) or - ('didn\'t match Pod\'s node affinity/selector' - in event_message)): + # TODO(Doyoung): Update the error message raised + # with the multi-host TPU support. + if 'Insufficient google.com/tpu' in event_message: + extra_msg = ( + f'Verify if ' + f'{pod.spec.node_selector[label_key]}' + ' is available in the cluster. Note ' + 'that multi-host TPU podslices are ' + 'currently not unsupported.') + raise config_lib.KubernetesError( + _lack_resource_msg('TPU', + pod, + extra_msg, + details=event_message)) + elif (('Insufficient nvidia.com/gpu' + in event_message) or + ('didn\'t match Pod\'s node affinity/selector' + in event_message)): extra_msg = ( f'Verify if ' f'{pod.spec.node_selector[label_key]}' @@ -185,6 +220,7 @@ def _raise_command_running_error(message: str, command: str, pod_name: str, f'code {rc}: {command!r}\nOutput: {stdout}.') +@timeline.event def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): """Wait for all pods to be scheduled. @@ -195,6 +231,10 @@ def _wait_for_pods_to_schedule(namespace, context, new_nodes, timeout: int): If timeout is set to a negative value, this method will wait indefinitely. """ + # Create a set of pod names we're waiting for + if not new_nodes: + return + expected_pod_names = {node.metadata.name for node in new_nodes} start_time = time.time() def _evaluate_timeout() -> bool: @@ -204,19 +244,34 @@ def _evaluate_timeout() -> bool: return time.time() - start_time < timeout while _evaluate_timeout(): - all_pods_scheduled = True - for node in new_nodes: - # Iterate over each pod to check their status - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) - if pod.status.phase == 'Pending': + # Get all pods in a single API call using the cluster name label + # which all pods in new_nodes should share + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying waiting for pods: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + + # Check if all pods are scheduled + all_scheduled = True + for pod in pods: + if (pod.metadata.name in expected_pod_names and + pod.status.phase == 'Pending'): # If container_statuses is None, then the pod hasn't # been scheduled yet. if pod.status.container_statuses is None: - all_pods_scheduled = False + all_scheduled = False break - if all_pods_scheduled: + if all_scheduled: return time.sleep(1) @@ -232,12 +287,18 @@ def _evaluate_timeout() -> bool: f'Error: {common_utils.format_exception(e)}') from None +@timeline.event def _wait_for_pods_to_run(namespace, context, new_nodes): """Wait for pods and their containers to be ready. Pods may be pulling images or may be in the process of container creation. """ + if not new_nodes: + return + + # Create a set of pod names we're waiting for + expected_pod_names = {node.metadata.name for node in new_nodes} def _check_init_containers(pod): # Check if any of the init containers failed @@ -265,12 +326,25 @@ def _check_init_containers(pod): f'{pod.metadata.name}. Error details: {msg}.') while True: - all_pods_running = True - # Iterate over each pod to check their status - for node in new_nodes: - pod = kubernetes.core_api(context).read_namespaced_pod( - node.metadata.name, namespace) + # Get all pods in a single API call + cluster_name = new_nodes[0].metadata.labels[TAG_SKYPILOT_CLUSTER_NAME] + all_pods = kubernetes.core_api(context).list_namespaced_pod( + namespace, + label_selector=f'{TAG_SKYPILOT_CLUSTER_NAME}={cluster_name}').items + + # Get the set of found pod names and check if we have all expected pods + found_pod_names = {pod.metadata.name for pod in all_pods} + missing_pods = expected_pod_names - found_pod_names + if missing_pods: + logger.info('Retrying running pods check: ' + f'Missing pods: {missing_pods}') + time.sleep(0.5) + continue + all_pods_running = True + for pod in all_pods: + if pod.metadata.name not in expected_pod_names: + continue # Continue if pod and all the containers within the # pod are successfully created and running. if pod.status.phase == 'Running' and all( @@ -333,52 +407,38 @@ def _run_function_with_retries(func: Callable, raise -def _set_env_vars_in_pods(namespace: str, context: Optional[str], - new_pods: List): - """Setting environment variables in pods. +@timeline.event +def pre_init(namespace: str, context: Optional[str], new_nodes: List) -> None: + """Pre-initialization step for SkyPilot pods. - Once all containers are ready, we can exec into them and set env vars. - Kubernetes automatically populates containers with critical - environment variables, such as those for discovering services running - in the cluster and CUDA/nvidia environment variables. We need to - make sure these env vars are available in every task and ssh session. - This is needed for GPU support and service discovery. - See https://github.com/skypilot-org/skypilot/issues/2287 for - more details. + This step is run in the pod right after it is created and before the + SkyPilot runtime is setup. - To do so, we capture env vars from the pod's runtime and write them to - /etc/profile.d/, making them available for all users in future - shell sessions. - """ - set_k8s_env_var_cmd = docker_utils.SETUP_ENV_VARS_CMD + This step includes three key steps: - def _set_env_vars_thread(new_pod): - pod_name = new_pod.metadata.name - logger.info(f'{"-"*20}Start: Set up env vars in pod {pod_name!r} ' - f'{"-"*20}') - runner = command_runner.KubernetesCommandRunner( - ((namespace, context), pod_name)) + 1. Privilege check: Checks if the default user has sufficient privilege + to set up the kubernetes instance pod. + 2. SSH setup: Sets up SSH for the pod instance. + 3. Environment variable setup to populate k8s env vars in the pod. - def _run_env_vars_cmd(): - rc, stdout, _ = runner.run(set_k8s_env_var_cmd, - require_outputs=True, - stream_logs=False) - _raise_command_running_error('set env vars', set_k8s_env_var_cmd, - pod_name, rc, stdout) + Make sure commands used in these methods are generic and work + on most base images. E.g., do not use Python, since that may not + be installed by default. - _run_function_with_retries(_run_env_vars_cmd, - f'set env vars in pod {pod_name}') - logger.info(f'{"-"*20}End: Set up env vars in pod {pod_name!r} ' - f'{"-"*20}') + If you run any apt commands, be sure to check if the lock is available. + It is possible the `apt update` run in the pod container args may still + be running. - subprocess_utils.run_in_parallel(_set_env_vars_thread, new_pods, - NUM_THREADS) + Args: + namespace (str): Kubernetes namespace. + context (Optional[str]): Kubernetes context. + new_nodes (List): List of new pod instances. + Raises: + config_lib.KubernetesError: If user privileges are insufficient or + setup fails. + """ -def _check_user_privilege(namespace: str, context: Optional[str], - new_nodes: List) -> None: - # Checks if the default user has sufficient privilege to set up - # the kubernetes instance pod. check_k8s_user_sudo_cmd = ( 'if [ $(id -u) -eq 0 ]; then' # If user is root, create an alias for sudo used in skypilot setup @@ -386,56 +446,67 @@ def _check_user_privilege(namespace: str, context: Optional[str], 'else ' ' if command -v sudo >/dev/null 2>&1; then ' ' timeout 2 sudo -l >/dev/null 2>&1 && echo succeed || ' - f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ); ' + f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ' + f' exit {exceptions.INSUFFICIENT_PRIVILEGES_CODE}; ); ' ' else ' - f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ); ' + f' ( echo {exceptions.INSUFFICIENT_PRIVILEGES_CODE!r}; ' + f' exit {exceptions.INSUFFICIENT_PRIVILEGES_CODE}; ); ' ' fi; ' - 'fi') + 'fi;') + + # Kubernetes automatically populates containers with critical + # environment variables, such as those for discovering services running + # in the cluster and CUDA/nvidia environment variables. We need to + # make sure these env vars are available in every task and ssh session. + # This is needed for GPU support and service discovery. + # See https://github.com/skypilot-org/skypilot/issues/2287 for more details. + # To do so, we capture env vars from the pod's runtime and write them to + # /etc/profile.d/, making them available for all users in future + # shell sessions. + set_k8s_env_var_cmd = docker_utils.SETUP_ENV_VARS_CMD - # This check needs to run on a per-image basis, so running the check on - # any one pod is sufficient. - new_node = new_nodes[0] - pod_name = new_node.metadata.name + check_apt_update_complete_cmd = ( + 'echo "Checking if apt update from container init is complete..."; ' + 'timeout_secs=600; ' + 'start_time=$(date +%s); ' + 'while ! grep -q "Fetched" /tmp/apt-update.log 2>/dev/null; do ' + ' echo "apt update still running. Logs:"; ' + ' cat /tmp/apt-update.log || true; ' + ' current_time=$(date +%s); ' + ' elapsed=$((current_time - start_time)); ' + ' if [ $elapsed -ge $timeout_secs ]; then ' + ' echo "Timed out waiting for apt update"; ' + ' exit 1; ' + ' fi; ' + ' sleep 5; ' + 'done; ' + 'echo "apt update complete."; ') - runner = command_runner.KubernetesCommandRunner( - ((namespace, context), pod_name)) - logger.info(f'{"-"*20}Start: Check user privilege in pod {pod_name!r} ' - f'{"-"*20}') - - def _run_privilege_check(): - rc, stdout, stderr = runner.run(check_k8s_user_sudo_cmd, - require_outputs=True, - separate_stderr=True, - stream_logs=False) - _raise_command_running_error('check user privilege', - check_k8s_user_sudo_cmd, pod_name, rc, - stdout + stderr) - return stdout - - stdout = _run_function_with_retries( - _run_privilege_check, f'check user privilege in pod {pod_name!r}') - - if stdout == str(exceptions.INSUFFICIENT_PRIVILEGES_CODE): - raise config_lib.KubernetesError( - 'Insufficient system privileges detected. ' - 'Ensure the default user has root access or ' - '"sudo" is installed and the user is added to the sudoers ' - 'from the image.') - logger.info(f'{"-"*20}End: Check user privilege in pod {pod_name!r} ' - f'{"-"*20}') - - -def _setup_ssh_in_pods(namespace: str, context: Optional[str], - new_nodes: List) -> None: - # Setting up ssh for the pod instance. This is already setup for - # the jump pod so it does not need to be run for it. - set_k8s_ssh_cmd = ( - 'set -ex; ' + install_ssh_k8s_cmd = ( 'prefix_cmd() ' '{ if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; }; ' 'export DEBIAN_FRONTEND=noninteractive;' - '$(prefix_cmd) apt-get update;' - '$(prefix_cmd) apt install openssh-server rsync -y; ' + 'echo "Installing missing packages..."; ' + 'for i in {1..5}; do ' + ' output=$($(prefix_cmd) apt install openssh-server rsync -y 2>&1); ' + ' rc=$?; ' + ' if [ $rc -eq 0 ]; then ' + ' break; ' + ' fi; ' + ' echo "$output" | grep -qi "could not get lock" || ' + ' grep -qi "Unable to acquire the dpkg frontend lock"; ' + ' if [ $? -eq 0 ]; then ' + ' echo "apt install failed due to lock, retrying. (Attempt $i/5)"; ' + ' sleep 5; ' + ' else ' + ' echo "apt install failed for a non-lock reason: $output"; ' + ' exit $rc; ' + ' fi; ' + 'done; ' + 'if [ $rc -ne 0 ]; then ' + ' echo "apt install failed after 5 attempts due to lock errors."; ' + ' exit $rc; ' + 'fi; ' '$(prefix_cmd) mkdir -p /var/run/sshd; ' '$(prefix_cmd) ' 'sed -i "s/PermitRootLogin prohibit-password/PermitRootLogin yes/" ' @@ -456,24 +527,35 @@ def _setup_ssh_in_pods(namespace: str, context: Optional[str], # See https://www.educative.io/answers/error-mesg-ttyname-failed-inappropriate-ioctl-for-device # pylint: disable=line-too-long '$(prefix_cmd) sed -i "s/mesg n/tty -s \\&\\& mesg n/" ~/.profile;') - def _setup_ssh_thread(new_node): + pre_init_cmd = ('set -ex; ' + check_k8s_user_sudo_cmd + + set_k8s_env_var_cmd + check_apt_update_complete_cmd + + install_ssh_k8s_cmd) + + def _pre_init_thread(new_node): pod_name = new_node.metadata.name + logger.info(f'{"-"*20}Start: Pre-init in pod {pod_name!r} {"-"*20}') runner = command_runner.KubernetesCommandRunner( ((namespace, context), pod_name)) - logger.info(f'{"-"*20}Start: Set up SSH in pod {pod_name!r} {"-"*20}') - def _run_ssh_setup(): - rc, stdout, _ = runner.run(set_k8s_ssh_cmd, - require_outputs=True, - stream_logs=False) - _raise_command_running_error('setup ssh', set_k8s_ssh_cmd, pod_name, - rc, stdout) + # Run the combined pre-init command + rc, stdout, _ = runner.run(pre_init_cmd, + require_outputs=True, + stream_logs=False) + if rc == exceptions.INSUFFICIENT_PRIVILEGES_CODE: + raise config_lib.KubernetesError( + 'Insufficient system privileges detected. ' + 'Ensure the default user has root access or ' + '"sudo" is installed and the user is added to the sudoers ' + 'from the image.') + + op_name = 'pre-init' + _raise_command_running_error(op_name, pre_init_cmd, pod_name, rc, + stdout) - _run_function_with_retries(_run_ssh_setup, - f'setup ssh in pod {pod_name!r}') - logger.info(f'{"-"*20}End: Set up SSH in pod {pod_name!r} {"-"*20}') + logger.info(f'{"-"*20}End: Pre-init in pod {pod_name!r} {"-"*20}') - subprocess_utils.run_in_parallel(_setup_ssh_thread, new_nodes, NUM_THREADS) + # Run pre_init in parallel across all new_nodes + subprocess_utils.run_in_parallel(_pre_init_thread, new_nodes, _NUM_THREADS) def _label_pod(namespace: str, context: Optional[str], pod_name: str, @@ -487,6 +569,7 @@ def _label_pod(namespace: str, context: Optional[str], pod_name: str, _request_timeout=kubernetes.API_TIMEOUT) +@timeline.event def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, context: Optional[str]) -> Any: """Attempts to create a Kubernetes Pod and handle any errors. @@ -546,11 +629,26 @@ def _create_namespaced_pod_with_retries(namespace: str, pod_spec: dict, logger.info('Failed to create Pod without AppArmor annotation: ' f'{retry_exception}') raise retry_exception + # Unlike other error from resource lackage on CPU/GPU/Memory, TPU + # lackage error is raised when pod is attemtped to be created. + # TODO(Doyoung): Update the error message raised with the multi-host + # TPU support. + elif 'Invalid resource requests for google.com/tpu.' in error_message: + extra_message = ('Verify if the cluster has a TPU slice node with ' + 'a topology matching the number of TPU(s) ' + 'requested. Note that multi-host TPU podslices ' + 'are currently not unsupported.') + raise config_lib.KubernetesError( + _lack_resource_msg('TPU', + pod_spec, + details=error_message, + extra_msg=extra_message)) else: # Re-raise the exception if it's a different error raise e +@timeline.event def _create_pods(region: str, cluster_name_on_cloud: str, config: common.ProvisionConfig) -> common.ProvisionRecord: """Create pods based on the config.""" @@ -572,7 +670,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) start_time = time.time() - while (len(terminating_pods) > 0 and + while (terminating_pods and time.time() - start_time < _TIMEOUT_FOR_POD_TERMINATION): logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods. Waiting them to finish: ' @@ -581,7 +679,7 @@ def _create_pods(region: str, cluster_name_on_cloud: str, terminating_pods = kubernetes_utils.filter_pods(namespace, context, tags, ['Terminating']) - if len(terminating_pods) > 0: + if terminating_pods: # If there are still terminating pods, we force delete them. logger.debug(f'run_instances: Found {len(terminating_pods)} ' 'terminating pods still in terminating state after ' @@ -626,32 +724,43 @@ def _create_pods(region: str, cluster_name_on_cloud: str, 'override runtimeClassName in ~/.sky/config.yaml. ' 'For more details, refer to https://skypilot.readthedocs.io/en/latest/reference/config.html') # pylint: disable=line-too-long - needs_gpus = (pod_spec['spec']['containers'][0].get('resources', {}).get( - 'limits', {}).get('nvidia.com/gpu', 0) > 0) + needs_gpus = False + limits = pod_spec['spec']['containers'][0].get('resources', + {}).get('limits') + if limits is not None: + needs_gpus = limits.get(kubernetes_utils.GPU_RESOURCE_KEY, 0) > 0 + + # TPU pods provisioned on GKE use the default containerd runtime. + # Reference: https://cloud.google.com/kubernetes-engine/docs/how-to/migrate-containerd#overview # pylint: disable=line-too-long if nvidia_runtime_exists and needs_gpus: pod_spec['spec']['runtimeClassName'] = 'nvidia' created_pods = {} logger.debug(f'run_instances: calling create_namespaced_pod ' f'(count={to_start_count}).') - for _ in range(to_start_count): - if head_pod_name is None: - pod_spec['metadata']['labels'].update(constants.HEAD_NODE_TAGS) + + def _create_pod_thread(i: int): + pod_spec_copy = copy.deepcopy(pod_spec) + if head_pod_name is None and i == 0: + # First pod should be head if no head exists + pod_spec_copy['metadata']['labels'].update(constants.HEAD_NODE_TAGS) head_selector = head_service_selector(cluster_name_on_cloud) - pod_spec['metadata']['labels'].update(head_selector) - pod_spec['metadata']['name'] = f'{cluster_name_on_cloud}-head' + pod_spec_copy['metadata']['labels'].update(head_selector) + pod_spec_copy['metadata']['name'] = f'{cluster_name_on_cloud}-head' else: - pod_spec['metadata']['labels'].update(constants.WORKER_NODE_TAGS) - pod_uuid = str(uuid.uuid4())[:4] + # Worker pods + pod_spec_copy['metadata']['labels'].update( + constants.WORKER_NODE_TAGS) + pod_uuid = str(uuid.uuid4())[:6] pod_name = f'{cluster_name_on_cloud}-{pod_uuid}' - pod_spec['metadata']['name'] = f'{pod_name}-worker' + pod_spec_copy['metadata']['name'] = f'{pod_name}-worker' # For multi-node support, we put a soft-constraint to schedule # worker pods on different nodes than the head pod. # This is not set as a hard constraint because if different nodes # are not available, we still want to be able to schedule worker # pods on larger nodes which may be able to fit multiple SkyPilot # "nodes". - pod_spec['spec']['affinity'] = { + pod_spec_copy['spec']['affinity'] = { 'podAntiAffinity': { # Set as a soft constraint 'preferredDuringSchedulingIgnoredDuringExecution': [{ @@ -672,15 +781,36 @@ def _create_pods(region: str, cluster_name_on_cloud: str, } } - pod = _create_namespaced_pod_with_retries(namespace, pod_spec, context) + # TPU slice nodes are given a taint, google.com/tpu=present:NoSchedule. + # This is to prevent from non-TPU workloads from being scheduled on TPU + # slice nodes. We need this toleration to allow the pod to be scheduled + # on TPU nodes. + # Reference: https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#how_tpus_work # pylint: disable=line-too-long + tpu_label = kubernetes_utils.GKELabelFormatter.TPU_LABEL_KEY + if tpu_label in config.node_config.get('spec', + {}).get('nodeSelector', {}): + tpu_toleration = { + 'key': kubernetes_utils.TPU_RESOURCE_KEY, + 'operator': 'Equal', + 'value': 'present', + 'effect': 'NoSchedule' + } + pod_spec_copy['spec']['tolerations'] = [tpu_toleration] + + return _create_namespaced_pod_with_retries(namespace, pod_spec_copy, + context) + + # Create pods in parallel + pods = subprocess_utils.run_in_parallel(_create_pod_thread, + range(to_start_count), _NUM_THREADS) + + # Process created pods + for pod in pods: created_pods[pod.metadata.name] = pod - if head_pod_name is None: + if head_pod_name is None and pod.metadata.labels.get( + constants.TAG_RAY_NODE_KIND) == 'head': head_pod_name = pod.metadata.name - wait_pods_dict = kubernetes_utils.filter_pods(namespace, context, tags, - ['Pending']) - wait_pods = list(wait_pods_dict.values()) - networking_mode = network_utils.get_networking_mode( config.provider_config.get('networking_mode')) if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: @@ -689,57 +819,24 @@ def _create_pods(region: str, cluster_name_on_cloud: str, ssh_jump_pod_name = pod_spec['metadata']['labels']['skypilot-ssh-jump'] jump_pod = kubernetes.core_api(context).read_namespaced_pod( ssh_jump_pod_name, namespace) - wait_pods.append(jump_pod) + pods.append(jump_pod) provision_timeout = provider_config['timeout'] wait_str = ('indefinitely' if provision_timeout < 0 else f'for {provision_timeout}s') logger.debug(f'run_instances: waiting {wait_str} for pods to schedule and ' - f'run: {list(wait_pods_dict.keys())}') + f'run: {[pod.metadata.name for pod in pods]}') # Wait until the pods are scheduled and surface cause for error # if there is one - _wait_for_pods_to_schedule(namespace, context, wait_pods, provision_timeout) + _wait_for_pods_to_schedule(namespace, context, pods, provision_timeout) # Wait until the pods and their containers are up and running, and # fail early if there is an error logger.debug(f'run_instances: waiting for pods to be running (pulling ' - f'images): {list(wait_pods_dict.keys())}') - _wait_for_pods_to_run(namespace, context, wait_pods) + f'images): {[pod.metadata.name for pod in pods]}') + _wait_for_pods_to_run(namespace, context, pods) logger.debug(f'run_instances: all pods are scheduled and running: ' - f'{list(wait_pods_dict.keys())}') - - running_pods = kubernetes_utils.filter_pods(namespace, context, tags, - ['Running']) - initialized_pods = kubernetes_utils.filter_pods(namespace, context, { - TAG_POD_INITIALIZED: 'true', - **tags - }, ['Running']) - uninitialized_pods = { - pod_name: pod - for pod_name, pod in running_pods.items() - if pod_name not in initialized_pods - } - if len(uninitialized_pods) > 0: - logger.debug(f'run_instances: Initializing {len(uninitialized_pods)} ' - f'pods: {list(uninitialized_pods.keys())}') - uninitialized_pods_list = list(uninitialized_pods.values()) - - # Setup SSH and environment variables in pods. - # Make sure commands used in these methods are generic and work - # on most base images. E.g., do not use Python, since that may not - # be installed by default. - _check_user_privilege(namespace, context, uninitialized_pods_list) - _setup_ssh_in_pods(namespace, context, uninitialized_pods_list) - _set_env_vars_in_pods(namespace, context, uninitialized_pods_list) - - for pod in uninitialized_pods.values(): - _label_pod(namespace, - context, - pod.metadata.name, - label={ - TAG_POD_INITIALIZED: 'true', - **pod.metadata.labels - }) + f'{[pod.metadata.name for pod in pods]}') assert head_pod_name is not None, 'head_instance_id should not be None' return common.ProvisionRecord( @@ -782,11 +879,6 @@ def _terminate_node(namespace: str, context: Optional[str], pod_name: str) -> None: """Terminate a pod.""" logger.debug('terminate_instances: calling delete_namespaced_pod') - try: - kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, pod_name) - except Exception as e: # pylint: disable=broad-except - logger.warning('terminate_instances: Error occurred when analyzing ' - f'SSH Jump pod: {e}') try: kubernetes.core_api(context).delete_namespaced_service( pod_name, namespace, _request_timeout=config_lib.DELETION_TIMEOUT) @@ -823,6 +915,18 @@ def terminate_instances( } pods = kubernetes_utils.filter_pods(namespace, context, tag_filters, None) + # Clean up the SSH jump pod if in use + networking_mode = network_utils.get_networking_mode( + provider_config.get('networking_mode')) + if networking_mode == kubernetes_enums.KubernetesNetworkingMode.NODEPORT: + pod_name = list(pods.keys())[0] + try: + kubernetes_utils.clean_zombie_ssh_jump_pod(namespace, context, + pod_name) + except Exception as e: # pylint: disable=broad-except + logger.warning('terminate_instances: Error occurred when analyzing ' + f'SSH Jump pod: {e}') + def _is_head(pod) -> bool: return pod.metadata.labels[constants.TAG_RAY_NODE_KIND] == 'head' @@ -835,7 +939,7 @@ def _terminate_pod_thread(pod_info): # Run pod termination in parallel subprocess_utils.run_in_parallel(_terminate_pod_thread, pods.items(), - NUM_THREADS) + _NUM_THREADS) def get_cluster_info( diff --git a/sky/provision/kubernetes/utils.py b/sky/provision/kubernetes/utils.py index 0156c4d1091..7442c9be7a6 100644 --- a/sky/provision/kubernetes/utils.py +++ b/sky/provision/kubernetes/utils.py @@ -28,6 +28,7 @@ from sky.utils import env_options from sky.utils import kubernetes_enums from sky.utils import schemas +from sky.utils import timeline from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -36,7 +37,6 @@ # TODO(romilb): Move constants to constants.py DEFAULT_NAMESPACE = 'default' -IN_CLUSTER_REGION = 'in-cluster' DEFAULT_SERVICE_ACCOUNT_NAME = 'skypilot-service-account' @@ -48,10 +48,18 @@ 'T': 2**40, 'P': 2**50, } -NO_GPU_HELP_MESSAGE = ('If your cluster contains GPUs, make sure ' - 'nvidia.com/gpu resource is available on the nodes and ' - 'the node labels for identifying GPUs ' - '(e.g., skypilot.co/accelerator) are setup correctly. ') + +# The resource keys used by Kubernetes to track NVIDIA GPUs and Google TPUs on +# nodes. These keys are typically used in the node's status.allocatable +# or status.capacity fields to indicate the available resources on the node. +GPU_RESOURCE_KEY = 'nvidia.com/gpu' +TPU_RESOURCE_KEY = 'google.com/tpu' + +NO_ACCELERATOR_HELP_MESSAGE = ( + 'If your cluster contains GPUs or TPUs, make sure ' + f'{GPU_RESOURCE_KEY} or {TPU_RESOURCE_KEY} resource is available ' + 'on the nodes and the node labels for identifying GPUs/TPUs ' + '(e.g., skypilot.co/accelerator) are setup correctly. ') KUBERNETES_AUTOSCALER_NOTE = ( 'Note: Kubernetes cluster autoscaling is enabled. ' @@ -74,6 +82,17 @@ PORT_FORWARD_PROXY_CMD_PATH = ('~/.sky/kubernetes-port-forward-proxy-command-' f'v{PORT_FORWARD_PROXY_CMD_VERSION}.sh') +# Mapping used to get generation for TPU accelerator name. +# https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#run +GKE_TPU_ACCELERATOR_TO_GENERATION = { + 'tpu-v4-podslice': 'v4', + # Only Single-host v5e TPU configurations are allowed. + 'tpu-v5-lite-device': 'v5e', + # Multi-host compatible v5e TPU configurations allowed. + 'tpu-v5-lite-podslice': 'v5e', + 'tpu-v5p-slice': 'v5p', +} + POD_STATUSES = { 'Pending', 'Running', 'Succeeded', 'Failed', 'Unknown', 'Terminating' } @@ -96,15 +115,25 @@ class GPULabelFormatter: """ @classmethod - def get_label_key(cls) -> str: + def get_label_key(cls, accelerator: Optional[str] = None) -> str: """Returns the label key for GPU type used by the Kubernetes cluster""" raise NotImplementedError + @classmethod + def get_label_keys(cls) -> List[str]: + """Returns a list of label keys for GPU used by Kubernetes cluster.""" + raise NotImplementedError + @classmethod def get_label_value(cls, accelerator: str) -> str: """Given a GPU type, returns the label value to be used""" raise NotImplementedError + @classmethod + def match_label_key(cls, label_key: str) -> bool: + """Checks if the given label key matches the formatter's label keys""" + raise NotImplementedError + @classmethod def get_accelerator_from_label_value(cls, value: str) -> str: """Given a label value, returns the GPU type""" @@ -126,10 +155,11 @@ def validate_label_value(cls, value: str) -> Tuple[bool, str]: def get_gke_accelerator_name(accelerator: str) -> str: - """Returns the accelerator name for GKE clusters + """Returns the accelerator name for GKE clusters. Uses the format - nvidia-tesla-. - A100-80GB, H100-80GB and L4 are an exception. They use nvidia-. + A100-80GB, H100-80GB, L4 are an exception. They use nvidia-. + TPU types are an exception as well keeping the given name. """ if accelerator == 'H100': # H100 is named as H100-80GB in GKE. @@ -138,6 +168,8 @@ def get_gke_accelerator_name(accelerator: str) -> str: # A100-80GB, L4, H100-80GB and H100-MEGA-80GB # have a different name pattern. return 'nvidia-{}'.format(accelerator.lower()) + elif accelerator.startswith('tpu-'): + return accelerator else: return 'nvidia-tesla-{}'.format(accelerator.lower()) @@ -152,15 +184,23 @@ class SkyPilotLabelFormatter(GPULabelFormatter): LABEL_KEY = 'skypilot.co/accelerator' @classmethod - def get_label_key(cls) -> str: + def get_label_key(cls, accelerator: Optional[str] = None) -> str: return cls.LABEL_KEY + @classmethod + def get_label_keys(cls) -> List[str]: + return [cls.LABEL_KEY] + @classmethod def get_label_value(cls, accelerator: str) -> str: # For SkyPilot formatter, we use the accelerator str directly. # See sky.utils.kubernetes.gpu_labeler. return accelerator.lower() + @classmethod + def match_label_key(cls, label_key: str) -> bool: + return label_key == cls.LABEL_KEY + @classmethod def get_accelerator_from_label_value(cls, value: str) -> str: return value.upper() @@ -184,13 +224,21 @@ class CoreWeaveLabelFormatter(GPULabelFormatter): LABEL_KEY = 'gpu.nvidia.com/class' @classmethod - def get_label_key(cls) -> str: + def get_label_key(cls, accelerator: Optional[str] = None) -> str: return cls.LABEL_KEY + @classmethod + def get_label_keys(cls) -> List[str]: + return [cls.LABEL_KEY] + @classmethod def get_label_value(cls, accelerator: str) -> str: return accelerator.upper() + @classmethod + def match_label_key(cls, label_key: str) -> bool: + return label_key == cls.LABEL_KEY + @classmethod def get_accelerator_from_label_value(cls, value: str) -> str: return value @@ -203,11 +251,28 @@ class GKELabelFormatter(GPULabelFormatter): label, which is used to identify the GPU type. """ - LABEL_KEY = 'cloud.google.com/gke-accelerator' + GPU_LABEL_KEY = 'cloud.google.com/gke-accelerator' + TPU_LABEL_KEY = 'cloud.google.com/gke-tpu-accelerator' + ACCELERATOR_COUNT_LABEL_KEY = 'cloud.google.com/gke-accelerator-count' + TPU_TOPOLOGY_LABEL_KEY = 'cloud.google.com/gke-tpu-topology' @classmethod - def get_label_key(cls) -> str: - return cls.LABEL_KEY + def get_label_key(cls, accelerator: Optional[str] = None) -> str: + if accelerator is not None and accelerator.startswith('tpu-'): + return cls.TPU_LABEL_KEY + return cls.GPU_LABEL_KEY + + @classmethod + def get_label_keys(cls) -> List[str]: + return [cls.GPU_LABEL_KEY, cls.TPU_LABEL_KEY] + + @classmethod + def match_label_key(cls, label_key: str) -> bool: + return label_key in cls.get_label_keys() + + @classmethod + def get_tpu_topology_label_key(cls) -> str: + return cls.TPU_TOPOLOGY_LABEL_KEY @classmethod def get_label_value(cls, accelerator: str) -> str: @@ -225,6 +290,8 @@ def get_accelerator_from_label_value(cls, value: str) -> str: # to distinguish between a3-high and a3-mega instances return 'H100' return acc + elif is_tpu_on_gke(value): + return value else: raise ValueError( f'Invalid accelerator name in GKE cluster: {value}') @@ -248,9 +315,13 @@ class GFDLabelFormatter(GPULabelFormatter): LABEL_KEY = 'nvidia.com/gpu.product' @classmethod - def get_label_key(cls) -> str: + def get_label_key(cls, accelerator: Optional[str] = None) -> str: return cls.LABEL_KEY + @classmethod + def get_label_keys(cls) -> List[str]: + return [cls.LABEL_KEY] + @classmethod def get_label_value(cls, accelerator: str) -> str: """An accelerator can map to many Nvidia GFD labels @@ -258,6 +329,10 @@ def get_label_value(cls, accelerator: str) -> str: As a result, we do not support get_label_value for GFDLabelFormatter.""" raise NotImplementedError + @classmethod + def match_label_key(cls, label_key: str) -> bool: + return label_key == cls.LABEL_KEY + @classmethod def get_accelerator_from_label_value(cls, value: str) -> str: """Searches against a canonical list of NVIDIA GPUs and pattern @@ -335,10 +410,9 @@ def detect_gpu_label_formatter( # Check if the node labels contain any of the GPU label prefixes for lf in LABEL_FORMATTER_REGISTRY: - label_key = lf.get_label_key() for _, label_list in node_labels.items(): for label, _ in label_list: - if label.startswith(label_key): + if lf.match_label_key(label): label_formatter = lf() return label_formatter, node_labels @@ -346,24 +420,28 @@ def detect_gpu_label_formatter( @functools.lru_cache(maxsize=10) -def detect_gpu_resource(context: Optional[str]) -> Tuple[bool, Set[str]]: - """Checks if the Kubernetes cluster has nvidia.com/gpu resource. +def detect_accelerator_resource( + context: Optional[str]) -> Tuple[bool, Set[str]]: + """Checks if the Kubernetes cluster has GPU/TPU resource. - If nvidia.com/gpu resource is missing, that typically means that the - Kubernetes cluster does not have GPUs or the nvidia GPU operator and/or - device drivers are not installed. + Two types of accelerator resources are available which are each checked + with nvidia.com/gpu and google.com/tpu. If nvidia.com/gpu resource is + missing, that typically means that the Kubernetes cluster does not have + GPUs or the nvidia GPU operator and/or device drivers are not installed. Returns: - bool: True if the cluster has nvidia.com/gpu resource, False otherwise. + bool: True if the cluster has GPU_RESOURCE_KEY or TPU_RESOURCE_KEY + resource, False otherwise. """ # Get the set of resources across all nodes cluster_resources: Set[str] = set() nodes = get_kubernetes_nodes(context) for node in nodes: cluster_resources.update(node.status.allocatable.keys()) - has_gpu = 'nvidia.com/gpu' in cluster_resources + has_accelerator = (GPU_RESOURCE_KEY in cluster_resources or + TPU_RESOURCE_KEY in cluster_resources) - return has_gpu, cluster_resources + return has_accelerator, cluster_resources @functools.lru_cache(maxsize=10) @@ -451,16 +529,52 @@ def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType', 'Maximum resources found on a single node: ' f'{max_cpu} CPUs, {common_utils.format_float(max_mem)}G Memory') + def check_tpu_fits(candidate_instance_type: 'KubernetesInstanceType', + node_list: List[Any]) -> Tuple[bool, Optional[str]]: + """Checks if the instance fits on the cluster based on requested TPU. + + It checks if the TPU type and count on each node match the required + number of TPU chips for the instance. In the case of multi-host TPU + podslice, the function ensures that the number of TPU chips on a single + node (node_tpu_chip_count) and the total TPU chips across the entire + podslice (topology_chip_count) are correctly handled. + """ + acc_type = candidate_instance_type.accelerator_type + acc_count = candidate_instance_type.accelerator_count + tpu_list_in_cluster = [] + for node in node_list: + if acc_type == node.metadata.labels[ + GKELabelFormatter.TPU_LABEL_KEY]: + # TODO(Doyoung): Update the logic when adding support for + # multi-host TPUs. + if is_multi_host_tpu(node.metadata.labels): + continue + node_tpu_chip_count = int(node.metadata.labels[ + GKELabelFormatter.ACCELERATOR_COUNT_LABEL_KEY]) + tpu_type = f'{acc_type}:{node_tpu_chip_count}' + tpu_list_in_cluster.append(tpu_type) + if node_tpu_chip_count == acc_count: + return True, None + tpu_list_in_cluster_str = ','.join(tpu_list_in_cluster) + # TODO(Doyoung): Update the error message raised with the multi-host + # TPU support. + return False, ('Requested TPU type was not found in the cluster. TPU ' + 'types found in the cluster: ' + f'{tpu_list_in_cluster_str}. Note that multi-host TPU ' + 'podslices are currently not unsupported.') + nodes = get_kubernetes_nodes(context) k8s_instance_type = KubernetesInstanceType.\ from_instance_type(instance) acc_type = k8s_instance_type.accelerator_type + acc_count = k8s_instance_type.accelerator_count if acc_type is not None: - # If GPUs are requested, check if GPU type is available, and if so, - # check if CPU and memory requirements on the specific node are met. + # If GPU/TPUs are requested, check if GPU/TPU type is available, and + # if so, check if CPU and memory requirements on the specific node are + # met. try: - gpu_label_key, gpu_label_val = get_gpu_label_key_value( - context, acc_type) + gpu_label_key, gpu_label_val, _, _ = ( + get_accelerator_label_key_value(context, acc_type, acc_count)) except exceptions.ResourcesUnavailableError as e: # If GPU not found, return empty list and error message. return False, str(e) @@ -470,6 +584,13 @@ def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType', node.metadata.labels[gpu_label_key] == gpu_label_val ] assert len(gpu_nodes) > 0, 'GPU nodes not found' + if is_tpu_on_gke(acc_type): + # If requested accelerator is a TPU type, check if the cluster + # has sufficient TPU resource to meet the requirement. + fits, reason = check_tpu_fits(k8s_instance_type, gpu_nodes) + if reason is not None: + return fits, reason + candidate_nodes = gpu_nodes not_fit_reason_prefix = ( f'GPU nodes with {acc_type} do not have ' @@ -481,7 +602,7 @@ def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType', f'CPU (> {k8s_instance_type.cpus} CPUs) ' 'and/or memory ' f'(> {k8s_instance_type.memory} G). ') - # Check if CPU and memory requirements are met on at least one + # Check if CPU and memory requirements are met on at least one # candidate node. fits, reason = check_cpu_mem_fits(k8s_instance_type, candidate_nodes) if not fits: @@ -492,25 +613,33 @@ def check_cpu_mem_fits(candidate_instance_type: 'KubernetesInstanceType', return fits, reason -def get_gpu_label_key_value(context: Optional[str], - acc_type: str, - check_mode=False) -> Tuple[str, str]: - """Returns the label key and value for the given GPU type. +def get_accelerator_label_key_value( + context: Optional[str], + acc_type: str, + acc_count: Optional[int], + check_mode=False +) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + """Returns the label key and value for the given GPU/TPU type. Args: - acc_type: The GPU type required by the task. - check_mode: If True, only checks if the cluster has GPU resources and - labels are setup on the cluster. acc_type is ignore does not return - the label key and value. Useful for checking if GPUs are configured - correctly on the cluster without explicitly requesting a acc_type. + acc_type: The GPU/TPU type required by the task. + acc_count: Number of GPU/TPUs required by the task. + check_mode: If True, only checks if the cluster has GPU/TPU resources + and labels are setup on the cluster. acc_type is ignore does not + return the label key and value. Useful for checking if GPUs are + configured correctly on the cluster without explicitly requesting + a acc_type. Returns: - A tuple of the label key and value. Returns empty strings if check_mode - is True. + A tuple of the accelerator label key, value, topology label key, and + topology value. The topology label key and value are populated only if + the requested accelerator type is TPU. Returns None if check_mode is + True. Raises: ResourcesUnavailableError: Can be raised from the following conditions: - - The cluster does not have GPU resources (nvidia.com/gpu) - - The cluster does not have GPU labels setup correctly - - The cluster doesn't have any nodes with acc_type GPU + - The cluster does not have GPU/TPU resources + (nvidia.com/gpu, google.com/tpu) + - The cluster does not have GPU/TPU labels setup correctly + - The cluster doesn't have any nodes with acc_type GPU/TPU """ # Check if the cluster has GPU resources # TODO(romilb): This assumes the accelerator is a nvidia GPU. We @@ -529,13 +658,14 @@ def get_gpu_label_key_value(context: Optional[str], # If check mode is enabled and autoscaler is set, we can return # early since we assume the cluster autoscaler will handle GPU # node provisioning. - return '', '' + return None, None, None, None formatter = AUTOSCALER_TO_LABEL_FORMATTER.get(autoscaler_type) assert formatter is not None, ('Unsupported autoscaler type:' f' {autoscaler_type}') - return formatter.get_label_key(), formatter.get_label_value(acc_type) + return formatter.get_label_key(acc_type), formatter.get_label_value( + acc_type), None, None - has_gpus, cluster_resources = detect_gpu_resource(context) + has_gpus, cluster_resources = detect_accelerator_resource(context) if has_gpus: # Check if the cluster has GPU labels setup correctly label_formatter, node_labels = \ @@ -544,8 +674,10 @@ def get_gpu_label_key_value(context: Optional[str], # If none of the GPU labels from LABEL_FORMATTER_REGISTRY are # detected, raise error with ux_utils.print_exception_no_traceback(): - supported_formats = ', '.join( - [f.get_label_key() for f in LABEL_FORMATTER_REGISTRY]) + supported_formats = ', '.join([ + key for f in LABEL_FORMATTER_REGISTRY + for key in f.get_label_keys() + ]) suffix = '' if env_options.Options.SHOW_DEBUG_INFO.get(): suffix = f' Found node labels: {node_labels}' @@ -561,7 +693,7 @@ def get_gpu_label_key_value(context: Optional[str], # correctly setup and will behave as expected. for node_name, label_list in node_labels.items(): for label, value in label_list: - if label == label_formatter.get_label_key(): + if label_formatter.match_label_key(label): is_valid, reason = label_formatter.validate_label_value( value) if not is_valid: @@ -571,8 +703,7 @@ def get_gpu_label_key_value(context: Optional[str], if check_mode: # If check mode is enabled and we reached so far, we can # conclude that the cluster is setup correctly and return. - return '', '' - k8s_acc_label_key = label_formatter.get_label_key() + return None, None, None, None # Search in node_labels to see if any node has the requested # GPU type. # Note - this only checks if the label is available on a @@ -580,11 +711,38 @@ def get_gpu_label_key_value(context: Optional[str], # quantity is available since that is dynamic and can change # during scheduling. for node_name, label_list in node_labels.items(): + node_metadata_labels = dict(label_list) + # TODO(Doyoung): Update the logic when adding support for + # multi-host TPUs. + if is_multi_host_tpu(node_metadata_labels): + continue for label, value in label_list: - if (label == k8s_acc_label_key and + if (label_formatter.match_label_key(label) and label_formatter.get_accelerator_from_label_value( value) == acc_type): - return label, value + if is_tpu_on_gke(acc_type): + assert isinstance(label_formatter, + GKELabelFormatter) + if node_metadata_labels.get( + label_formatter.TPU_LABEL_KEY) == acc_type: + topology_label_key = ( + label_formatter.TPU_TOPOLOGY_LABEL_KEY) + topology_value = node_metadata_labels.get( + topology_label_key) + assert topology_value is not None + tpu_topology_chip_count = reduce_tpu_topology( + topology_value) + # For single-host TPUs, there aren't multiple + # different topologies that maps to identical + # number of TPU chips. + if tpu_topology_chip_count == acc_count: + return (label, value, topology_label_key, + topology_value) + else: + continue + else: + return label, value, None, None + # If no node is found with the requested acc_type, raise error with ux_utils.print_exception_no_traceback(): suffix = '' @@ -592,15 +750,19 @@ def get_gpu_label_key_value(context: Optional[str], all_labels = [] for node_name, label_list in node_labels.items(): all_labels.extend(label_list) - gpus_available = set( - v for k, v in all_labels if k == k8s_acc_label_key) - suffix = f' Available GPUs on the cluster: {gpus_available}' + acc_available = set(v for k, v in all_labels + if label_formatter.match_label_key(k)) + suffix = (' Available GPU/TPUs on the cluster: ' + f'{acc_available}') + # TODO(Doyoung): Update the error message raised with the + # multi-host TPU support. raise exceptions.ResourcesUnavailableError( 'Could not find any node in the Kubernetes cluster ' - f'with {acc_type} GPU. Please ensure at least ' - f'one node in the cluster has {acc_type} GPU and node ' - 'labels are setup correctly. ' - f'Please refer to the documentation for more. {suffix}') + f'with {acc_type}. Please ensure at least one node in the ' + f'cluster has {acc_type} and node labels are setup ' + 'correctly. Please refer to the documentration for more. ' + f'{suffix}. Note that multi-host TPU podslices are ' + 'currently not unsupported.') else: # If GPU resources are not detected, raise error with ux_utils.print_exception_no_traceback(): @@ -609,13 +771,14 @@ def get_gpu_label_key_value(context: Optional[str], suffix = (' Available resources on the cluster: ' f'{cluster_resources}') raise exceptions.ResourcesUnavailableError( - 'Could not detect GPU resources (`nvidia.com/gpu`) in ' - 'Kubernetes cluster. If this cluster contains GPUs, please ' - 'ensure GPU drivers are installed on the node. Check if the ' - 'GPUs are setup correctly by running `kubectl describe nodes` ' - 'and looking for the nvidia.com/gpu resource. ' - 'Please refer to the documentation on how ' - f'to set up GPUs.{suffix}') + f'Could not detect GPU/TPU resources ({GPU_RESOURCE_KEY!r} or ' + f'{TPU_RESOURCE_KEY!r}) in Kubernetes cluster. If this cluster' + ' contains GPUs, please ensure GPU drivers are installed on ' + 'the node. Check if the GPUs are setup correctly by running ' + '`kubectl describe nodes` and looking for the ' + f'{GPU_RESOURCE_KEY!r} or {TPU_RESOURCE_KEY!r} resource. ' + 'Please refer to the documentation on how to set up GPUs.' + f'{suffix}') def get_head_ssh_port(cluster_name: str, namespace: str, @@ -710,7 +873,10 @@ def check_credentials(context: Optional[str], # provider if their cluster GPUs are not setup correctly. gpu_msg = '' try: - _, _ = get_gpu_label_key_value(context, acc_type='', check_mode=True) + get_accelerator_label_key_value(context, + acc_type='', + acc_count=0, + check_mode=True) except exceptions.ResourcesUnavailableError as e: # If GPUs are not available, we return cluster as enabled (since it can # be a CPU-only cluster) but we also return the exception message which @@ -754,6 +920,9 @@ def is_kubeconfig_exec_auth( str: Error message if exec-based authentication is used, None otherwise """ k8s = kubernetes.kubernetes + if context == kubernetes.in_cluster_context_name(): + # If in-cluster config is used, exec-based auth is not used. + return False, None try: k8s.config.load_kube_config() except kubernetes.config_exception(): @@ -836,30 +1005,34 @@ def is_incluster_config_available() -> bool: return os.path.exists('/var/run/secrets/kubernetes.io/serviceaccount/token') -def get_all_kube_config_context_names() -> List[Optional[str]]: - """Get all kubernetes context names from the kubeconfig file. +def get_all_kube_context_names() -> List[str]: + """Get all kubernetes context names available in the environment. + + Fetches context names from the kubeconfig file and in-cluster auth, if any. - If running in-cluster, returns [None] to indicate in-cluster config. + If running in-cluster and IN_CLUSTER_CONTEXT_NAME_ENV_VAR is not set, + returns the default in-cluster kubernetes context name. We should not cache the result of this function as the admin policy may update the contexts. Returns: List[Optional[str]]: The list of kubernetes context names if - available, an empty list otherwise. If running in-cluster, - returns [None] to indicate in-cluster config. + available, an empty list otherwise. """ k8s = kubernetes.kubernetes + context_names = [] try: all_contexts, _ = k8s.config.list_kube_config_contexts() # all_contexts will always have at least one context. If kubeconfig # does not have any contexts defined, it will raise ConfigException. - return [context['name'] for context in all_contexts] + context_names = [context['name'] for context in all_contexts] except k8s.config.config_exception.ConfigException: - # If running in cluster, return [None] to indicate in-cluster config - if is_incluster_config_available(): - return [None] - return [] + # If no config found, continue + pass + if is_incluster_config_available(): + context_names.append(kubernetes.in_cluster_context_name()) + return context_names @functools.lru_cache() @@ -872,11 +1045,15 @@ def get_kube_config_context_namespace( the default namespace. """ k8s = kubernetes.kubernetes - # Get namespace if using in-cluster config ns_path = '/var/run/secrets/kubernetes.io/serviceaccount/namespace' - if os.path.exists(ns_path): - with open(ns_path, encoding='utf-8') as f: - return f.read().strip() + # If using in-cluster context, get the namespace from the service account + # namespace file. Uses the same logic as adaptors.kubernetes._load_config() + # to stay consistent with in-cluster config loading. + if (context_name == kubernetes.in_cluster_context_name() or + context_name is None): + if os.path.exists(ns_path): + with open(ns_path, encoding='utf-8') as f: + return f.read().strip() # If not in-cluster, get the namespace from kubeconfig try: contexts, current_context = k8s.config.list_kube_config_contexts() @@ -963,7 +1140,11 @@ def name(self) -> str: name = (f'{common_utils.format_float(self.cpus)}CPU--' f'{common_utils.format_float(self.memory)}GB') if self.accelerator_count: - name += f'--{self.accelerator_count}{self.accelerator_type}' + # Replace spaces with underscores in accelerator type to make it a + # valid logical instance type name. + assert self.accelerator_type is not None, self.accelerator_count + acc_name = self.accelerator_type.replace(' ', '_') + name += f'--{self.accelerator_count}{acc_name}' return name @staticmethod @@ -994,7 +1175,9 @@ def _parse_instance_type( accelerator_type = match.group('accelerator_type') if accelerator_count: accelerator_count = int(accelerator_count) - accelerator_type = str(accelerator_type) + # This is to revert the accelerator types with spaces back to + # the original format. + accelerator_type = str(accelerator_type).replace('_', ' ') else: accelerator_count = None accelerator_type = None @@ -1527,6 +1710,8 @@ def merge_dicts(source: Dict[Any, Any], destination: Dict[Any, Any]): else: destination[key].extend(value) else: + if destination is None: + destination = {} destination[key] = value @@ -1787,7 +1972,7 @@ def __init__(self, obj): class KubernetesNodeInfo: """Dataclass to store Kubernetes node information.""" name: str - gpu_type: Optional[str] + accelerator_type: Optional[str] # Resources available on the node. E.g., {'nvidia.com/gpu': '2'} total: Dict[str, int] free: Dict[str, int] @@ -1801,52 +1986,71 @@ def get_kubernetes_node_info( number of GPUs available on the node and the number of free GPUs on the node. + If the user does not have sufficient permissions to list pods in all + namespaces, the function will return free GPUs as -1. + Returns: Dict[str, KubernetesNodeInfo]: Dictionary containing the node name as key and the KubernetesNodeInfo object as value """ nodes = get_kubernetes_nodes(context) # Get the pods to get the real-time resource usage - pods = get_all_pods_in_kubernetes_cluster(context) + try: + pods = get_all_pods_in_kubernetes_cluster(context) + except kubernetes.api_exception() as e: + if e.status == 403: + pods = None + else: + raise - label_formatter, _ = detect_gpu_label_formatter(context) - if not label_formatter: + lf, _ = detect_gpu_label_formatter(context) + if not lf: label_key = None else: - label_key = label_formatter.get_label_key() + label_keys = lf.get_label_keys() node_info_dict: Dict[str, KubernetesNodeInfo] = {} - for node in nodes: - allocated_qty = 0 - if label_formatter is not None and label_key in node.metadata.labels: - accelerator_name = label_formatter.get_accelerator_from_label_value( - node.metadata.labels.get(label_key)) - else: - accelerator_name = None - - accelerator_count = int(node.status.allocatable.get( - 'nvidia.com/gpu', 0)) - - for pod in pods: - # Get all the pods running on the node - if (pod.spec.node_name == node.metadata.name and - pod.status.phase in ['Running', 'Pending']): - # Iterate over all the containers in the pod and sum the - # GPU requests - for container in pod.spec.containers: - if container.resources.requests: - allocated_qty += int( - container.resources.requests.get( - 'nvidia.com/gpu', 0)) - - accelerators_available = accelerator_count - allocated_qty - - node_info_dict[node.metadata.name] = KubernetesNodeInfo( - name=node.metadata.name, - gpu_type=accelerator_name, - total={'nvidia.com/gpu': int(accelerator_count)}, - free={'nvidia.com/gpu': int(accelerators_available)}) + for label_key in label_keys: + for node in nodes: + allocated_qty = 0 + if lf is not None and label_key in node.metadata.labels: + accelerator_name = lf.get_accelerator_from_label_value( + node.metadata.labels.get(label_key)) + else: + accelerator_name = None + + accelerator_count = get_node_accelerator_count( + node.status.allocatable) + + if pods is None: + accelerators_available = -1 + + else: + for pod in pods: + # Get all the pods running on the node + if (pod.spec.node_name == node.metadata.name and + pod.status.phase in ['Running', 'Pending']): + # Iterate over all the containers in the pod and sum the + # GPU requests + for container in pod.spec.containers: + if container.resources.requests: + allocated_qty += get_node_accelerator_count( + container.resources.requests) + + accelerators_available = accelerator_count - allocated_qty + + # Exclude multi-host TPUs from being processed. + # TODO(Doyoung): Remove the logic when adding support for + # multi-host TPUs. + if is_multi_host_tpu(node.metadata.labels): + continue + + node_info_dict[node.metadata.name] = KubernetesNodeInfo( + name=node.metadata.name, + accelerator_type=accelerator_name, + total={'accelerator_count': int(accelerator_count)}, + free={'accelerators_available': int(accelerators_available)}) return node_info_dict @@ -1866,6 +2070,7 @@ def get_namespace_from_config(provider_config: Dict[str, Any]) -> str: get_kube_config_context_namespace(context)) +@timeline.event def filter_pods(namespace: str, context: Optional[str], tag_filters: Dict[str, str], @@ -1996,9 +2201,9 @@ def set_autodown_annotations(handle: 'backends.CloudVmRayResourceHandle', def get_context_from_config(provider_config: Dict[str, Any]) -> Optional[str]: context = provider_config.get('context', get_current_kube_config_context_name()) - if context == IN_CLUSTER_REGION: - # If the context (also used as the region) is set to IN_CLUSTER_REGION - # we need to use in-cluster auth. + if context == kubernetes.in_cluster_context_name(): + # If the context (also used as the region) is in-cluster, we need to + # we need to use in-cluster auth by setting the context to None. context = None return context @@ -2028,6 +2233,80 @@ def get_skypilot_pods(context: Optional[str] = None) -> List[Any]: return pods +def is_tpu_on_gke(accelerator: str) -> bool: + """Determins if the given accelerator is a TPU supported on GKE.""" + return accelerator in GKE_TPU_ACCELERATOR_TO_GENERATION + + +def get_node_accelerator_count(attribute_dict: dict) -> int: + """Retrieves the count of accelerators from a node's resource dictionary. + + This method checks the node's allocatable resources or the accelerators + already deployed on the node, using pod objects that describe resource + requests. + + Args: + attribute_dict: Containing resource information from a node, such as + allocatable or requested resources. + + Returns: + Number of accelerators allocated or available from the node. If no + resource is found, it returns 0. + """ + assert not (GPU_RESOURCE_KEY in attribute_dict and + TPU_RESOURCE_KEY in attribute_dict) + if GPU_RESOURCE_KEY in attribute_dict: + return int(attribute_dict[GPU_RESOURCE_KEY]) + elif TPU_RESOURCE_KEY in attribute_dict: + return int(attribute_dict[TPU_RESOURCE_KEY]) + return 0 + + +def reduce_tpu_topology(topology: str) -> int: + """Computes the number of TPU chips from its topology string.""" + chip_dimensions = [int(chip_count) for chip_count in topology.split('x')] + # tpu_topology_chip_count represents the total number of TPU chips in the + # entire podslice, whether it is a single-host or multi-host TPU podslice. + tpu_topology_chip_count = functools.reduce(lambda x, y: x * y, + chip_dimensions) + return tpu_topology_chip_count + + +def is_multi_host_tpu(node_metadata_labels: dict) -> bool: + """Determines whether the given node is a multi-host TPU configuration.""" + if GKELabelFormatter.TPU_LABEL_KEY in node_metadata_labels: + assert GKELabelFormatter.TPU_TOPOLOGY_LABEL_KEY in node_metadata_labels + topology_value = ( + node_metadata_labels[GKELabelFormatter.TPU_TOPOLOGY_LABEL_KEY]) + accelerator_count_label_key = ( + GKELabelFormatter.ACCELERATOR_COUNT_LABEL_KEY) + assert accelerator_count_label_key in node_metadata_labels + # node_tpu_chip_count represents the number of TPU chips + # available in this node. If the node is part of a node pool + # forming a multi-host TPU podslice, it only reflects the + # number of TPU chips in this individual node, not the entire + # multi-host TPU podslice. + node_tpu_chip_count = int( + node_metadata_labels[accelerator_count_label_key]) + topology_chip_count = reduce_tpu_topology(topology_value) + # For multi-host TPU podslices, topology_chip_count and + # node_tpu_chip_count will differ, as topology_chip_count + # reflects the total across all hosts, while + # node_tpu_chip_count reflects only the chips in a single node. + if node_tpu_chip_count != topology_chip_count: + return True + return False + + +def multi_host_tpu_exists_in_cluster(context: Optional[str] = None) -> bool: + """Checks if there exists a multi-host TPU within the cluster.""" + nodes = get_kubernetes_nodes(context) + for node in nodes: + if is_multi_host_tpu(node.metadata.labels): + return True + return False + + @dataclasses.dataclass class KubernetesSkyPilotClusterInfo: cluster_name_on_cloud: str diff --git a/sky/provision/oci/__init__.py b/sky/provision/oci/__init__.py new file mode 100644 index 00000000000..eb9128cc04c --- /dev/null +++ b/sky/provision/oci/__init__.py @@ -0,0 +1,15 @@ +"""OCI provisioner for SkyPilot. + +History: + - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Initial implementation +""" + +from sky.provision.oci.config import bootstrap_instances +from sky.provision.oci.instance import cleanup_ports +from sky.provision.oci.instance import get_cluster_info +from sky.provision.oci.instance import open_ports +from sky.provision.oci.instance import query_instances +from sky.provision.oci.instance import run_instances +from sky.provision.oci.instance import stop_instances +from sky.provision.oci.instance import terminate_instances +from sky.provision.oci.instance import wait_instances diff --git a/sky/provision/oci/config.py b/sky/provision/oci/config.py new file mode 100644 index 00000000000..e688bf12443 --- /dev/null +++ b/sky/provision/oci/config.py @@ -0,0 +1,51 @@ +"""OCI configuration bootstrapping. + +Creates the resource group and deploys the configuration template to OCI for +a cluster to be launched. + +History: + - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Initial implementation +""" + +from sky import exceptions +from sky import sky_logging +from sky.adaptors import oci as oci_adaptor +from sky.clouds.utils import oci_utils +from sky.provision import common +from sky.provision.oci.query_utils import query_helper + +logger = sky_logging.init_logger(__name__) + + +@common.log_function_start_end +def bootstrap_instances( + region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionConfig: + """See sky/provision/__init__.py""" + # OCI module import and oci client + oci_adaptor.get_core_client(region, oci_utils.oci_config.get_profile()) + + # Find / create a compartment for creating instances. + compartment = query_helper.find_compartment(region) + + # Find the configured VCN, or create a new one. + vcn = query_helper.find_create_vcn_subnet(region) + if vcn is None: + # pylint: disable=line-too-long + raise exceptions.ResourcesUnavailableError( + 'Failed to create a new VCN, possibly you hit the resource limitation.' + ) + + node_config = config.node_config + + # Subscribe the image if it is from Marketplace listing. + query_helper.subscribe_image( + compartment_id=compartment, + listing_id=node_config['AppCatalogListingId'], + resource_version=node_config['ResourceVersion'], + region=region, + ) + + logger.info(f'Using cluster name: {cluster_name_on_cloud}') + + return config diff --git a/sky/provision/oci/instance.py b/sky/provision/oci/instance.py new file mode 100644 index 00000000000..e04089ff8d4 --- /dev/null +++ b/sky/provision/oci/instance.py @@ -0,0 +1,433 @@ +"""OCI instance provisioning. + +History: + - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Initial implementation + - Hysun He (hysun.he@oracle.com) @ Nov.13, 2024: Implement open_ports + and cleanup_ports for supporting SkyServe. +""" + +import copy +from datetime import datetime +import time +from typing import Any, Dict, List, Optional + +from sky import exceptions +from sky import sky_logging +from sky import status_lib +from sky.adaptors import oci as oci_adaptor +from sky.clouds.utils import oci_utils +from sky.provision import common +from sky.provision import constants +from sky.provision.oci import query_utils +from sky.provision.oci.query_utils import query_helper +from sky.utils import common_utils +from sky.utils import ux_utils + +logger = sky_logging.init_logger(__name__) + + +@query_utils.debug_enabled(logger) +@common_utils.retry +def query_instances( + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, + non_terminated_only: bool = True, +) -> Dict[str, Optional[status_lib.ClusterStatus]]: + """Query instances. + + Returns a dictionary of instance IDs and status. + + A None status means the instance is marked as "terminated" + or "terminating". + """ + assert provider_config is not None, cluster_name_on_cloud + region = provider_config['region'] + + status_map = oci_utils.oci_config.STATE_MAPPING_OCI_TO_SKY + statuses: Dict[str, Optional[status_lib.ClusterStatus]] = {} + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + instances = _get_filtered_nodes(region, filters) + for node in instances: + vm_status = node['status'] + sky_status = status_map[vm_status] + if non_terminated_only and sky_status is None: + continue + statuses[node['inst_id']] = sky_status + + return statuses + + +@query_utils.debug_enabled(logger) +def run_instances(region: str, cluster_name_on_cloud: str, + config: common.ProvisionConfig) -> common.ProvisionRecord: + """Start instances with bootstrapped configuration.""" + tags = dict(sorted(copy.deepcopy(config.tags).items())) + + start_time = round(time.time() * 1000) + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + + # Starting stopped nodes if resume_stopped_nodes=True + resume_instances = [] + if config.resume_stopped_nodes: + logger.debug('Checking existing stopped nodes.') + + existing_instances = _get_filtered_nodes(region, filters) + if len(existing_instances) > config.count: + raise RuntimeError( + 'The number of pending/running/stopped/stopping ' + f'instances combined ({len(existing_instances)}) in ' + f'cluster "{cluster_name_on_cloud}" is greater than the ' + f'number requested by the user ({config.count}). ' + 'This is likely a resource leak. ' + 'Use "sky down" to terminate the cluster.') + + # pylint: disable=line-too-long + logger.debug( + f'run_instances: Found {[inst["name"] for inst in existing_instances]} ' + 'existing instances in cluster.') + existing_instances.sort(key=lambda x: x['name']) + + stopped_instances = [] + for existing_node in existing_instances: + if existing_node['status'] == 'STOPPING': + query_helper.wait_instance_until_status( + region, existing_node['inst_id'], 'STOPPED') + stopped_instances.append(existing_node) + elif existing_node['status'] == 'STOPPED': + stopped_instances.append(existing_node) + elif existing_node['status'] in ('PROVISIONING', 'STARTING', + 'RUNNING'): + resume_instances.append(existing_node) + + for stopped_node in stopped_instances: + stopped_node_id = stopped_node['inst_id'] + instance_action_response = query_helper.start_instance( + region, stopped_node_id) + + starting_inst = instance_action_response.data + resume_instances.append({ + 'inst_id': starting_inst.id, + 'name': starting_inst.display_name, + 'ad': starting_inst.availability_domain, + 'compartment': starting_inst.compartment_id, + 'status': starting_inst.lifecycle_state, + 'oci_tags': starting_inst.freeform_tags, + }) + # end if config.resume_stopped_nodes + + # Try get head id from the existing instances + head_instance_id = _get_head_instance_id(resume_instances) + logger.debug(f'Check existing head node: {head_instance_id}') + + # Let's create additional new nodes (if neccessary) + to_start_count = config.count - len(resume_instances) + created_instances = [] + node_config = config.node_config + if to_start_count > 0: + compartment = query_helper.find_compartment(region) + vcn = query_helper.find_create_vcn_subnet(region) + + ocpu_count = 0 + vcpu_str = node_config['VCPUs'] + instance_type_str = node_config['InstanceType'] + + if vcpu_str is not None and vcpu_str != 'None': + if instance_type_str.startswith( + f'{oci_utils.oci_config.VM_PREFIX}.A'): + # For ARM cpu, 1*ocpu = 1*vcpu + ocpu_count = round(float(vcpu_str)) + else: + # For Intel / AMD cpu, 1*ocpu = 2*vcpu + ocpu_count = round(float(vcpu_str) / 2) + ocpu_count = 1 if (ocpu_count > 0 and ocpu_count < 1) else ocpu_count + + machine_shape_config = None + if ocpu_count > 0: + mem = node_config['MemoryInGbs'] + if mem is not None and mem != 'None': + # pylint: disable=line-too-long + machine_shape_config = oci_adaptor.oci.core.models.LaunchInstanceShapeConfigDetails( + ocpus=ocpu_count, memory_in_gbs=mem) + else: + # pylint: disable=line-too-long + machine_shape_config = oci_adaptor.oci.core.models.LaunchInstanceShapeConfigDetails( + ocpus=ocpu_count) + + preempitible_config = ( + oci_adaptor.oci.core.models.PreemptibleInstanceConfigDetails( + preemption_action=oci_adaptor.oci.core.models. + TerminatePreemptionAction(type='TERMINATE', + preserve_boot_volume=False)) + if node_config['Preemptible'] else None) + + batch_id = datetime.now().strftime('%Y%m%d%H%M%S') + + vm_tags_head = { + **tags, + **constants.HEAD_NODE_TAGS, + constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + 'sky_spot_flag': str(node_config['Preemptible']).lower(), + } + vm_tags_worker = { + **tags, + **constants.WORKER_NODE_TAGS, + constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud, + 'sky_spot_flag': str(node_config['Preemptible']).lower(), + } + + for seq in range(1, to_start_count + 1): + if head_instance_id is None: + vm_tags = vm_tags_head + node_type = constants.HEAD_NODE_TAGS[ + constants.TAG_RAY_NODE_KIND] + else: + vm_tags = vm_tags_worker + node_type = constants.WORKER_NODE_TAGS[ + constants.TAG_RAY_NODE_KIND] + + launch_instance_response = query_helper.launch_instance( + region, + oci_adaptor.oci.core.models.LaunchInstanceDetails( + availability_domain=node_config['AvailabilityDomain'], + compartment_id=compartment, + shape=instance_type_str, + display_name= + f'{cluster_name_on_cloud}_{node_type}_{batch_id}_{seq}', + freeform_tags=vm_tags, + metadata={ + 'ssh_authorized_keys': node_config['AuthorizedKey'] + }, + source_details=oci_adaptor.oci.core.models. + InstanceSourceViaImageDetails( + source_type='image', + image_id=node_config['ImageId'], + boot_volume_size_in_gbs=node_config['BootVolumeSize'], + boot_volume_vpus_per_gb=int( + node_config['BootVolumePerf']), + ), + create_vnic_details=oci_adaptor.oci.core.models. + CreateVnicDetails( + assign_public_ip=True, + subnet_id=vcn, + ), + shape_config=machine_shape_config, + preemptible_instance_config=preempitible_config, + )) + + new_inst = launch_instance_response.data + if head_instance_id is None: + head_instance_id = new_inst.id + logger.debug(f'New head node: {head_instance_id}') + + created_instances.append({ + 'inst_id': new_inst.id, + 'name': new_inst.display_name, + 'ad': new_inst.availability_domain, + 'compartment': new_inst.compartment_id, + 'status': new_inst.lifecycle_state, + 'oci_tags': new_inst.freeform_tags, + }) + # end for loop + # end if to_start_count > 0:... + + for inst in (resume_instances + created_instances): + logger.debug(f'Provisioning for node {inst["name"]}') + query_helper.wait_instance_until_status(region, inst['inst_id'], + 'RUNNING') + logger.debug(f'Instance {inst["name"]} is RUNNING.') + + total_time = round(time.time() * 1000) - start_time + logger.debug('Total time elapsed: {0} milli-seconds.'.format(total_time)) + + assert head_instance_id is not None, head_instance_id + + # Format: TenancyPrefix:AvailabilityDomain, e.g. bxtG:US-SANJOSE-1-AD-1 + _, ad = str(node_config['AvailabilityDomain']).split(':', maxsplit=1) + return common.ProvisionRecord( + provider_name='oci', + region=region, + zone=ad, + cluster_name=cluster_name_on_cloud, + head_instance_id=head_instance_id, + created_instance_ids=[n['inst_id'] for n in created_instances], + resumed_instance_ids=[n['inst_id'] for n in resume_instances], + ) + + +@query_utils.debug_enabled(logger) +def stop_instances( + cluster_name_on_cloud: str, + provider_config: Dict[str, Any], + worker_only: bool = False, +) -> None: + """Stop running instances.""" + # pylint: disable=line-too-long + assert provider_config is not None, (cluster_name_on_cloud, provider_config) + + region = provider_config['region'] + tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + if worker_only: + tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker' + + nodes = _get_filtered_nodes(region, tag_filters) + for node in nodes: + query_helper.stop_instance(region, node['inst_id']) + + +@query_utils.debug_enabled(logger) +def terminate_instances( + cluster_name_on_cloud: str, + provider_config: Dict[str, Any], + worker_only: bool = False, +) -> None: + """Terminate running or stopped instances.""" + region = provider_config['region'] + tag_filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + if worker_only: + tag_filters[constants.TAG_RAY_NODE_KIND] = 'worker' + query_helper.terminate_instances_by_tags(tag_filters, region) + + +@query_utils.debug_enabled(logger) +def open_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """Open ports for inbound traffic.""" + assert provider_config is not None, cluster_name_on_cloud + region = provider_config['region'] + query_helper.create_nsg_rules(region=region, + cluster_name=cluster_name_on_cloud, + ports=ports) + + +@query_utils.debug_enabled(logger) +def cleanup_ports( + cluster_name_on_cloud: str, + ports: List[str], + provider_config: Optional[Dict[str, Any]] = None, +) -> None: + """Delete any opened ports.""" + assert provider_config is not None, cluster_name_on_cloud + region = provider_config['region'] + del ports + query_helper.remove_cluster_nsg(region=region, + cluster_name=cluster_name_on_cloud) + + +@query_utils.debug_enabled(logger) +def wait_instances(region: str, cluster_name_on_cloud: str, + state: Optional[status_lib.ClusterStatus]) -> None: + del region, cluster_name_on_cloud, state + # We already wait for the instances to be running in run_instances. + # We can not implement the wait logic here because the provisioning + # instances are not retrieveable by the QL 'query instance resources ...'. + + +@query_utils.debug_enabled(logger) +def get_cluster_info( + region: str, + cluster_name_on_cloud: str, + provider_config: Optional[Dict[str, Any]] = None, +) -> common.ClusterInfo: + """Get the metadata of instances in a cluster.""" + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name_on_cloud} + running_instances = _get_filtered_nodes(region, filters) + + instances = {} + for running_instance in running_instances: + inst = _get_inst_obj_with_ip(region, running_instance) + instances[inst['id']] = [ + common.InstanceInfo( + instance_id=inst['id'], + internal_ip=inst['internal_ip'], + external_ip=inst['external_ip'], + tags=inst['tags'], + ) + ] + + instances = dict(sorted(instances.items(), key=lambda x: x[0])) + logger.debug(f'Cluster info: {instances}') + + head_instance_id = _get_head_instance_id(running_instances) + logger.debug(f'Head instance id is {head_instance_id}') + + return common.ClusterInfo( + provider_name='oci', + head_instance_id=head_instance_id, + instances=instances, + provider_config=provider_config, + ) + + +def _get_filtered_nodes(region: str, + tag_filters: Dict[str, str]) -> List[Dict[str, Any]]: + return_nodes = [] + + try: + insts = query_helper.query_instances_by_tags(tag_filters, region) + except oci_adaptor.oci.exceptions.ServiceError as e: + with ux_utils.print_exception_no_traceback(): + raise exceptions.ClusterStatusFetchingError( + f'Failed to query status for OCI cluster {tag_filters}.' + 'Details: ' + f'{common_utils.format_exception(e, use_bracket=True)}') + + for inst in insts: + inst_id = inst.identifier + return_nodes.append({ + 'inst_id': inst_id, + 'name': inst.display_name, + 'ad': inst.availability_domain, + 'compartment': inst.compartment_id, + 'status': inst.lifecycle_state, + 'oci_tags': inst.freeform_tags, + }) + + return return_nodes + + +def _get_inst_obj_with_ip(region: str, inst_info: Dict[str, + Any]) -> Dict[str, Any]: + get_vnic_response = query_helper.get_instance_primary_vnic( + region, inst_info) + internal_ip = get_vnic_response.private_ip + external_ip = get_vnic_response.public_ip + if external_ip is None: + external_ip = internal_ip + + return { + 'id': inst_info['inst_id'], + 'name': inst_info['name'], + 'external_ip': external_ip, + 'internal_ip': internal_ip, + 'tags': inst_info['oci_tags'], + 'status': inst_info['status'], + } + + +def _get_head_instance_id(instances: List[Dict[str, Any]]) -> Optional[str]: + head_instance_id = None + head_node_tags = tuple(constants.HEAD_NODE_TAGS.items()) + for inst in instances: + is_matched = True + for k, v in head_node_tags: + if (k, v) not in inst['oci_tags'].items(): + is_matched = False + break + if is_matched: + if head_instance_id is not None: + logger.warning( + 'There are multiple head nodes in the cluster ' + f'(current head instance id: {head_instance_id}, ' + f'newly discovered id: {inst["inst_id"]}. It is likely ' + f'that something goes wrong.') + # Don't break here so that we can continue to check and + # warn user about duplicate head instance issue so that + # user can take further action on the abnormal cluster. + + head_instance_id = inst['inst_id'] + + return head_instance_id diff --git a/sky/skylet/providers/oci/query_helper.py b/sky/provision/oci/query_utils.py similarity index 50% rename from sky/skylet/providers/oci/query_helper.py rename to sky/provision/oci/query_utils.py index 8bbaab62b7f..47a0438cb21 100644 --- a/sky/skylet/providers/oci/query_helper.py +++ b/sky/provision/oci/query_utils.py @@ -1,56 +1,80 @@ -""" -Helper class for some OCI operations methods which needs to be shared/called -by multiple places. +"""OCI query helper class History: - - Hysun He (hysun.he@oracle.com) @ Apr, 2023: Initial implementation - + - Hysun He (hysun.he@oracle.com) @ Oct.16, 2024: Code here mainly + migrated from the old provisioning API. + - Hysun He (hysun.he@oracle.com) @ Oct.18, 2024: Enhancement. + find_compartment: allow search subtree when find a compartment. + - Hysun He (hysun.he@oracle.com) @ Nov.12, 2024: Add methods to + Add/remove security rules: create_nsg_rules & remove_nsg """ - from datetime import datetime -import logging +import functools +from logging import Logger import re import time import traceback import typing -from typing import Optional +from typing import List, Optional, Tuple +from sky import exceptions +from sky import sky_logging from sky.adaptors import common as adaptors_common from sky.adaptors import oci as oci_adaptor from sky.clouds.utils import oci_utils -from sky.skylet.providers.oci import utils +from sky.provision import constants +from sky.utils import resources_utils if typing.TYPE_CHECKING: import pandas as pd else: pd = adaptors_common.LazyImport('pandas') -logger = logging.getLogger(__name__) +logger = sky_logging.init_logger(__name__) + +def debug_enabled(log: Logger): + + def decorate(f): + + @functools.wraps(f) + def wrapper(*args, **kwargs): + dt_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + log.debug(f'{dt_str} Enter {f}, {args}, {kwargs}') + try: + return f(*args, **kwargs) + finally: + log.debug(f'{dt_str} Exit {f}') -class oci_query_helper: + return wrapper + return decorate + + +class QueryHelper: + """Helper class for some OCI operations + """ # Call Cloud API to try getting the satisfied nodes. @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) def query_instances_by_tags(cls, tag_filters, region): - where_clause_tags = "" + where_clause_tags = '' for tag_key in tag_filters: - if where_clause_tags != "": - where_clause_tags += " && " + if where_clause_tags != '': + where_clause_tags += ' && ' tag_value = tag_filters[tag_key] - where_clause_tags += (f"(freeformTags.key = '{tag_key}'" - f" && freeformTags.value = '{tag_value}')") + where_clause_tags += (f'(freeformTags.key = \'{tag_key}\'' + f' && freeformTags.value = \'{tag_value}\')') - qv_str = (f"query instance resources where {where_clause_tags}" - f" && (lifecycleState != 'TERMINATED'" - f" && lifecycleState != 'TERMINATING')") + qv_str = (f'query instance resources where {where_clause_tags}' + f' && (lifecycleState != \'TERMINATED\'' + f' && lifecycleState != \'TERMINATING\')') qv = oci_adaptor.oci.resource_search.models.StructuredSearchDetails( query=qv_str, - type="Structured", + type='Structured', matching_context_type=oci_adaptor.oci.resource_search.models. SearchDetails.MATCHING_CONTEXT_TYPE_NONE, ) @@ -62,45 +86,113 @@ def query_instances_by_tags(cls, tag_filters, region): return result_set @classmethod + @debug_enabled(logger) def terminate_instances_by_tags(cls, tag_filters, region) -> int: - logger.debug(f"Terminate instance by tags: {tag_filters}") + logger.debug(f'Terminate instance by tags: {tag_filters}') + + cluster_name = tag_filters[constants.TAG_RAY_CLUSTER_NAME] + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=False) + + core_client = oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()) + insts = cls.query_instances_by_tags(tag_filters, region) fail_count = 0 for inst in insts: inst_id = inst.identifier - logger.debug(f"Got instance(to be terminated): {inst_id}") + logger.debug(f'Terminating instance {inst_id}') try: - oci_adaptor.get_core_client( - region, - oci_utils.oci_config.get_profile()).terminate_instance( - inst_id) - except Exception as e: + # Release the NSG reference so that the NSG can be + # deleted without waiting the instance being terminated. + if nsg_id is not None: + cls.detach_nsg(region, inst, nsg_id) + + # Terminate the instance + core_client.terminate_instance(inst_id) + + except oci_adaptor.oci.exceptions.ServiceError as e: fail_count += 1 - logger.error(f"Terminate instance failed: {str(e)}\n: {inst}") + logger.error(f'Terminate instance failed: {str(e)}\n: {inst}') traceback.print_exc() if fail_count == 0: - logger.debug(f"Instance teardown result: OK") + logger.debug('Instance teardown result: OK') else: - logger.warn(f"Instance teardown result: {fail_count} failed!") + logger.warning(f'Instance teardown result: {fail_count} failed!') return fail_count @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) + def launch_instance(cls, region, launch_config): + """ To create a new instance """ + return oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()).launch_instance( + launch_instance_details=launch_config) + + @classmethod + @debug_enabled(logger) + def start_instance(cls, region, instance_id): + """ To start an existing instance """ + return oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()).instance_action( + instance_id=instance_id, action='START') + + @classmethod + @debug_enabled(logger) + def stop_instance(cls, region, instance_id): + """ To stop an instance """ + return oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()).instance_action( + instance_id=instance_id, action='STOP') + + @classmethod + @debug_enabled(logger) + def wait_instance_until_status(cls, region, node_id, status): + """ To wait a instance becoming the specified state """ + compute_client = oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()) + + resp = compute_client.get_instance(instance_id=node_id) + + oci_adaptor.oci.wait_until( + compute_client, + resp, + 'lifecycle_state', + status, + ) + + @classmethod + def get_instance_primary_vnic(cls, region, inst_info): + """ Get the primary vnic infomation of the instance """ + list_vnic_attachments_response = oci_adaptor.get_core_client( + region, oci_utils.oci_config.get_profile()).list_vnic_attachments( + availability_domain=inst_info['ad'], + compartment_id=inst_info['compartment'], + instance_id=inst_info['inst_id'], + ) + vnic = list_vnic_attachments_response.data[0] + return oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()).get_vnic( + vnic_id=vnic.vnic_id).data + + @classmethod + @debug_enabled(logger) def subscribe_image(cls, compartment_id, listing_id, resource_version, region): - if (pd.isna(listing_id) or listing_id.strip() == "None" or - listing_id.strip() == "nan"): + if (pd.isna(listing_id) or listing_id.strip() == 'None' or + listing_id.strip() == 'nan'): return core_client = oci_adaptor.get_core_client( region, oci_utils.oci_config.get_profile()) try: - agreements_response = core_client.get_app_catalog_listing_agreements( + agreements_resp = core_client.get_app_catalog_listing_agreements( listing_id=listing_id, resource_version=resource_version) - agreements = agreements_response.data + agreements = agreements_resp.data core_client.create_app_catalog_subscription( create_app_catalog_subscription_details=oci_adaptor.oci.core. @@ -113,24 +205,24 @@ def subscribe_image(cls, compartment_id, listing_id, resource_version, oracle_terms_of_use_link, time_retrieved=datetime.strptime( re.sub( - "\d{3}\+\d{2}\:\d{2}", - "Z", + r'\d{3}\+\d{2}\:\d{2}', + 'Z', str(agreements.time_retrieved), 0, ), - "%Y-%m-%d %H:%M:%S.%fZ", + '%Y-%m-%d %H:%M:%S.%fZ', ), signature=agreements.signature, eula_link=agreements.eula_link, )) - except Exception as e: + except oci_adaptor.oci.exceptions.ServiceError as e: logger.critical( - f"subscribe_image: {listing_id} - {resource_version} ... [Failed]" - f"Error message: {str(e)}") - raise RuntimeError("ERR: Image subscription error!") + f'[Failed] subscribe_image: {listing_id} - {resource_version}' + f'Error message: {str(e)}') + raise RuntimeError('ERR: Image subscription error!') from e @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) def find_compartment(cls, region) -> str: """ If compartment is not configured, we use root compartment """ # Try to use the configured one first @@ -143,12 +235,18 @@ def find_compartment(cls, region) -> str: # config file is supported (2023/06/09). root = oci_adaptor.get_oci_config( region, oci_utils.oci_config.get_profile())['tenancy'] + list_compartments_response = oci_adaptor.get_identity_client( region, oci_utils.oci_config.get_profile()).list_compartments( compartment_id=root, name=oci_utils.oci_config.COMPARTMENT, + compartment_id_in_subtree=True, + access_level='ACCESSIBLE', lifecycle_state='ACTIVE', + sort_by='TIMECREATED', + sort_order='DESC', limit=1) + compartments = list_compartments_response.data if len(compartments) > 0: skypilot_compartment = compartments[0].id @@ -159,7 +257,7 @@ def find_compartment(cls, region) -> str: return skypilot_compartment @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) def find_create_vcn_subnet(cls, region) -> Optional[str]: """ If sub is not configured, we find/create VCN skypilot_vcn """ subnet = oci_utils.oci_config.get_vcn_subnet(region) @@ -174,7 +272,7 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]: list_vcns_response = net_client.list_vcns( compartment_id=skypilot_compartment, display_name=oci_utils.oci_config.VCN_NAME, - lifecycle_state="AVAILABLE") + lifecycle_state='AVAILABLE') vcns = list_vcns_response.data if len(vcns) > 0: # Found the VCN. @@ -184,7 +282,7 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]: limit=1, vcn_id=skypilot_vcn, display_name=oci_utils.oci_config.VCN_SUBNET_NAME, - lifecycle_state="AVAILABLE") + lifecycle_state='AVAILABLE') logger.debug(f'Got VCN subnet \n{list_subnets_response.data}') if len(list_subnets_response.data) < 1: logger.error( @@ -201,10 +299,17 @@ def find_create_vcn_subnet(cls, region) -> Optional[str]: return cls.create_vcn_subnet(net_client, skypilot_compartment) @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) def create_vcn_subnet(cls, net_client, skypilot_compartment) -> Optional[str]: + + skypilot_vcn = None # VCN for the resources + subnet = None # Subnet for the VMs + ig = None # Internet gateway + sg = None # Service gateway + try: + # pylint: disable=line-too-long create_vcn_response = net_client.create_vcn( create_vcn_details=oci_adaptor.oci.core.models.CreateVcnDetails( compartment_id=skypilot_compartment, @@ -274,38 +379,38 @@ def create_vcn_subnet(cls, net_client, update_security_list_details=oci_adaptor.oci.core.models. UpdateSecurityListDetails(ingress_security_rules=[ oci_adaptor.oci.core.models.IngressSecurityRule( - protocol="6", + protocol='6', source=oci_utils.oci_config.VCN_CIDR_INTERNET, is_stateless=False, - source_type="CIDR_BLOCK", + source_type='CIDR_BLOCK', tcp_options=oci_adaptor.oci.core.models.TcpOptions( destination_port_range=oci_adaptor.oci.core.models. PortRange(max=22, min=22), source_port_range=oci_adaptor.oci.core.models. PortRange(max=65535, min=1)), - description="Allow SSH port."), + description='Allow SSH port.'), oci_adaptor.oci.core.models.IngressSecurityRule( - protocol="all", + protocol='all', source=oci_utils.oci_config.VCN_SUBNET_CIDR, is_stateless=False, - source_type="CIDR_BLOCK", - description="Allow all traffic from/to same subnet."), + source_type='CIDR_BLOCK', + description='Allow all traffic from/to same subnet.'), oci_adaptor.oci.core.models.IngressSecurityRule( - protocol="1", + protocol='1', source=oci_utils.oci_config.VCN_CIDR_INTERNET, is_stateless=False, - source_type="CIDR_BLOCK", + source_type='CIDR_BLOCK', icmp_options=oci_adaptor.oci.core.models.IcmpOptions( type=3, code=4), - description="ICMP traffic."), + description='ICMP traffic.'), oci_adaptor.oci.core.models.IngressSecurityRule( - protocol="1", + protocol='1', source=oci_utils.oci_config.VCN_CIDR, is_stateless=False, - source_type="CIDR_BLOCK", + source_type='CIDR_BLOCK', icmp_options=oci_adaptor.oci.core.models.IcmpOptions( type=3), - description="ICMP traffic (VCN)."), + description='ICMP traffic (VCN).'), ])) logger.debug( f'Updated security_list: \n{update_security_list_response.data}' @@ -325,7 +430,7 @@ def create_vcn_subnet(cls, net_client, ])) logger.debug(f'Route table: \n{update_route_table_response.data}') - except oci_adaptor.service_exception() as e: + except oci_adaptor.oci.exceptions.ServiceError as e: logger.error(f'Create VCN Error: Create new VCN ' f'{oci_utils.oci_config.VCN_NAME} failed: {str(e)}') # In case of partial success while creating vcn @@ -335,7 +440,7 @@ def create_vcn_subnet(cls, net_client, return subnet @classmethod - @utils.debug_enabled(logger=logger) + @debug_enabled(logger) def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet, internet_gateway, service_gateway): if skypilot_vcn is None: @@ -369,7 +474,7 @@ def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet, f'Deleted vcn {skypilot_vcn}-{delete_vcn_response.data}' ) break - except oci_adaptor.service_exception() as e: + except oci_adaptor.oci.exceptions.ServiceError as e: logger.info(f'Waiting del SG/IG/Subnet finish: {str(e)}') retry_count = retry_count + 1 if retry_count == oci_utils.oci_config.MAX_RETRY_COUNT: @@ -378,6 +483,196 @@ def delete_vcn(cls, net_client, skypilot_vcn, skypilot_subnet, time.sleep( oci_utils.oci_config.RETRY_INTERVAL_BASE_SECONDS) - except oci_adaptor.service_exception() as e: + except oci_adaptor.oci.exceptions.ServiceError as e: logger.error( f'Delete VCN {oci_utils.oci_config.VCN_NAME} Error: {str(e)}') + + @classmethod + @debug_enabled(logger) + def find_nsg(cls, region: str, nsg_name: str, + create_if_not_exist: bool) -> Optional[str]: + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + compartment = cls.find_compartment(region) + + list_vcns_resp = net_client.list_vcns( + compartment_id=compartment, + display_name=oci_utils.oci_config.VCN_NAME, + lifecycle_state='AVAILABLE', + ) + + if not list_vcns_resp: + raise exceptions.ResourcesUnavailableError( + 'The VCN is not available') + + # Get the primary vnic. + assert len(list_vcns_resp.data) > 0 + vcn = list_vcns_resp.data[0] + + list_nsg_resp = net_client.list_network_security_groups( + compartment_id=compartment, + vcn_id=vcn.id, + limit=1, + display_name=nsg_name, + ) + + nsgs = list_nsg_resp.data + if nsgs: + assert len(nsgs) == 1 + return nsgs[0].id + elif not create_if_not_exist: + return None + + # Continue to create new NSG if not exists + create_nsg_resp = net_client.create_network_security_group( + create_network_security_group_details=oci_adaptor.oci.core.models. + CreateNetworkSecurityGroupDetails( + compartment_id=compartment, + vcn_id=vcn.id, + display_name=nsg_name, + )) + get_nsg_resp = net_client.get_network_security_group( + network_security_group_id=create_nsg_resp.data.id) + oci_adaptor.oci.wait_until( + net_client, + get_nsg_resp, + 'lifecycle_state', + 'AVAILABLE', + ) + + return get_nsg_resp.data.id + + @classmethod + def get_range_min_max(cls, port_range: str) -> Tuple[int, int]: + range_list = port_range.split('-') + if len(range_list) == 1: + return (int(range_list[0]), int(range_list[0])) + from_port, to_port = range_list + return (int(from_port), int(to_port)) + + @classmethod + @debug_enabled(logger) + def create_nsg_rules(cls, region: str, cluster_name: str, + ports: List[str]) -> None: + """ Create per-cluster NSG with ingress rules """ + if not ports: + return + + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=True) + + filters = {constants.TAG_RAY_CLUSTER_NAME: cluster_name} + insts = query_helper.query_instances_by_tags(filters, region) + for inst in insts: + vnic = cls.get_instance_primary_vnic( + region=region, + inst_info={ + 'inst_id': inst.identifier, + 'ad': inst.availability_domain, + 'compartment': inst.compartment_id, + }) + nsg_ids = vnic.nsg_ids + if not nsg_ids: + net_client.update_vnic( + vnic_id=vnic.id, + update_vnic_details=oci_adaptor.oci.core.models. + UpdateVnicDetails(nsg_ids=[nsg_id], + skip_source_dest_check=False), + ) + + # pylint: disable=line-too-long + list_nsg_rules_resp = net_client.list_network_security_group_security_rules( + network_security_group_id=nsg_id, + direction='INGRESS', + sort_by='TIMECREATED', + sort_order='DESC', + ) + + ingress_rules: List = list_nsg_rules_resp.data + existing_port_ranges: List[str] = [] + for r in ingress_rules: + if r.tcp_options: + options_range = r.tcp_options.destination_port_range + rule_port_range = f'{options_range.min}-{options_range.max}' + existing_port_ranges.append(rule_port_range) + + new_ports = resources_utils.port_ranges_to_set(ports) + existing_ports = resources_utils.port_ranges_to_set( + existing_port_ranges) + if new_ports.issubset(existing_ports): + # ports already contains in the existing rules, nothing to add. + return + + # Determine the ports to be added, without overlapping. + ports_to_open = new_ports - existing_ports + port_ranges_to_open = resources_utils.port_set_to_ranges(ports_to_open) + + new_rules = [] + for port_range in port_ranges_to_open: + port_range_min, port_range_max = cls.get_range_min_max(port_range) + new_rules.append( + oci_adaptor.oci.core.models.AddSecurityRuleDetails( + direction='INGRESS', + protocol='6', + is_stateless=False, + source=oci_utils.oci_config.VCN_CIDR_INTERNET, + source_type='CIDR_BLOCK', + tcp_options=oci_adaptor.oci.core.models.TcpOptions( + destination_port_range=oci_adaptor.oci.core.models. + PortRange(min=port_range_min, max=port_range_max),), + description=oci_utils.oci_config.SERVICE_PORT_RULE_TAG, + )) + + net_client.add_network_security_group_security_rules( + network_security_group_id=nsg_id, + add_network_security_group_security_rules_details=oci_adaptor.oci. + core.models.AddNetworkSecurityGroupSecurityRulesDetails( + security_rules=new_rules), + ) + + @classmethod + @debug_enabled(logger) + def detach_nsg(cls, region: str, inst, nsg_id: Optional[str]) -> None: + if nsg_id is None: + return + + vnic = cls.get_instance_primary_vnic( + region=region, + inst_info={ + 'inst_id': inst.identifier, + 'ad': inst.availability_domain, + 'compartment': inst.compartment_id, + }) + + # Detatch the NSG before removing it. + oci_adaptor.get_net_client(region, oci_utils.oci_config.get_profile( + )).update_vnic( + vnic_id=vnic.id, + update_vnic_details=oci_adaptor.oci.core.models.UpdateVnicDetails( + nsg_ids=[], skip_source_dest_check=False), + ) + + @classmethod + @debug_enabled(logger) + def remove_cluster_nsg(cls, region: str, cluster_name: str) -> None: + """ Remove NSG of the cluster """ + net_client = oci_adaptor.get_net_client( + region, oci_utils.oci_config.get_profile()) + + nsg_name = oci_utils.oci_config.NSG_NAME_TEMPLATE.format( + cluster_name=cluster_name) + nsg_id = cls.find_nsg(region, nsg_name, create_if_not_exist=False) + if nsg_id is None: + return + + # Delete the NSG + net_client.delete_network_security_group( + network_security_group_id=nsg_id) + + +query_helper = QueryHelper() diff --git a/sky/provision/provisioner.py b/sky/provision/provisioner.py index b3e965769c9..cc2ca73e1dc 100644 --- a/sky/provision/provisioner.py +++ b/sky/provision/provisioner.py @@ -29,6 +29,7 @@ from sky.utils import resources_utils from sky.utils import rich_utils from sky.utils import subprocess_utils +from sky.utils import timeline from sky.utils import ux_utils # Do not use __name__ as we do not want to propagate logs to sky.provision, @@ -343,6 +344,7 @@ def _wait_ssh_connection_indirect(ip: str, return True, '' +@timeline.event def wait_for_ssh(cluster_info: provision_common.ClusterInfo, ssh_credentials: Dict[str, str]): """Wait until SSH is ready. @@ -432,11 +434,15 @@ def _post_provision_setup( ux_utils.spinner_message( 'Launching - Waiting for SSH access', provision_logging.config.log_path)) as status: - - logger.debug( - f'\nWaiting for SSH to be available for {cluster_name!r} ...') - wait_for_ssh(cluster_info, ssh_credentials) - logger.debug(f'SSH Connection ready for {cluster_name!r}') + # If on Kubernetes, skip SSH check since the pods are guaranteed to be + # ready by the provisioner, and we use kubectl instead of SSH to run the + # commands and rsync on the pods. SSH will still be ready after a while + # for the users to SSH into the pod. + if cloud_name.lower() != 'kubernetes': + logger.debug( + f'\nWaiting for SSH to be available for {cluster_name!r} ...') + wait_for_ssh(cluster_info, ssh_credentials) + logger.debug(f'SSH Connection ready for {cluster_name!r}') vm_str = 'Instance' if cloud_name.lower() != 'kubernetes' else 'Pod' plural = '' if len(cluster_info.instances) == 1 else 's' verb = 'is' if len(cluster_info.instances) == 1 else 'are' @@ -496,31 +502,94 @@ def _post_provision_setup( **ssh_credentials) head_runner = runners[0] - status.update( - runtime_preparation_str.format(step=3, step_name='runtime')) - full_ray_setup = True - ray_port = constants.SKY_REMOTE_RAY_PORT - if not provision_record.is_instance_just_booted( - head_instance.instance_id): + def is_ray_cluster_healthy(ray_status_output: str, + expected_num_nodes: int) -> bool: + """Parse the output of `ray status` to get #active nodes. + + The output of `ray status` looks like: + Node status + --------------------------------------------------------------- + Active: + 1 node_291a8b849439ad6186387c35dc76dc43f9058108f09e8b68108cf9ec + 1 node_0945fbaaa7f0b15a19d2fd3dc48f3a1e2d7c97e4a50ca965f67acbfd + Pending: + (no pending nodes) + Recent failures: + (no failures) + """ + start = ray_status_output.find('Active:') + end = ray_status_output.find('Pending:', start) + if start == -1 or end == -1: + return False + num_active_nodes = 0 + for line in ray_status_output[start:end].split('\n'): + if line.strip() and not line.startswith('Active:'): + num_active_nodes += 1 + return num_active_nodes == expected_num_nodes + + def check_ray_port_and_cluster_healthy() -> Tuple[int, bool, bool]: + head_ray_needs_restart = True + ray_cluster_healthy = False + ray_port = constants.SKY_REMOTE_RAY_PORT + # Check if head node Ray is alive returncode, stdout, _ = head_runner.run( instance_setup.RAY_STATUS_WITH_SKY_RAY_PORT_COMMAND, stream_logs=False, require_outputs=True) - if returncode: - logger.debug('Ray cluster on head is not up. Restarting...') - else: - logger.debug('Ray cluster on head is up.') + if not returncode: ray_port = common_utils.decode_payload(stdout)['ray_port'] - full_ray_setup = bool(returncode) + logger.debug(f'Ray cluster on head is up with port {ray_port}.') + + head_ray_needs_restart = bool(returncode) + # This is a best effort check to see if the ray cluster has expected + # number of nodes connected. + ray_cluster_healthy = (not head_ray_needs_restart and + is_ray_cluster_healthy( + stdout, cluster_info.num_instances)) + return ray_port, ray_cluster_healthy, head_ray_needs_restart + + status.update( + runtime_preparation_str.format(step=3, step_name='runtime')) + + ray_port = constants.SKY_REMOTE_RAY_PORT + head_ray_needs_restart = True + ray_cluster_healthy = False + if (not provision_record.is_instance_just_booted( + head_instance.instance_id)): + # Check if head node Ray is alive + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + elif cloud_name.lower() == 'kubernetes': + timeout = 90 # 1.5-min maximum timeout + start = time.time() + while True: + # Wait until Ray cluster is ready + (ray_port, ray_cluster_healthy, + head_ray_needs_restart) = check_ray_port_and_cluster_healthy() + if ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip head and worker ' + 'node ray cluster setup.') + break + if time.time() - start > timeout: + # In most cases, the ray cluster will be ready after a few + # seconds. Trigger ray start on head or worker nodes to be + # safe, if the ray cluster is not ready after timeout. + break + logger.debug('Ray cluster is not ready yet, waiting for the ' + 'async setup to complete...') + time.sleep(1) - if full_ray_setup: + if head_ray_needs_restart: logger.debug('Starting Ray on the entire cluster.') instance_setup.start_ray_on_head_node( cluster_name.name_on_cloud, custom_resource=custom_resource, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + else: + logger.debug('Ray cluster on head is ready. Skip starting ray ' + 'cluster on head node.') # NOTE: We have to check all worker nodes to make sure they are all # healthy, otherwise we can only start Ray on newly started worker @@ -531,10 +600,13 @@ def _post_provision_setup( # if provision_record.is_instance_just_booted(inst.instance_id): # worker_ips.append(inst.public_ip) - if cluster_info.num_instances > 1: + # We don't need to restart ray on worker nodes if the ray cluster is + # already healthy, i.e. the head node has expected number of nodes + # connected to the ray cluster. + if cluster_info.num_instances > 1 and not ray_cluster_healthy: instance_setup.start_ray_on_worker_nodes( cluster_name.name_on_cloud, - no_restart=not full_ray_setup, + no_restart=not head_ray_needs_restart, custom_resource=custom_resource, # Pass the ray_port to worker nodes for backward compatibility # as in some existing clusters the ray_port is not dumped with @@ -543,6 +615,9 @@ def _post_provision_setup( ray_port=ray_port, cluster_info=cluster_info, ssh_credentials=ssh_credentials) + elif ray_cluster_healthy: + logger.debug('Ray cluster is ready. Skip starting ray cluster on ' + 'worker nodes.') instance_setup.start_skylet_on_head_node(cluster_name.name_on_cloud, cluster_info, ssh_credentials) @@ -553,6 +628,7 @@ def _post_provision_setup( return cluster_info +@timeline.event def post_provision_runtime_setup( cloud_name: str, cluster_name: resources_utils.ClusterName, cluster_yaml: str, provision_record: provision_common.ProvisionRecord, diff --git a/sky/resources.py b/sky/resources.py index 3b33476713b..5184278e02e 100644 --- a/sky/resources.py +++ b/sky/resources.py @@ -14,6 +14,7 @@ from sky import skypilot_config from sky.clouds import service_catalog from sky.provision import docker_utils +from sky.provision.kubernetes import utils as kubernetes_utils from sky.skylet import constants from sky.utils import accelerator_registry from sky.utils import common_utils @@ -44,7 +45,7 @@ class Resources: """ # If any fields changed, increment the version. For backward compatibility, # modify the __setstate__ method to handle the old version. - _VERSION = 19 + _VERSION = 20 def __init__( self, @@ -582,36 +583,46 @@ def _set_accelerators( acc, _ = list(accelerators.items())[0] if 'tpu' in acc.lower(): if self.cloud is None: - self._cloud = clouds.GCP() - assert self.cloud.is_same_cloud( - clouds.GCP()), 'Cloud must be GCP.' + if kubernetes_utils.is_tpu_on_gke(acc): + self._cloud = clouds.Kubernetes() + else: + self._cloud = clouds.GCP() + assert (self.cloud.is_same_cloud(clouds.GCP()) or + self.cloud.is_same_cloud(clouds.Kubernetes())), ( + 'Cloud must be GCP or Kubernetes for TPU ' + 'accelerators.') + if accelerator_args is None: accelerator_args = {} + use_tpu_vm = accelerator_args.get('tpu_vm', True) - if self.instance_type is not None and use_tpu_vm: - if self.instance_type != 'TPU-VM': - with ux_utils.print_exception_no_traceback(): - raise ValueError( - 'Cannot specify instance type' - f' (got "{self.instance_type}") for TPU VM.') - if 'runtime_version' not in accelerator_args: - - def _get_default_runtime_version() -> str: - if not use_tpu_vm: - return '2.12.0' - # TPU V5 requires a newer runtime version. - if acc.startswith('tpu-v5'): - return 'v2-alpha-tpuv5' - # TPU V6e requires a newer runtime version. - if acc.startswith('tpu-v6e'): - return 'v2-alpha-tpuv6e' - return 'tpu-vm-base' - - accelerator_args['runtime_version'] = ( - _get_default_runtime_version()) - logger.info( - 'Missing runtime_version in accelerator_args, using' - f' default ({accelerator_args["runtime_version"]})') + if (self.cloud.is_same_cloud(clouds.GCP()) and + not kubernetes_utils.is_tpu_on_gke(acc)): + if 'runtime_version' not in accelerator_args: + + def _get_default_runtime_version() -> str: + if not use_tpu_vm: + return '2.12.0' + # TPU V5 requires a newer runtime version. + if acc.startswith('tpu-v5'): + return 'v2-alpha-tpuv5' + # TPU V6e requires a newer runtime version. + elif acc.startswith('tpu-v6e'): + return 'v2-alpha-tpuv6e' + return 'tpu-vm-base' + + accelerator_args['runtime_version'] = ( + _get_default_runtime_version()) + logger.info( + 'Missing runtime_version in accelerator_args, using' + f' default ({accelerator_args["runtime_version"]})') + + if self.instance_type is not None and use_tpu_vm: + if self.instance_type != 'TPU-VM': + with ux_utils.print_exception_no_traceback(): + raise ValueError( + 'Cannot specify instance type (got ' + f'{self.instance_type!r}) for TPU VM.') self._accelerators = accelerators self._accelerator_args = accelerator_args @@ -1030,6 +1041,7 @@ def get_spot_str(self) -> str: def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, region: clouds.Region, zones: Optional[List[clouds.Zone]], + num_nodes: int, dryrun: bool) -> Dict[str, Optional[str]]: """Converts planned sky.Resources to resource variables. @@ -1051,7 +1063,7 @@ def make_deploy_variables(self, cluster_name: resources_utils.ClusterName, # Cloud specific variables cloud_specific_variables = self.cloud.make_deploy_resources_variables( - self, cluster_name, region, zones, dryrun) + self, cluster_name, region, zones, num_nodes, dryrun) # Docker run options docker_run_options = skypilot_config.get_nested( @@ -1595,4 +1607,25 @@ def __setstate__(self, state): self._cluster_config_overrides = state.pop( '_cluster_config_overrides', None) + if version < 20: + # Pre-0.7.0, we used 'kubernetes' as the default region for + # Kubernetes clusters. With the introduction of support for + # multiple contexts, we now set the region to the context name. + # Since we do not have information on which context the cluster + # was run in, we default it to the current active context. + legacy_region = clouds.Kubernetes().LEGACY_SINGLETON_REGION + original_cloud = state.get('_cloud', None) + original_region = state.get('_region', None) + if (isinstance(original_cloud, clouds.Kubernetes) and + original_region == legacy_region): + current_context = ( + kubernetes_utils.get_current_kube_config_context_name()) + state['_region'] = current_context + # Also update the image_id dict if it contains the old region + if isinstance(state['_image_id'], dict): + if legacy_region in state['_image_id']: + state['_image_id'][current_context] = ( + state['_image_id'][legacy_region]) + del state['_image_id'][legacy_region] + self.__dict__.update(state) diff --git a/sky/serve/__init__.py b/sky/serve/__init__.py index f93495809c3..6bda949d3c3 100644 --- a/sky/serve/__init__.py +++ b/sky/serve/__init__.py @@ -11,6 +11,7 @@ from sky.serve.core import terminate_replica from sky.serve.core import up from sky.serve.core import update +from sky.serve.load_balancing_policies import LB_POLICIES from sky.serve.serve_state import ReplicaStatus from sky.serve.serve_state import ServiceStatus from sky.serve.serve_utils import DEFAULT_UPDATE_MODE @@ -35,6 +36,7 @@ 'get_endpoint', 'INITIAL_VERSION', 'LB_CONTROLLER_SYNC_INTERVAL_SECONDS', + 'LB_POLICIES', 'ReplicaStatus', 'ServiceComponent', 'ServiceStatus', diff --git a/sky/serve/core.py b/sky/serve/core.py index abf9bfbc719..f6f6c53ad7b 100644 --- a/sky/serve/core.py +++ b/sky/serve/core.py @@ -701,6 +701,7 @@ def tail_logs( with ux_utils.print_exception_no_traceback(): raise ValueError(f'`target` must be a string or ' f'sky.serve.ServiceComponent, got {type(target)}.') + if target == serve_utils.ServiceComponent.REPLICA: if replica_id is None: with ux_utils.print_exception_no_traceback(): diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index c15f71e214a..30697532a22 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -2,7 +2,7 @@ import asyncio import logging import threading -from typing import Dict, Union +from typing import Dict, Optional, Union import aiohttp import fastapi @@ -27,18 +27,24 @@ class SkyServeLoadBalancer: policy. """ - def __init__(self, controller_url: str, load_balancer_port: int) -> None: + def __init__(self, + controller_url: str, + load_balancer_port: int, + load_balancing_policy_name: Optional[str] = None) -> None: """Initialize the load balancer. Args: controller_url: The URL of the controller. load_balancer_port: The port where the load balancer listens to. + load_balancing_policy_name: The name of the load balancing policy + to use. Defaults to None. """ self._app = fastapi.FastAPI() self._controller_url: str = controller_url self._load_balancer_port: int = load_balancer_port - self._load_balancing_policy: lb_policies.LoadBalancingPolicy = ( - lb_policies.RoundRobinPolicy()) + # Use the registry to create the load balancing policy + self._load_balancing_policy = lb_policies.LoadBalancingPolicy.make( + load_balancing_policy_name) self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) # TODO(tian): httpx.Client has a resource limit of 100 max connections @@ -223,9 +229,21 @@ async def startup(): uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port) -def run_load_balancer(controller_addr: str, load_balancer_port: int): - load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, - load_balancer_port=load_balancer_port) +def run_load_balancer(controller_addr: str, + load_balancer_port: int, + load_balancing_policy_name: Optional[str] = None) -> None: + """ Run the load balancer. + + Args: + controller_addr: The address of the controller. + load_balancer_port: The port where the load balancer listens to. + policy_name: The name of the load balancing policy to use. Defaults to + None. + """ + load_balancer = SkyServeLoadBalancer( + controller_url=controller_addr, + load_balancer_port=load_balancer_port, + load_balancing_policy_name=load_balancing_policy_name) load_balancer.run() @@ -241,5 +259,13 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int): required=True, default=8890, help='The port where the load balancer listens to.') + available_policies = list(lb_policies.LB_POLICIES.keys()) + parser.add_argument( + '--load-balancing-policy', + choices=available_policies, + default='round_robin', + help=f'The load balancing policy to use. Available policies: ' + f'{", ".join(available_policies)}.') args = parser.parse_args() - run_load_balancer(args.controller_addr, args.load_balancer_port) + run_load_balancer(args.controller_addr, args.load_balancer_port, + args.load_balancing_policy) diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index 34c1fa4249b..aec6eb01487 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -10,6 +10,10 @@ logger = sky_logging.init_logger(__name__) +# Define a registry for load balancing policies +LB_POLICIES = {} +DEFAULT_LB_POLICY = None + def _request_repr(request: 'fastapi.Request') -> str: return (' None: self.ready_replicas: List[str] = [] + def __init_subclass__(cls, name: str, default: bool = False): + LB_POLICIES[name] = cls + if default: + global DEFAULT_LB_POLICY + assert DEFAULT_LB_POLICY is None, ( + 'Only one policy can be default.') + DEFAULT_LB_POLICY = name + + @classmethod + def make(cls, policy_name: Optional[str] = None) -> 'LoadBalancingPolicy': + """Create a load balancing policy from a name.""" + if policy_name is None: + policy_name = DEFAULT_LB_POLICY + + if policy_name not in LB_POLICIES: + raise ValueError(f'Unknown load balancing policy: {policy_name}') + return LB_POLICIES[policy_name]() + def set_ready_replicas(self, ready_replicas: List[str]) -> None: raise NotImplementedError @@ -44,7 +66,7 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: raise NotImplementedError -class RoundRobinPolicy(LoadBalancingPolicy): +class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True): """Round-robin load balancing policy.""" def __init__(self) -> None: diff --git a/sky/serve/serve_utils.py b/sky/serve/serve_utils.py index 6e7b6f6eb4a..6ab932f278a 100644 --- a/sky/serve/serve_utils.py +++ b/sky/serve/serve_utils.py @@ -46,8 +46,14 @@ constants.CONTROLLER_MEMORY_USAGE_GB) _CONTROLLER_URL = 'http://localhost:{CONTROLLER_PORT}' -_SKYPILOT_PROVISION_LOG_PATTERN = r'.*tail -n100 -f (.*provision\.log).*' -_SKYPILOT_LOG_PATTERN = r'.*tail -n100 -f (.*\.log).*' +# NOTE(dev): We assume log paths are either in ~/sky_logs/... or ~/.sky/... +# and always appear after a space. Be careful when changing UX as this +# assumption is used to expand some log files while ignoring others. +_SKYPILOT_LOG_DIRS = r'~/(sky_logs|\.sky)' +_SKYPILOT_PROVISION_LOG_PATTERN = ( + fr'.* ({_SKYPILOT_LOG_DIRS}/.*provision\.log)') +_SKYPILOT_LOG_PATTERN = fr'.* ({_SKYPILOT_LOG_DIRS}/.*\.log)' + # TODO(tian): Find all existing replica id and print here. _FAILED_TO_FIND_REPLICA_MSG = ( f'{colorama.Fore.RED}Failed to find replica ' @@ -591,16 +597,27 @@ def get_latest_version_with_min_replicas( return active_versions[-1] if active_versions else None -def _follow_replica_logs( - file: TextIO, - cluster_name: str, - *, - finish_stream: Callable[[], bool], - exit_if_stream_end: bool = False, - no_new_content_timeout: Optional[int] = None) -> Iterator[str]: - line = '' - log_file = None - no_new_content_cnt = 0 +def _follow_logs_with_provision_expanding( + file: TextIO, + cluster_name: str, + *, + should_stop: Callable[[], bool], + stop_on_eof: bool = False, + idle_timeout_seconds: Optional[int] = None, +) -> Iterator[str]: + """Follows logs and expands any provision.log references found. + + Args: + file: Log file to read from. + cluster_name: Name of the cluster being launched. + should_stop: Callback that returns True when streaming should stop. + stop_on_eof: If True, stop when reaching end of file. + idle_timeout_seconds: If set, stop after these many seconds without + new content. + + Yields: + Log lines, including expanded content from referenced provision logs. + """ def cluster_is_up() -> bool: cluster_record = global_user_state.get_cluster_from_name(cluster_name) @@ -608,51 +625,51 @@ def cluster_is_up() -> bool: return False return cluster_record['status'] == status_lib.ClusterStatus.UP - while True: - tmp = file.readline() - if tmp is not None and tmp != '': - no_new_content_cnt = 0 - line += tmp - if '\n' in line or '\r' in line: - # Tailing detailed progress for user. All logs in skypilot is - # of format `To view detailed progress: tail -n100 -f *.log`. - x = re.match(_SKYPILOT_PROVISION_LOG_PATTERN, line) - if x is not None: - log_file = os.path.expanduser(x.group(1)) - elif re.match(_SKYPILOT_LOG_PATTERN, line) is None: - # Not print other logs (file sync logs) since we lack - # utility to determine when these log files are finished - # writing. - # TODO(tian): Not skip these logs since there are small - # chance that error will happen in file sync. Need to find - # a better way to do this. - yield line - # Output next line first since it indicates the process is - # starting. For our launching logs, it's always: - # Launching on () - if log_file is not None: - with open(log_file, 'r', newline='', - encoding='utf-8') as f: - # We still exit if more than 10 seconds without new - # content to avoid any internal bug that causes - # the launch failed and cluster status remains INIT. - for l in _follow_replica_logs( - f, - cluster_name, - finish_stream=cluster_is_up, - exit_if_stream_end=exit_if_stream_end, - no_new_content_timeout=10): - yield l - log_file = None - line = '' - else: - if exit_if_stream_end or finish_stream(): - break - if no_new_content_timeout is not None: - if no_new_content_cnt >= no_new_content_timeout: - break - no_new_content_cnt += 1 - time.sleep(1) + def process_line(line: str) -> Iterator[str]: + # The line might be directing users to view logs, like + # `✓ Cluster launched: new-http. View logs at: *.log` + # We should tail the detailed logs for user. + provision_log_prompt = re.match(_SKYPILOT_PROVISION_LOG_PATTERN, line) + log_prompt = re.match(_SKYPILOT_LOG_PATTERN, line) + + if provision_log_prompt is not None: + nested_log_path = os.path.expanduser(provision_log_prompt.group(1)) + + try: + with open(nested_log_path, 'r', newline='', + encoding='utf-8') as f: + # We still exit if more than 10 seconds without new content + # to avoid any internal bug that causes the launch to fail + # while cluster status remains INIT. + yield from log_utils.follow_logs(f, + should_stop=cluster_is_up, + stop_on_eof=stop_on_eof, + idle_timeout_seconds=10) + except FileNotFoundError: + yield line + + yield (f'{colorama.Fore.YELLOW}{colorama.Style.BRIGHT}' + f'Try to expand log file {nested_log_path} but not ' + f'found. Skipping...{colorama.Style.RESET_ALL}') + pass + return + + if log_prompt is not None: + # Now we skip other logs (file sync logs) since we lack + # utility to determine when these log files are finished + # writing. + # TODO(tian): We should not skip these logs since there are + # small chance that error will happen in file sync. Need to + # find a better way to do this. + return + + yield line + + return log_utils.follow_logs(file, + should_stop=should_stop, + stop_on_eof=stop_on_eof, + process_line=process_line, + idle_timeout_seconds=idle_timeout_seconds) def stream_replica_logs(service_name: str, replica_id: int, @@ -687,14 +704,17 @@ def _get_replica_status() -> serve_state.ReplicaStatus: raise ValueError( _FAILED_TO_FIND_REPLICA_MSG.format(replica_id=replica_id)) - finish_stream = ( + replica_provisioned = ( lambda: _get_replica_status() != serve_state.ReplicaStatus.PROVISIONING) with open(launch_log_file_name, 'r', newline='', encoding='utf-8') as f: - for line in _follow_replica_logs(f, - replica_cluster_name, - finish_stream=finish_stream, - exit_if_stream_end=not follow): + for line in _follow_logs_with_provision_expanding( + f, + replica_cluster_name, + should_stop=replica_provisioned, + stop_on_eof=not follow, + ): print(line, end='', flush=True) + if (not follow and _get_replica_status() == serve_state.ReplicaStatus.PROVISIONING): # Early exit if not following the logs. @@ -719,22 +739,6 @@ def _get_replica_status() -> serve_state.ReplicaStatus: return '' -def _follow_logs(file: TextIO, *, finish_stream: Callable[[], bool], - exit_if_stream_end: bool) -> Iterator[str]: - line = '' - while True: - tmp = file.readline() - if tmp is not None and tmp != '': - line += tmp - if '\n' in line or '\r' in line: - yield line - line = '' - else: - if exit_if_stream_end or finish_stream(): - break - time.sleep(1) - - def stream_serve_process_logs(service_name: str, stream_controller: bool, follow: bool) -> str: msg = check_service_status_healthy(service_name) @@ -753,9 +757,11 @@ def _service_is_terminal() -> bool: with open(os.path.expanduser(log_file), 'r', newline='', encoding='utf-8') as f: - for line in _follow_logs(f, - finish_stream=_service_is_terminal, - exit_if_stream_end=not follow): + for line in log_utils.follow_logs( + f, + should_stop=_service_is_terminal, + stop_on_eof=not follow, + ): print(line, end='', flush=True) return '' diff --git a/sky/serve/service.py b/sky/serve/service.py index 956a4839a87..0a1c7f34766 100644 --- a/sky/serve/service.py +++ b/sky/serve/service.py @@ -219,6 +219,9 @@ def _get_host(): load_balancer_port = common_utils.find_free_port( constants.LOAD_BALANCER_PORT_START) + # Extract the load balancing policy from the service spec + policy_name = service_spec.load_balancing_policy + # Start the load balancer. # TODO(tian): Probably we could enable multiple ports specified in # service spec and we could start multiple load balancers. @@ -227,7 +230,7 @@ def _get_host(): target=ux_utils.RedirectOutputForProcess( load_balancer.run_load_balancer, load_balancer_log_file).run, - args=(controller_addr, load_balancer_port)) + args=(controller_addr, load_balancer_port, policy_name)) load_balancer_process.start() serve_state.set_service_load_balancer_port(service_name, load_balancer_port) diff --git a/sky/serve/service_spec.py b/sky/serve/service_spec.py index 2eff6f40a9d..000eed139f1 100644 --- a/sky/serve/service_spec.py +++ b/sky/serve/service_spec.py @@ -6,6 +6,7 @@ import yaml +from sky import serve from sky.serve import constants from sky.utils import common_utils from sky.utils import schemas @@ -29,6 +30,7 @@ def __init__( base_ondemand_fallback_replicas: Optional[int] = None, upscale_delay_seconds: Optional[int] = None, downscale_delay_seconds: Optional[int] = None, + load_balancing_policy: Optional[str] = None, ) -> None: if max_replicas is not None and max_replicas < min_replicas: with ux_utils.print_exception_no_traceback(): @@ -55,6 +57,13 @@ def __init__( raise ValueError('readiness_path must start with a slash (/). ' f'Got: {readiness_path}') + # Add the check for unknown load balancing policies + if (load_balancing_policy is not None and + load_balancing_policy not in serve.LB_POLICIES): + with ux_utils.print_exception_no_traceback(): + raise ValueError( + f'Unknown load balancing policy: {load_balancing_policy}. ' + f'Available policies: {list(serve.LB_POLICIES.keys())}') self._readiness_path: str = readiness_path self._initial_delay_seconds: int = initial_delay_seconds self._readiness_timeout_seconds: int = readiness_timeout_seconds @@ -69,6 +78,7 @@ def __init__( int] = base_ondemand_fallback_replicas self._upscale_delay_seconds: Optional[int] = upscale_delay_seconds self._downscale_delay_seconds: Optional[int] = downscale_delay_seconds + self._load_balancing_policy: Optional[str] = load_balancing_policy self._use_ondemand_fallback: bool = ( self.dynamic_ondemand_fallback is not None and @@ -150,6 +160,8 @@ def from_yaml_config(config: Dict[str, Any]) -> 'SkyServiceSpec': service_config['dynamic_ondemand_fallback'] = policy_section.get( 'dynamic_ondemand_fallback', None) + service_config['load_balancing_policy'] = config.get( + 'load_balancing_policy', None) return SkyServiceSpec(**service_config) @staticmethod @@ -205,6 +217,8 @@ def add_if_not_none(section, key, value, no_empty: bool = False): self.upscale_delay_seconds) add_if_not_none('replica_policy', 'downscale_delay_seconds', self.downscale_delay_seconds) + add_if_not_none('load_balancing_policy', None, + self._load_balancing_policy) return config def probe_str(self): @@ -256,6 +270,7 @@ def __repr__(self) -> str: Readiness probe timeout seconds: {self.readiness_timeout_seconds} Replica autoscaling policy: {self.autoscaling_policy_str()} Spot Policy: {self.spot_policy_str()} + Load Balancing Policy: {self.load_balancing_policy} """) @property @@ -310,3 +325,7 @@ def downscale_delay_seconds(self) -> Optional[int]: @property def use_ondemand_fallback(self) -> bool: return self._use_ondemand_fallback + + @property + def load_balancing_policy(self) -> Optional[str]: + return self._load_balancing_policy diff --git a/sky/setup_files/MANIFEST.in b/sky/setup_files/MANIFEST.in index 0cd93f485e0..ea5ceb50cfb 100644 --- a/sky/setup_files/MANIFEST.in +++ b/sky/setup_files/MANIFEST.in @@ -6,7 +6,6 @@ include sky/setup_files/* include sky/skylet/*.sh include sky/skylet/LICENSE include sky/skylet/providers/ibm/* -include sky/skylet/providers/oci/* include sky/skylet/providers/scp/* include sky/skylet/providers/*.py include sky/skylet/ray_patches/*.patch diff --git a/sky/setup_files/dependencies.py b/sky/setup_files/dependencies.py new file mode 100644 index 00000000000..18d2f5cdc08 --- /dev/null +++ b/sky/setup_files/dependencies.py @@ -0,0 +1,141 @@ +"""Dependencies for SkyPilot. + +This file is imported by setup.py, so: +- It may not be able to import other skypilot modules, since sys.path may not be + correct. +- It should not import any dependencies, as they may not be installed yet. +""" +from typing import Dict, List + +install_requires = [ + 'wheel', + 'cachetools', + # NOTE: ray requires click>=7.0. + 'click >= 7.0', + 'colorama', + 'cryptography', + # Jinja has a bug in older versions because of the lack of pinning + # the version of the underlying markupsafe package. See: + # https://github.com/pallets/jinja/issues/1585 + 'jinja2 >= 3.0', + 'jsonschema', + 'networkx', + 'pandas>=1.3.0', + 'pendulum', + # PrettyTable with version >=2.0.0 is required for the support of + # `add_rows` method. + 'PrettyTable >= 2.0.0', + 'python-dotenv', + 'rich', + 'tabulate', + # Light weight requirement, can be replaced with "typing" once + # we deprecate Python 3.7 (this will take a while). + 'typing_extensions', + 'filelock >= 3.6.0', + 'packaging', + 'psutil', + 'pulp', + # Cython 3.0 release breaks PyYAML 5.4.* + # (https://github.com/yaml/pyyaml/issues/601) + # <= 3.13 may encounter https://github.com/ultralytics/yolov5/issues/414 + 'pyyaml > 3.13, != 5.4.*', + 'requests', +] + +local_ray = [ + # Lower version of ray will cause dependency conflict for + # click/grpcio/protobuf. + # Excluded 2.6.0 as it has a bug in the cluster launcher: + # https://github.com/ray-project/ray/releases/tag/ray-2.6.1 + 'ray[default] >= 2.2.0, != 2.6.0', +] + +remote = [ + # Adopted from ray's setup.py: + # https://github.com/ray-project/ray/blob/ray-2.4.0/python/setup.py + # SkyPilot: != 1.48.0 is required to avoid the error where ray dashboard + # fails to start when ray start is called (#2054). + # Tracking issue: https://github.com/ray-project/ray/issues/30984 + 'grpcio >= 1.32.0, <= 1.49.1, != 1.48.0; python_version < \'3.10\' and sys_platform == \'darwin\'', # noqa:E501 pylint: disable=line-too-long + 'grpcio >= 1.42.0, <= 1.49.1, != 1.48.0; python_version >= \'3.10\' and sys_platform == \'darwin\'', # noqa:E501 pylint: disable=line-too-long + # Original issue: https://github.com/ray-project/ray/issues/33833 + 'grpcio >= 1.32.0, <= 1.51.3, != 1.48.0; python_version < \'3.10\' and sys_platform != \'darwin\'', # noqa:E501 pylint: disable=line-too-long + 'grpcio >= 1.42.0, <= 1.51.3, != 1.48.0; python_version >= \'3.10\' and sys_platform != \'darwin\'', # noqa:E501 pylint: disable=line-too-long + # Adopted from ray's setup.py: + # https://github.com/ray-project/ray/blob/ray-2.9.3/python/setup.py#L343 + 'protobuf >= 3.15.3, != 3.19.5', + # Some pydantic versions are not compatible with ray. Adopted from ray's + # setup.py: + # https://github.com/ray-project/ray/blob/ray-2.9.3/python/setup.py#L254 + 'pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3', +] + +# NOTE: Change the templates/jobs-controller.yaml.j2 file if any of the +# following packages dependencies are changed. +aws_dependencies = [ + # botocore does not work with urllib3>=2.0.0, according to + # https://github.com/boto/botocore/issues/2926 + # We have to explicitly pin the version to optimize the time for + # poetry install. See https://github.com/orgs/python-poetry/discussions/7937 + 'urllib3<2', + # NOTE: this installs CLI V1. To use AWS SSO (e.g., `aws sso login`), users + # should instead use CLI V2 which is not pip-installable. See + # https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html. + 'awscli>=1.27.10', + 'botocore>=1.29.10', + 'boto3>=1.26.1', + # NOTE: required by awscli. To avoid ray automatically installing + # the latest version. + 'colorama < 0.4.5', +] + +# azure-cli cannot be installed normally by uv, so we need to work around it in +# a few places. +AZURE_CLI = 'azure-cli>=2.65.0' + +extras_require: Dict[str, List[str]] = { + 'aws': aws_dependencies, + # TODO(zongheng): azure-cli is huge and takes a long time to install. + # Tracked in: https://github.com/Azure/azure-cli/issues/7387 + # azure-identity is needed in node_provider. + # We need azure-identity>=1.13.0 to enable the customization of the + # timeout of AzureCliCredential. + 'azure': [ + AZURE_CLI, + 'azure-core>=1.31.0', + 'azure-identity>=1.19.0', + 'azure-mgmt-network>=27.0.0', + 'azure-mgmt-compute>=33.0.0', + 'azure-storage-blob>=12.23.1', + 'msgraph-sdk', + ] + local_ray, + # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' + # parameter for stopping instances. Reference: + # https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6 + 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage'], + 'ibm': [ + 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk' + ] + local_ray, + 'docker': ['docker'] + local_ray, + 'lambda': local_ray, + 'cloudflare': aws_dependencies, + 'scp': local_ray, + 'oci': ['oci'] + local_ray, + 'kubernetes': ['kubernetes>=20.0.0'], + 'remote': remote, + 'runpod': ['runpod>=1.5.1'], + 'fluidstack': [], # No dependencies needed for fluidstack + 'cudo': ['cudo-compute>=0.1.10'], + 'paperspace': [], # No dependencies needed for paperspace + 'vsphere': [ + 'pyvmomi==8.0.1.0.2', + # vsphere-automation-sdk is also required, but it does not have + # pypi release, which cause failure of our pypi release. + # https://peps.python.org/pep-0440/#direct-references + # We have the instruction for its installation in our + # docs instead. + # 'vsphere-automation-sdk @ git+https://github.com/vmware/vsphere-automation-sdk-python.git@v8.0.1.0' pylint: disable=line-too-long + ], +} + +extras_require['all'] = sum(extras_require.values(), []) diff --git a/sky/setup_files/setup.py b/sky/setup_files/setup.py index 0fd6978ec03..121f96d8e8b 100644 --- a/sky/setup_files/setup.py +++ b/sky/setup_files/setup.py @@ -18,19 +18,28 @@ import os import platform import re +import runpy import subprocess import sys -from typing import Dict, List import setuptools +# __file__ is setup.py at the root of the repo. We shouldn't assume it's a +# symlink - e.g. in the sdist it's resolved to a normal file. ROOT_DIR = os.path.dirname(__file__) +DEPENDENCIES_FILE_PATH = os.path.join(ROOT_DIR, 'sky', 'setup_files', + 'dependencies.py') INIT_FILE_PATH = os.path.join(ROOT_DIR, 'sky', '__init__.py') _COMMIT_FAILURE_MESSAGE = ( 'WARNING: SkyPilot fail to {verb} the commit hash in ' f'{INIT_FILE_PATH!r} (SkyPilot can still be normally used): ' '{error}') +# setuptools does not include the script dir on the search path, so we can't +# just do `import dependencies`. Instead, use runpy to manually load it. Note: +# dependencies here is a dict, not a module, so we access it by subscripting. +dependencies = runpy.run_path(DEPENDENCIES_FILE_PATH) + original_init_content = None system = platform.system() @@ -130,127 +139,6 @@ def parse_readme(readme: str) -> str: return readme -install_requires = [ - 'wheel', - 'cachetools', - # NOTE: ray requires click>=7.0. - 'click >= 7.0', - 'colorama', - 'cryptography', - # Jinja has a bug in older versions because of the lack of pinning - # the version of the underlying markupsafe package. See: - # https://github.com/pallets/jinja/issues/1585 - 'jinja2 >= 3.0', - 'jsonschema', - 'networkx', - 'pandas>=1.3.0', - 'pendulum', - # PrettyTable with version >=2.0.0 is required for the support of - # `add_rows` method. - 'PrettyTable >= 2.0.0', - 'python-dotenv', - 'rich', - 'tabulate', - # Light weight requirement, can be replaced with "typing" once - # we deprecate Python 3.7 (this will take a while). - 'typing_extensions', - 'filelock >= 3.6.0', - 'packaging', - 'psutil', - 'pulp', - # Cython 3.0 release breaks PyYAML 5.4.* (https://github.com/yaml/pyyaml/issues/601) - # <= 3.13 may encounter https://github.com/ultralytics/yolov5/issues/414 - 'pyyaml > 3.13, != 5.4.*', - 'requests', -] - -local_ray = [ - # Lower version of ray will cause dependency conflict for - # click/grpcio/protobuf. - # Excluded 2.6.0 as it has a bug in the cluster launcher: - # https://github.com/ray-project/ray/releases/tag/ray-2.6.1 - 'ray[default] >= 2.2.0, != 2.6.0', -] - -remote = [ - # Adopted from ray's setup.py: https://github.com/ray-project/ray/blob/ray-2.4.0/python/setup.py - # SkyPilot: != 1.48.0 is required to avoid the error where ray dashboard fails to start when - # ray start is called (#2054). - # Tracking issue: https://github.com/ray-project/ray/issues/30984 - "grpcio >= 1.32.0, <= 1.49.1, != 1.48.0; python_version < '3.10' and sys_platform == 'darwin'", # noqa:E501 - "grpcio >= 1.42.0, <= 1.49.1, != 1.48.0; python_version >= '3.10' and sys_platform == 'darwin'", # noqa:E501 - # Original issue: https://github.com/ray-project/ray/issues/33833 - "grpcio >= 1.32.0, <= 1.51.3, != 1.48.0; python_version < '3.10' and sys_platform != 'darwin'", # noqa:E501 - "grpcio >= 1.42.0, <= 1.51.3, != 1.48.0; python_version >= '3.10' and sys_platform != 'darwin'", # noqa:E501 - # Adopted from ray's setup.py: - # https://github.com/ray-project/ray/blob/ray-2.9.3/python/setup.py#L343 - 'protobuf >= 3.15.3, != 3.19.5', - # Some pydantic versions are not compatible with ray. Adopted from ray's - # setup.py: https://github.com/ray-project/ray/blob/ray-2.9.3/python/setup.py#L254 - 'pydantic!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3', -] - -# NOTE: Change the templates/jobs-controller.yaml.j2 file if any of the -# following packages dependencies are changed. -aws_dependencies = [ - # botocore does not work with urllib3>=2.0.0, according to https://github.com/boto/botocore/issues/2926 - # We have to explicitly pin the version to optimize the time for - # poetry install. See https://github.com/orgs/python-poetry/discussions/7937 - 'urllib3<2', - # NOTE: this installs CLI V1. To use AWS SSO (e.g., `aws sso login`), users - # should instead use CLI V2 which is not pip-installable. See - # https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html. - 'awscli>=1.27.10', - 'botocore>=1.29.10', - 'boto3>=1.26.1', - # NOTE: required by awscli. To avoid ray automatically installing - # the latest version. - 'colorama < 0.4.5', -] - -extras_require: Dict[str, List[str]] = { - 'aws': aws_dependencies, - # TODO(zongheng): azure-cli is huge and takes a long time to install. - # Tracked in: https://github.com/Azure/azure-cli/issues/7387 - # azure-identity is needed in node_provider. - # We need azure-identity>=1.13.0 to enable the customization of the - # timeout of AzureCliCredential. - 'azure': [ - 'azure-cli>=2.65.0', 'azure-core>=1.31.0', 'azure-identity>=1.19.0', - 'azure-mgmt-network>=27.0.0', 'azure-mgmt-compute>=33.0.0', - 'azure-storage-blob>=12.23.1', 'msgraph-sdk' - ] + local_ray, - # We need google-api-python-client>=2.69.0 to enable 'discardLocalSsd' - # parameter for stopping instances. - # Reference: https://github.com/googleapis/google-api-python-client/commit/f6e9d3869ed605b06f7cbf2e8cf2db25108506e6 - 'gcp': ['google-api-python-client>=2.69.0', 'google-cloud-storage'], - 'ibm': [ - 'ibm-cloud-sdk-core', 'ibm-vpc', 'ibm-platform-services', 'ibm-cos-sdk' - ] + local_ray, - 'docker': ['docker'] + local_ray, - 'lambda': local_ray, - 'cloudflare': aws_dependencies, - 'scp': local_ray, - 'oci': ['oci'] + local_ray, - 'kubernetes': ['kubernetes>=20.0.0'], - 'remote': remote, - 'runpod': ['runpod>=1.5.1'], - 'fluidstack': [], # No dependencies needed for fluidstack - 'cudo': ['cudo-compute>=0.1.10'], - 'paperspace': [], # No dependencies needed for paperspace - 'vsphere': [ - 'pyvmomi==8.0.1.0.2', - # vsphere-automation-sdk is also required, but it does not have - # pypi release, which cause failure of our pypi release. - # https://peps.python.org/pep-0440/#direct-references - # We have the instruction for its installation in our - # docs instead. - # 'vsphere-automation-sdk @ git+https://github.com/vmware/vsphere-automation-sdk-python.git@v8.0.1.0' - ], -} - -extras_require['all'] = sum(extras_require.values(), []) - long_description = '' readme_filepath = 'README.md' # When sky/backends/wheel_utils.py builds wheels, it will not contain the @@ -277,8 +165,8 @@ def parse_readme(readme: str) -> str: long_description_content_type='text/markdown', setup_requires=['wheel'], requires_python='>=3.7', - install_requires=install_requires, - extras_require=extras_require, + install_requires=dependencies['install_requires'], + extras_require=dependencies['extras_require'], entry_points={ 'console_scripts': ['sky = sky.cli:cli'], }, diff --git a/sky/skylet/constants.py b/sky/skylet/constants.py index a9b8013cad7..0b2a5b08e1b 100644 --- a/sky/skylet/constants.py +++ b/sky/skylet/constants.py @@ -4,6 +4,7 @@ from packaging import version import sky +from sky.setup_files import dependencies SKY_LOGS_DIRECTORY = '~/sky_logs' SKY_REMOTE_WORKDIR = '~/sky_workdir' @@ -39,6 +40,8 @@ 'which python3') # Python executable, e.g., /opt/conda/bin/python3 SKY_PYTHON_CMD = f'$({SKY_GET_PYTHON_PATH_CMD})' +# Prefer SKY_UV_PIP_CMD, which is faster. +# TODO(cooperc): remove remaining usage (GCP TPU setup). SKY_PIP_CMD = f'{SKY_PYTHON_CMD} -m pip' # Ray executable, e.g., /opt/conda/bin/ray # We need to add SKY_PYTHON_CMD before ray executable because: @@ -50,6 +53,14 @@ SKY_REMOTE_PYTHON_ENV_NAME = 'skypilot-runtime' SKY_REMOTE_PYTHON_ENV = f'~/{SKY_REMOTE_PYTHON_ENV_NAME}' ACTIVATE_SKY_REMOTE_PYTHON_ENV = f'source {SKY_REMOTE_PYTHON_ENV}/bin/activate' +# uv is used for venv and pip, much faster than python implementations. +SKY_UV_INSTALL_DIR = '"$HOME/.local/bin"' +SKY_UV_CMD = f'{SKY_UV_INSTALL_DIR}/uv' +# This won't reinstall uv if it's already installed, so it's safe to re-run. +SKY_UV_INSTALL_CMD = (f'{SKY_UV_CMD} -V >/dev/null 2>&1 || ' + 'curl -LsSf https://astral.sh/uv/install.sh ' + f'| UV_INSTALL_DIR={SKY_UV_INSTALL_DIR} sh') +SKY_UV_PIP_CMD = f'VIRTUAL_ENV={SKY_REMOTE_PYTHON_ENV} {SKY_UV_CMD} pip' # Deleting the SKY_REMOTE_PYTHON_ENV_NAME from the PATH to deactivate the # environment. `deactivate` command does not work when conda is used. DEACTIVATE_SKY_REMOTE_PYTHON_ENV = ( @@ -75,11 +86,11 @@ # cluster yaml is updated. # # TODO(zongheng,zhanghao): make the upgrading of skylet automatic? -SKYLET_VERSION = '8' +SKYLET_VERSION = '9' # The version of the lib files that skylet/jobs use. Whenever there is an API # change for the job_lib or log_lib, we need to bump this version, so that the # user can be notified to update their SkyPilot version on the remote cluster. -SKYLET_LIB_VERSION = 1 +SKYLET_LIB_VERSION = 2 SKYLET_VERSION_FILE = '~/.sky/skylet_version' # `sky jobs dashboard`-related @@ -148,31 +159,30 @@ 'echo "Creating conda env with Python 3.10" && ' f'conda create -y -n {SKY_REMOTE_PYTHON_ENV_NAME} python=3.10 && ' f'conda activate {SKY_REMOTE_PYTHON_ENV_NAME};' + # Install uv for venv management and pip installation. + f'{SKY_UV_INSTALL_CMD};' # Create a separate conda environment for SkyPilot dependencies. f'[ -d {SKY_REMOTE_PYTHON_ENV} ] || ' # Do NOT use --system-site-packages here, because if users upgrade any # packages in the base env, they interfere with skypilot dependencies. # Reference: https://github.com/skypilot-org/skypilot/issues/4097 - f'{SKY_PYTHON_CMD} -m venv {SKY_REMOTE_PYTHON_ENV};' + # --seed will include pip and setuptools, which are present in venvs created + # with python -m venv. + f'{SKY_UV_CMD} venv --seed {SKY_REMOTE_PYTHON_ENV};' f'echo "$(echo {SKY_REMOTE_PYTHON_ENV})/bin/python" > {SKY_PYTHON_PATH_FILE};' ) _sky_version = str(version.parse(sky.__version__)) RAY_STATUS = f'RAY_ADDRESS=127.0.0.1:{SKY_REMOTE_RAY_PORT} {SKY_RAY_CMD} status' -# Install ray and skypilot on the remote cluster if they are not already -# installed. {var} will be replaced with the actual value in -# backend_utils.write_cluster_config. -RAY_SKYPILOT_INSTALLATION_COMMANDS = ( +RAY_INSTALLATION_COMMANDS = ( + f'{SKY_UV_INSTALL_CMD};' 'mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;' - # Disable the pip version check to avoid the warning message, which makes - # the output hard to read. - 'export PIP_DISABLE_PIP_VERSION_CHECK=1;' # Print the PATH in provision.log to help debug PATH issues. 'echo PATH=$PATH; ' # Install setuptools<=69.5.1 to avoid the issue with the latest setuptools # causing the error: # ImportError: cannot import name 'packaging' from 'pkg_resources'" - f'{SKY_PIP_CMD} install "setuptools<70"; ' + f'{SKY_UV_PIP_CMD} install "setuptools<70"; ' # Backward compatibility for ray upgrade (#3248): do not upgrade ray if the # ray cluster is already running, to avoid the ray cluster being restarted. # @@ -186,10 +196,10 @@ # latest ray port 6380, but those existing cluster launched before #1790 # that has ray cluster on the default port 6379 will be upgraded and # restarted. - f'{SKY_PIP_CMD} list | grep "ray " | ' + f'{SKY_UV_PIP_CMD} list | grep "ray " | ' f'grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null ' f'|| {RAY_STATUS} || ' - f'{SKY_PIP_CMD} install --exists-action w -U ray[default]=={SKY_REMOTE_RAY_VERSION}; ' # pylint: disable=line-too-long + f'{SKY_UV_PIP_CMD} install -U ray[default]=={SKY_REMOTE_RAY_VERSION}; ' # pylint: disable=line-too-long # In some envs, e.g. pip does not have permission to write under /opt/conda # ray package will be installed under ~/.local/bin. If the user's PATH does # not include ~/.local/bin (the pip install will have the output: `WARNING: @@ -202,24 +212,43 @@ # Writes ray path to file if it does not exist or the file is empty. f'[ -s {SKY_RAY_PATH_FILE} ] || ' f'{{ {ACTIVATE_SKY_REMOTE_PYTHON_ENV} && ' - f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ' - # END ray package check and installation - f'{{ {SKY_PIP_CMD} list | grep "skypilot " && ' + f'which ray > {SKY_RAY_PATH_FILE} || exit 1; }}; ') + +SKYPILOT_WHEEL_INSTALLATION_COMMANDS = ( + f'{SKY_UV_INSTALL_CMD};' + f'{{ {SKY_UV_PIP_CMD} list | grep "skypilot " && ' '[ "$(cat ~/.sky/wheels/current_sky_wheel_hash)" == "{sky_wheel_hash}" ]; } || ' # pylint: disable=line-too-long - f'{{ {SKY_PIP_CMD} uninstall skypilot -y; ' - f'{SKY_PIP_CMD} install "$(echo ~/.sky/wheels/{{sky_wheel_hash}}/' + f'{{ {SKY_UV_PIP_CMD} uninstall skypilot; ' + # uv cannot install azure-cli normally, since it depends on pre-release + # packages. Manually install azure-cli with the --prerelease=allow flag + # first. This will allow skypilot to successfully install. See + # https://docs.astral.sh/uv/pip/compatibility/#pre-release-compatibility. + # We don't want to use --prerelease=allow for all packages, because it will + # cause uv to use pre-releases for some other packages that have sufficient + # stable releases. + 'if [ "{cloud}" = "azure" ]; then ' + f'{SKY_UV_PIP_CMD} install --prerelease=allow "{dependencies.AZURE_CLI}";' + 'fi;' + # Install skypilot from wheel + f'{SKY_UV_PIP_CMD} install "$(echo ~/.sky/wheels/{{sky_wheel_hash}}/' f'skypilot-{_sky_version}*.whl)[{{cloud}}, remote]" && ' 'echo "{sky_wheel_hash}" > ~/.sky/wheels/current_sky_wheel_hash || ' - 'exit 1; }; ' - # END SkyPilot package check and installation + 'exit 1; }; ') +# Install ray and skypilot on the remote cluster if they are not already +# installed. {var} will be replaced with the actual value in +# backend_utils.write_cluster_config. +RAY_SKYPILOT_INSTALLATION_COMMANDS = ( + f'{RAY_INSTALLATION_COMMANDS} ' + f'{SKYPILOT_WHEEL_INSTALLATION_COMMANDS} ' # Only patch ray when the ray version is the same as the expected version. # The ray installation above can be skipped due to the existing ray cluster # for backward compatibility. In this case, we should not patch the ray # files. - f'{SKY_PIP_CMD} list | grep "ray " | grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null ' - f'&& {{ {SKY_PYTHON_CMD} -c "from sky.skylet.ray_patches import patch; patch()" ' - '|| exit 1; };') + f'{SKY_UV_PIP_CMD} list | grep "ray " | ' + f'grep {SKY_REMOTE_RAY_VERSION} 2>&1 > /dev/null && ' + f'{{ {SKY_PYTHON_CMD} -c ' + '"from sky.skylet.ray_patches import patch; patch()" || exit 1; }; ') # The name for the environment variable that stores SkyPilot user hash, which # is mainly used to make sure sky commands runs on a VM launched by SkyPilot diff --git a/sky/skylet/job_lib.py b/sky/skylet/job_lib.py index ff1a9184bd6..ef120011496 100644 --- a/sky/skylet/job_lib.py +++ b/sky/skylet/job_lib.py @@ -8,6 +8,7 @@ import os import pathlib import shlex +import signal import sqlite3 import subprocess import time @@ -25,7 +26,12 @@ logger = sky_logging.init_logger(__name__) +_LINUX_NEW_LINE = '\n' _JOB_STATUS_LOCK = '~/.sky/locks/.job_{}.lock' +# JOB_CMD_IDENTIFIER is used for identifying the process retrieved +# with pid is the same driver process to guard against the case where +# the same pid is reused by a different process. +JOB_CMD_IDENTIFIER = 'echo "SKYPILOT_JOB_ID <{}>"' _MAX_PENDING_SUBMIT = 2 @@ -47,6 +53,7 @@ class JobInfoLoc(enum.IntEnum): START_AT = 6 END_AT = 7 RESOURCES = 8 + PID = 9 _DB_PATH = os.path.expanduser('~/.sky/jobs.db') @@ -68,6 +75,16 @@ def create_table(cursor, conn): # If the database is locked, it is OK to continue, as the WAL mode # is not critical and is likely to be enabled by other processes. + # Pid column is used for keeping track of the driver process of a job. It + # can be in three states: + # -1: The job was submitted with SkyPilot older than #4318, where we use + # ray job submit to submit the job, i.e. no pid is recorded. This is for + # backward compatibility and should be removed after 0.10.0. + # 0: The job driver process has never been started. When adding a job with + # INIT state, the pid will be set to 0 (the default -1 value is just for + # backward compatibility). + # >=0: The job has been started. The pid is the driver process's pid. + # The driver can be actually running or finished. cursor.execute("""\ CREATE TABLE IF NOT EXISTS jobs ( job_id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -76,7 +93,10 @@ def create_table(cursor, conn): submitted_at FLOAT, status TEXT, run_timestamp TEXT CANDIDATE KEY, - start_at FLOAT DEFAULT -1)""") + start_at FLOAT DEFAULT -1, + end_at FLOAT DEFAULT NULL, + resources TEXT DEFAULT NULL, + pid INTEGER DEFAULT -1)""") cursor.execute("""CREATE TABLE IF NOT EXISTS pending_jobs( job_id INTEGER, @@ -87,7 +107,8 @@ def create_table(cursor, conn): db_utils.add_column_to_table(cursor, conn, 'jobs', 'end_at', 'FLOAT') db_utils.add_column_to_table(cursor, conn, 'jobs', 'resources', 'TEXT') - + db_utils.add_column_to_table(cursor, conn, 'jobs', 'pid', + 'INTEGER DEFAULT -1') conn.commit() @@ -119,6 +140,11 @@ class JobStatus(enum.Enum): # In the 'jobs' table, the `start_at` column will be set to the current # time, when the job is firstly transitioned to RUNNING. RUNNING = 'RUNNING' + # The job driver process failed. This happens when the job driver process + # finishes when the status in job table is still not set to terminal state. + # We should keep this state before the SUCCEEDED, as our job status update + # relies on the order of the statuses to keep the latest status. + FAILED_DRIVER = 'FAILED_DRIVER' # 3 terminal states below: once reached, they do not transition. # The job finished successfully. SUCCEEDED = 'SUCCEEDED' @@ -149,11 +175,16 @@ def colored_str(self): return f'{color}{self.value}{colorama.Style.RESET_ALL}' -# Only update status of the jobs after this many seconds of job submission, -# to avoid race condition with `ray job` to make sure it job has been -# correctly updated. +# We have two steps for job submissions: +# 1. Client reserve a job id from the job table by adding a INIT state job. +# 2. Client updates the job status to PENDING by actually submitting the job's +# command to the scheduler. +# In normal cases, the two steps happens very close to each other through two +# consecutive SSH connections. +# We should update status for INIT job that has been staying in INIT state for +# a while (60 seconds), which likely fails to reach step 2. # TODO(zhwu): This number should be tuned based on heuristics. -_PENDING_SUBMIT_GRACE_PERIOD = 60 +_INIT_SUBMIT_GRACE_PERIOD = 60 _PRE_RESOURCE_STATUSES = [JobStatus.PENDING] @@ -176,7 +207,39 @@ def _run_job(self, job_id: int, run_cmd: str): _CURSOR.execute((f'UPDATE pending_jobs SET submit={int(time.time())} ' f'WHERE job_id={job_id!r}')) _CONN.commit() - subprocess.Popen(run_cmd, shell=True, stdout=subprocess.DEVNULL) + # Use nohup to ensure the job driver process is a separate process tree, + # instead of being a child of the current process. This is important to + # avoid a chain of driver processes (job driver can call schedule_step() + # to submit new jobs, and the new job can also call schedule_step() + # recursively). + # + # echo $! will output the PID of the last background process started + # in the current shell, so we can retrieve it and record in the DB. + # + # TODO(zhwu): A more elegant solution is to use another daemon process + # to be in charge of starting these driver processes, instead of + # starting them in the current process. + wrapped_cmd = (f'nohup bash -c {shlex.quote(run_cmd)} ' + '/dev/null 2>&1 & echo $!') + proc = subprocess.run(wrapped_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.DEVNULL, + start_new_session=True, + check=True, + shell=True, + text=True) + # Get the PID of the detached process + pid = int(proc.stdout.strip()) + + # TODO(zhwu): Backward compatibility, remove this check after 0.10.0. + # This is for the case where the job is submitted with SkyPilot older + # than #4318, using ray job submit. + if 'job submit' in run_cmd: + pid = -1 + _CURSOR.execute((f'UPDATE jobs SET pid={pid} ' + f'WHERE job_id={job_id!r}')) + _CONN.commit() def schedule_step(self, force_update_jobs: bool = False) -> None: if force_update_jobs: @@ -243,59 +306,13 @@ def _get_pending_job_ids(self) -> List[int]: JobStatus.SETTING_UP: colorama.Fore.BLUE, JobStatus.PENDING: colorama.Fore.BLUE, JobStatus.RUNNING: colorama.Fore.GREEN, + JobStatus.FAILED_DRIVER: colorama.Fore.RED, JobStatus.SUCCEEDED: colorama.Fore.GREEN, JobStatus.FAILED: colorama.Fore.RED, JobStatus.FAILED_SETUP: colorama.Fore.RED, JobStatus.CANCELLED: colorama.Fore.YELLOW, } -_RAY_TO_JOB_STATUS_MAP = { - # These are intentionally set this way, because: - # 1. when the ray status indicates the job is PENDING the generated - # python program has been `ray job submit` from the job queue - # and is now PENDING - # 2. when the ray status indicates the job is RUNNING the job can be in - # setup or resources may not be allocated yet, i.e. the job should be - # PENDING. - # For case 2, update_job_status() would compare this mapped PENDING to - # the status in our jobs DB and take the max. This is because the job's - # generated ray program is the only place that can determine a job has - # reserved resources and actually started running: it will set the - # status in the DB to SETTING_UP or RUNNING. - # If there is no setup specified in the task, as soon as it is started - # (ray's status becomes RUNNING), i.e. it will be very rare that the job - # will be set to SETTING_UP by the update_job_status, as our generated - # ray program will set the status to PENDING immediately. - 'PENDING': JobStatus.PENDING, - 'RUNNING': JobStatus.PENDING, - 'SUCCEEDED': JobStatus.SUCCEEDED, - 'FAILED': JobStatus.FAILED, - 'STOPPED': JobStatus.CANCELLED, -} - - -def _create_ray_job_submission_client(): - """Import the ray job submission client.""" - try: - import ray # pylint: disable=import-outside-toplevel - except ImportError: - logger.error('Failed to import ray') - raise - try: - # pylint: disable=import-outside-toplevel - from ray import job_submission - except ImportError: - logger.error( - f'Failed to import job_submission with ray=={ray.__version__}') - raise - port = get_job_submission_port() - return job_submission.JobSubmissionClient( - address=f'http://127.0.0.1:{port}') - - -def make_ray_job_id(sky_job_id: int) -> str: - return f'{sky_job_id}-{getpass.getuser()}' - def make_job_command_with_user_switching(username: str, command: str) -> List[str]: @@ -307,9 +324,10 @@ def add_job(job_name: str, username: str, run_timestamp: str, """Atomically reserve the next available job id for the user.""" job_submitted_at = time.time() # job_id will autoincrement with the null value - _CURSOR.execute('INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?)', - (job_name, username, job_submitted_at, JobStatus.INIT.value, - run_timestamp, None, resources_str)) + _CURSOR.execute( + 'INSERT INTO jobs VALUES (null, ?, ?, ?, ?, ?, ?, null, ?, 0)', + (job_name, username, job_submitted_at, JobStatus.INIT.value, + run_timestamp, None, resources_str)) _CONN.commit() rows = _CURSOR.execute('SELECT job_id FROM jobs WHERE run_timestamp=(?)', (run_timestamp,)) @@ -484,6 +502,7 @@ def _get_records_from_rows(rows) -> List[Dict[str, Any]]: 'start_at': row[JobInfoLoc.START_AT.value], 'end_at': row[JobInfoLoc.END_AT.value], 'resources': row[JobInfoLoc.RESOURCES.value], + 'pid': row[JobInfoLoc.PID.value], }) return records @@ -543,6 +562,23 @@ def _get_pending_job(job_id: int) -> Optional[Dict[str, Any]]: return None +def _is_job_driver_process_running(job_pid: int, job_id: int) -> bool: + """Check if the job driver process is running. + + We check the cmdline to avoid the case where the same pid is reused by a + different process. + """ + if job_pid <= 0: + return False + try: + job_process = psutil.Process(job_pid) + return job_process.is_running() and any( + JOB_CMD_IDENTIFIER.format(job_id) in line + for line in job_process.cmdline()) + except psutil.NoSuchProcess: + return False + + def update_job_status(job_ids: List[int], silent: bool = False) -> List[JobStatus]: """Updates and returns the job statuses matching our `JobStatus` semantics. @@ -560,11 +596,8 @@ def update_job_status(job_ids: List[int], if len(job_ids) == 0: return [] - ray_job_ids = [make_ray_job_id(job_id) for job_id in job_ids] - job_client = _create_ray_job_submission_client() - statuses = [] - for job_id, ray_job_id in zip(job_ids, ray_job_ids): + for job_id in job_ids: # Per-job status lock is required because between the job status # query and the job status update, the job status in the databse # can be modified by the generated ray program. @@ -573,11 +606,13 @@ def update_job_status(job_ids: List[int], job_record = _get_jobs_by_ids([job_id])[0] original_status = job_record['status'] job_submitted_at = job_record['submitted_at'] + job_pid = job_record['pid'] - ray_job_query_time = time.time() + pid_query_time = time.time() + failed_driver_transition_message = None if original_status == JobStatus.INIT: if (job_submitted_at >= psutil.boot_time() and job_submitted_at - >= ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): + >= pid_query_time - _INIT_SUBMIT_GRACE_PERIOD): # The job id is reserved, but the job is not submitted yet. # We should keep it in INIT. status = JobStatus.INIT @@ -588,74 +623,98 @@ def update_job_status(job_ids: List[int], # was killed before the job is submitted. We should set it # to FAILED then. Note, if ray job indicates the job is # running, we will change status to PENDING below. - echo(f'INIT job {job_id} is stale, setting to FAILED') - status = JobStatus.FAILED - - try: - # Querying status within the lock is safer than querying - # outside, as it avoids the race condition when job table is - # updated after the ray job status query. - # Also, getting per-job status is faster than querying all jobs, - # when there are significant number of finished jobs. - # Reference: getting 124 finished jobs takes 0.038s, while - # querying a single job takes 0.006s, 10 jobs takes 0.066s. - # TODO: if too slow, directly query against redis. - ray_job_status = job_client.get_job_status(ray_job_id) - status = _RAY_TO_JOB_STATUS_MAP[ray_job_status.value] - except RuntimeError: - # Job not found. - pass + failed_driver_transition_message = ( + f'INIT job {job_id} is stale, setting to FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER + + # job_pid is 0 if the job is not submitted yet. + # job_pid is -1 if the job is submitted with SkyPilot older than + # #4318, using ray job submit. We skip the checking for those + # jobs. + if job_pid > 0: + if _is_job_driver_process_running(job_pid, job_id): + status = JobStatus.PENDING + else: + # By default, if the job driver process does not exist, + # the actual SkyPilot job is one of the following: + # 1. Still pending to be submitted. + # 2. Submitted and finished. + # 3. Driver failed without correctly setting the job + # status in the job table. + # Although we set the status to FAILED_DRIVER, it can be + # overridden to PENDING if the job is not submitted, or + # any other terminal status if the job driver process + # finished correctly. + failed_driver_transition_message = ( + f'Job {job_id} driver process is not running, but ' + 'the job state is not in terminal states, setting ' + 'it to FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER + elif job_pid < 0: + # TODO(zhwu): Backward compatibility, remove after 0.9.0. + # We set the job status to PENDING instead of actually + # checking ray job status and let the status in job table + # take effect in the later max. + status = JobStatus.PENDING pending_job = _get_pending_job(job_id) if pending_job is not None: if pending_job['created_time'] < psutil.boot_time(): - echo(f'Job {job_id} is stale, setting to FAILED: ' - f'created_time={pending_job["created_time"]}, ' - f'boot_time={psutil.boot_time()}') + failed_driver_transition_message = ( + f'Job {job_id} is stale, setting to FAILED_DRIVER: ' + f'created_time={pending_job["created_time"]}, ' + f'boot_time={psutil.boot_time()}') # The job is stale as it is created before the instance # is booted, e.g. the instance is rebooted. - status = JobStatus.FAILED - # Gives a 60 second grace period between job being submit from - # the pending table until appearing in ray jobs. For jobs - # submitted outside of the grace period, we will consider the - # ray job status. - if not (pending_job['submit'] > 0 and pending_job['submit'] < - ray_job_query_time - _PENDING_SUBMIT_GRACE_PERIOD): - # Reset the job status to PENDING even though it may not - # appear in the ray jobs, so that it will not be considered - # as stale. + status = JobStatus.FAILED_DRIVER + elif pending_job['submit'] <= 0: + # The job is not submitted (submit <= 0), we set it to + # PENDING. + # For submitted jobs, the driver should have been started, + # because the job_lib.JobScheduler.schedule_step() have + # the submit field and driver process pid set in the same + # job lock. + # The job process check in the above section should + # correctly figured out the status and we don't overwrite + # it here. (Note: the FAILED_DRIVER status will be + # overridden by the actual job terminal status in the table + # if the job driver process finished correctly.) status = JobStatus.PENDING assert original_status is not None, (job_id, status) if status is None: + # The job is submitted but the job driver process pid is not + # set in the database. This is guarding against the case where + # the schedule_step() function is interrupted (e.g., VM stop) + # at the middle of starting a new process and setting the pid. status = original_status if (original_status is not None and not original_status.is_terminal()): - echo(f'Ray job status for job {job_id} is None, ' - 'setting it to FAILED.') - # The job may be stale, when the instance is restarted - # (the ray redis is volatile). We need to reset the - # status of the task to FAILED if its original status - # is RUNNING or PENDING. - status = JobStatus.FAILED + echo(f'Job {job_id} status is None, setting it to ' + 'FAILED_DRIVER.') + # The job may be stale, when the instance is restarted. We + # need to reset the job status to FAILED_DRIVER if its + # original status is in nonterminal_statuses. + echo(f'Job {job_id} is in a unknown state, setting it to ' + 'FAILED_DRIVER') + status = JobStatus.FAILED_DRIVER _set_status_no_lock(job_id, status) - echo(f'Updated job {job_id} status to {status}') else: # Taking max of the status is necessary because: - # 1. It avoids race condition, where the original status has - # already been set to later state by the job. We skip the - # update. - # 2. _RAY_TO_JOB_STATUS_MAP would map `ray job status`'s - # `RUNNING` to our JobStatus.SETTING_UP; if a job has already - # been set to JobStatus.PENDING or JobStatus.RUNNING by the - # generated ray program, `original_status` (job status from our - # DB) would already have that value. So we take the max here to - # keep it at later status. + # 1. The original status has already been set to later + # terminal state by a finished job driver. + # 2. Job driver process check would map any running job process + # to `PENDING`, so we need to take the max to keep it at + # later status for jobs actually started in SETTING_UP or + # RUNNING. status = max(status, original_status) assert status is not None, (job_id, status, original_status) if status != original_status: # Prevents redundant update. _set_status_no_lock(job_id, status) echo(f'Updated job {job_id} status to {status}') + if (status == JobStatus.FAILED_DRIVER and + failed_driver_transition_message is not None): + echo(failed_driver_transition_message) statuses.append(status) return statuses @@ -668,17 +727,13 @@ def fail_all_jobs_in_progress() -> None: f"""\ UPDATE jobs SET status=(?) WHERE status IN ({','.join(['?'] * len(in_progress_status))}) - """, (JobStatus.FAILED.value, *in_progress_status)) + """, (JobStatus.FAILED_DRIVER.value, *in_progress_status)) _CONN.commit() def update_status() -> None: # This will be called periodically by the skylet to update the status # of the jobs in the database, to avoid stale job status. - # NOTE: there might be a INIT job in the database set to FAILED by this - # function, as the ray job status does not exist due to the app - # not submitted yet. It will be then reset to PENDING / RUNNING when the - # app starts. nonterminal_jobs = _get_jobs(username=None, status_list=JobStatus.nonterminal_statuses()) nonterminal_job_ids = [job['job_id'] for job in nonterminal_jobs] @@ -761,6 +816,31 @@ def load_job_queue(payload: str) -> List[Dict[str, Any]]: return jobs +# TODO(zhwu): Backward compatibility for jobs submitted before #4318, remove +# after 0.10.0. +def _create_ray_job_submission_client(): + """Import the ray job submission client.""" + try: + import ray # pylint: disable=import-outside-toplevel + except ImportError: + logger.error('Failed to import ray') + raise + try: + # pylint: disable=import-outside-toplevel + from ray import job_submission + except ImportError: + logger.error( + f'Failed to import job_submission with ray=={ray.__version__}') + raise + port = get_job_submission_port() + return job_submission.JobSubmissionClient( + address=f'http://127.0.0.1:{port}') + + +def _make_ray_job_id(sky_job_id: int) -> str: + return f'{sky_job_id}-{getpass.getuser()}' + + def cancel_jobs_encoded_results(jobs: Optional[List[int]], cancel_all: bool = False) -> str: """Cancel jobs. @@ -788,28 +868,54 @@ def cancel_jobs_encoded_results(jobs: Optional[List[int]], # Cancel jobs with specified IDs. job_records = _get_jobs_by_ids(jobs) - # TODO(zhwu): `job_client.stop_job` will wait for the jobs to be killed, but - # when the memory is not enough, this will keep waiting. - job_client = _create_ray_job_submission_client() cancelled_ids = [] # Sequentially cancel the jobs to avoid the resource number bug caused by # ray cluster (tracked in #1262). - for job in job_records: - job_id = make_ray_job_id(job['job_id']) + for job_record in job_records: + job_id = job_record['job_id'] # Job is locked to ensure that pending queue does not start it while # it is being cancelled - with filelock.FileLock(_get_lock_path(job['job_id'])): - try: - job_client.stop_job(job_id) - except RuntimeError as e: - # If the request to the job server fails, we should not - # set the job to CANCELLED. - if 'does not exist' not in str(e): - logger.warning(str(e)) - continue - - if job['status'] in [ + with filelock.FileLock(_get_lock_path(job_id)): + job = _get_jobs_by_ids([job_id])[0] + if _is_job_driver_process_running(job['pid'], job_id): + # Not use process.terminate() as that will only terminate the + # process shell process, not the ray driver process + # under the shell. + # + # We don't kill all the children of the process, like + # subprocess_utils.kill_process_daemon() does, but just the + # process group here, because the underlying job driver can + # start other jobs with `schedule_step`, causing the other job + # driver processes to be children of the current job driver + # process. + # + # Killing the process group is enough as the underlying job + # should be able to clean itself up correctly by ray driver. + # + # The process group pid should be the same as the job pid as we + # use start_new_session=True, but we use os.getpgid() to be + # extra cautious. + job_pgid = os.getpgid(job['pid']) + os.killpg(job_pgid, signal.SIGTERM) + # We don't have to start a daemon to forcefully kill the process + # as our job driver process will clean up the underlying + # child processes. + elif job['pid'] < 0: + try: + # TODO(zhwu): Backward compatibility, remove after 0.9.0. + # The job was submitted with ray job submit before #4318. + job_client = _create_ray_job_submission_client() + job_client.stop_job(_make_ray_job_id(job['job_id'])) + except RuntimeError as e: + # If the request to the job server fails, we should not + # set the job to CANCELLED. + if 'does not exist' not in str(e): + logger.warning(str(e)) + continue + # Get the job status again to avoid race condition. + job_status = get_status_no_lock(job['job_id']) + if job_status in [ JobStatus.PENDING, JobStatus.SETTING_UP, JobStatus.RUNNING ]: _set_status_no_lock(job['job_id'], JobStatus.CANCELLED) @@ -868,10 +974,17 @@ def add_job(cls, job_name: Optional[str], username: str, run_timestamp: str, if job_name is None: job_name = '-' code = [ - 'job_id = job_lib.add_job(' - f'{job_name!r}, ' - f'{username!r}, ' - f'{run_timestamp!r}, ' + # We disallow job submission when SKYLET_VERSION is older than 9, as + # it was using ray job submit before #4318, and switched to raw + # process. Using the old skylet version will cause the job status + # to be stuck in PENDING state or transition to FAILED_DRIVER state. + '\nif int(constants.SKYLET_VERSION) < 9: ' + 'raise RuntimeError("SkyPilot runtime is too old, which does not ' + 'support submitting jobs.")', + '\njob_id = job_lib.add_job(' + f'{job_name!r},' + f'{username!r},' + f'{run_timestamp!r},' f'{resources_str!r})', 'print("Job ID: " + str(job_id), flush=True)', ] @@ -879,9 +992,11 @@ def add_job(cls, job_name: Optional[str], username: str, run_timestamp: str, @classmethod def queue_job(cls, job_id: int, cmd: str) -> str: - code = ['job_lib.scheduler.queue(' - f'{job_id!r},' - f'{cmd!r})'] + code = [ + 'job_lib.scheduler.queue(' + f'{job_id!r},' + f'{cmd!r})', + ] return cls._build(code) @classmethod @@ -920,14 +1035,19 @@ def fail_all_jobs_in_progress(cls) -> str: def tail_logs(cls, job_id: Optional[int], managed_job_id: Optional[int], - follow: bool = True) -> str: + follow: bool = True, + tail: int = 0) -> str: # pylint: disable=line-too-long + code = [ + # We use != instead of is not because 1 is not None will print a warning: + # :1: SyntaxWarning: "is not" with a literal. Did you mean "!="? f'job_id = {job_id} if {job_id} != None else job_lib.get_latest_job_id()', 'run_timestamp = job_lib.get_run_timestamp(job_id)', f'log_dir = None if run_timestamp is None else os.path.join({constants.SKY_LOGS_DIRECTORY!r}, run_timestamp)', - f'log_lib.tail_logs(job_id=job_id, log_dir=log_dir, ' - f'managed_job_id={managed_job_id!r}, follow={follow})', + f'tail_log_kwargs = {{"job_id": job_id, "log_dir": log_dir, "managed_job_id": {managed_job_id!r}, "follow": {follow}}}', + f'{_LINUX_NEW_LINE}if getattr(constants, "SKYLET_LIB_VERSION", 1) > 1: tail_log_kwargs["tail"] = {tail}', + f'{_LINUX_NEW_LINE}log_lib.tail_logs(**tail_log_kwargs)', ] return cls._build(code) diff --git a/sky/skylet/log_lib.py b/sky/skylet/log_lib.py index eb64440077e..8a40982972a 100644 --- a/sky/skylet/log_lib.py +++ b/sky/skylet/log_lib.py @@ -2,6 +2,7 @@ This is a remote utility module that provides logging functionality. """ +import collections import copy import io import multiprocessing.pool @@ -12,7 +13,8 @@ import tempfile import textwrap import time -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import (Deque, Dict, Iterable, Iterator, List, Optional, TextIO, + Tuple, Union) import colorama @@ -26,9 +28,14 @@ _SKY_LOG_WAITING_GAP_SECONDS = 1 _SKY_LOG_WAITING_MAX_RETRY = 5 _SKY_LOG_TAILING_GAP_SECONDS = 0.2 +# Peek the head of the lines to check if we need to start +# streaming when tail > 0. +PEEK_HEAD_LINES_FOR_START_STREAM = 20 logger = sky_logging.init_logger(__name__) +LOG_FILE_START_STREAMING_AT = 'Waiting for task resources on ' + class _ProcessingArgs: """Arguments for processing logs.""" @@ -178,40 +185,7 @@ def run_with_log( shell=shell, **kwargs) as proc: try: - # The proc can be defunct if the python program is killed. Here we - # open a new subprocess to gracefully kill the proc, SIGTERM - # and then SIGKILL the process group. - # Adapted from ray/dashboard/modules/job/job_manager.py#L154 - parent_pid = os.getpid() - daemon_script = os.path.join( - os.path.dirname(os.path.abspath(job_lib.__file__)), - 'subprocess_daemon.py') - python_path = subprocess.check_output( - constants.SKY_GET_PYTHON_PATH_CMD, - shell=True, - stderr=subprocess.DEVNULL, - encoding='utf-8').strip() - daemon_cmd = [ - python_path, - daemon_script, - '--parent-pid', - str(parent_pid), - '--proc-pid', - str(proc.pid), - ] - - # We do not need to set `start_new_session=True` here, as the - # daemon script will detach itself from the parent process with - # fork to avoid being killed by ray job. See the reason we - # daemonize the process in `sky/skylet/subprocess_daemon.py`. - subprocess.Popen( - daemon_cmd, - # Suppress output - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - # Disable input - stdin=subprocess.DEVNULL, - ) + subprocess_utils.kill_process_daemon(proc.pid) stdout = '' stderr = '' @@ -315,11 +289,8 @@ def run_bash_command_with_log(bash_command: str, # Need this `-i` option to make sure `source ~/.bashrc` work. inner_command = f'/bin/bash -i {script_path}' - subprocess_cmd: Union[str, List[str]] - subprocess_cmd = inner_command - return run_with_log( - subprocess_cmd, + inner_command, log_path, stream_logs=stream_logs, with_ray=with_ray, @@ -330,6 +301,7 @@ def run_bash_command_with_log(bash_command: str, def _follow_job_logs(file, job_id: int, + start_streaming: bool, start_streaming_at: str = '') -> Iterator[str]: """Yield each line from a file as they are written. @@ -338,7 +310,6 @@ def _follow_job_logs(file, # No need to lock the status here, as the while loop can handle # the older status. status = job_lib.get_status_no_lock(job_id) - start_streaming = False wait_last_logs = True while True: tmp = file.readline() @@ -378,10 +349,45 @@ def _follow_job_logs(file, status = job_lib.get_status_no_lock(job_id) +def _peek_head_lines(log_file: TextIO) -> List[str]: + """Peek the head of the file.""" + lines = [ + log_file.readline() for _ in range(PEEK_HEAD_LINES_FOR_START_STREAM) + ] + # Reset the file pointer to the beginning + log_file.seek(0, os.SEEK_SET) + return [line for line in lines if line] + + +def _should_stream_the_whole_tail_lines(head_lines_of_log_file: List[str], + tail_lines: Deque[str], + start_stream_at: str) -> bool: + """Check if the entire tail lines should be streamed.""" + # See comment: + # https://github.com/skypilot-org/skypilot/pull/4241#discussion_r1833611567 + # for more details. + # Case 1: If start_stream_at is found at the head of the tail lines, + # we should not stream the whole tail lines. + for index, line in enumerate(tail_lines): + if index >= PEEK_HEAD_LINES_FOR_START_STREAM: + break + if start_stream_at in line: + return False + # Case 2: If start_stream_at is found at the head of log file, but not at + # the tail lines, we need to stream the whole tail lines. + for line in head_lines_of_log_file: + if start_stream_at in line: + return True + # Case 3: If start_stream_at is not at the head, and not found at the tail + # lines, we should not stream the whole tail lines. + return False + + def tail_logs(job_id: Optional[int], log_dir: Optional[str], managed_job_id: Optional[int] = None, - follow: bool = True) -> None: + follow: bool = True, + tail: int = 0) -> None: """Tail the logs of a job. Args: @@ -390,6 +396,8 @@ def tail_logs(job_id: Optional[int], managed_job_id: The managed job id (for logging info only to avoid confusion). follow: Whether to follow the logs or print the logs so far and exit. + tail: The number of lines to display from the end of the log file, + if 0, print all lines. """ if job_id is None: # This only happens when job_lib.get_latest_job_id() returns None, @@ -429,7 +437,9 @@ def tail_logs(job_id: Optional[int], time.sleep(_SKY_LOG_WAITING_GAP_SECONDS) status = job_lib.update_job_status([job_id], silent=True)[0] - start_stream_at = 'Waiting for task resources on ' + start_stream_at = LOG_FILE_START_STREAMING_AT + # Explicitly declare the type to avoid mypy warning. + lines: Iterable[str] = [] if follow and status in [ job_lib.JobStatus.SETTING_UP, job_lib.JobStatus.PENDING, @@ -440,18 +450,43 @@ def tail_logs(job_id: Optional[int], with open(log_path, 'r', newline='', encoding='utf-8') as log_file: # Using `_follow` instead of `tail -f` to streaming the whole # log and creating a new process for tail. + start_streaming = False + if tail > 0: + head_lines_of_log_file = _peek_head_lines(log_file) + lines = collections.deque(log_file, maxlen=tail) + start_streaming = _should_stream_the_whole_tail_lines( + head_lines_of_log_file, lines, start_stream_at) + for line in lines: + if start_stream_at in line: + start_streaming = True + if start_streaming: + print(line, end='') + # Flush the last n lines + print(end='', flush=True) + # Now, the cursor is at the end of the last lines + # if tail > 0 for line in _follow_job_logs(log_file, job_id=job_id, + start_streaming=start_streaming, start_streaming_at=start_stream_at): print(line, end='', flush=True) else: try: - start_stream = False - with open(log_path, 'r', encoding='utf-8') as f: - for line in f.readlines(): + start_streaming = False + with open(log_path, 'r', encoding='utf-8') as log_file: + if tail > 0: + # If tail > 0, we need to read the last n lines. + # We use double ended queue to rotate the last n lines. + head_lines_of_log_file = _peek_head_lines(log_file) + lines = collections.deque(log_file, maxlen=tail) + start_streaming = _should_stream_the_whole_tail_lines( + head_lines_of_log_file, lines, start_stream_at) + else: + lines = log_file + for line in lines: if start_stream_at in line: - start_stream = True - if start_stream: + start_streaming = True + if start_streaming: print(line, end='', flush=True) except FileNotFoundError: print(f'{colorama.Fore.RED}ERROR: Logs for job {job_id} (status:' diff --git a/sky/skylet/log_lib.pyi b/sky/skylet/log_lib.pyi index 01b08b6444f..89d1628ec11 100644 --- a/sky/skylet/log_lib.pyi +++ b/sky/skylet/log_lib.pyi @@ -13,6 +13,9 @@ from sky.skylet import constants as constants from sky.skylet import job_lib as job_lib from sky.utils import log_utils as log_utils +LOG_FILE_START_STREAMING_AT: str = ... + + class _ProcessingArgs: log_path: str stream_logs: bool diff --git a/sky/skylet/providers/oci/__init__.py b/sky/skylet/providers/oci/__init__.py deleted file mode 100644 index f7c3aa255ae..00000000000 --- a/sky/skylet/providers/oci/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""OCI node provider""" -from sky.skylet.providers.oci.node_provider import OCINodeProvider diff --git a/sky/skylet/providers/oci/node_provider.py b/sky/skylet/providers/oci/node_provider.py deleted file mode 100644 index 35d4304582b..00000000000 --- a/sky/skylet/providers/oci/node_provider.py +++ /dev/null @@ -1,488 +0,0 @@ -"""OCI Node Provider. - -Node provider is called by the Ray Autoscaler to provision new compute -resources (head / worker nodes). - -To show debug messages, export SKYPILOT_DEBUG=1 - -History: - - Hysun He (hysun.he@oracle.com) @ Apr, 2023: Initial implementation - -""" - -import copy -from datetime import datetime -import logging -import threading -import time - -from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME -from ray.autoscaler.tags import TAG_RAY_LAUNCH_CONFIG -from ray.autoscaler.tags import TAG_RAY_NODE_KIND -from ray.autoscaler.tags import TAG_RAY_USER_NODE_TYPE - -from sky.adaptors import oci as oci_adaptor -from sky.clouds.utils import oci_utils -from sky.skylet.providers.oci import utils -from sky.skylet.providers.oci.query_helper import oci_query_helper - -logger = logging.getLogger(__name__) - - -def synchronized(f): - - def wrapper(self, *args, **kwargs): - self.lock.acquire() - try: - return f(self, *args, **kwargs) - finally: - self.lock.release() - - return wrapper - - -class OCINodeProvider(NodeProvider): - """Node Provider for OracleCloud (OCI).""" - - def __init__(self, provider_config, cluster_name): - NodeProvider.__init__(self, provider_config, cluster_name) - self.lock = threading.RLock() - self.cached_nodes = {} - self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", - True) - self.region = provider_config["region"] - - # Do a read-ahead cache loading to improve performance. - self._get_filtered_nodes({}) - - @synchronized - def _get_filtered_nodes(self, tag_filters, force=False): - # Make sure the cluster_name is always an criterion - tag_filters = {**tag_filters, TAG_RAY_CLUSTER_NAME: self.cluster_name} - - return_nodes = {} - if not force: - # Query cache first to reduce API call. - cache_hit = False - for k, node in self.cached_nodes.items(): - tags = node["tags"] - unmatched_tags = [ - k for k, v in tag_filters.items() - if k not in tags or v != tags[k] - ] - if len(unmatched_tags) == 0: - return_nodes[k] = node - cache_hit |= True - - if cache_hit: - return return_nodes - - insts = oci_query_helper.query_instances_by_tags( - tag_filters, self.region) - for inst in insts: - inst_id = inst.identifier - if inst_id in self.cached_nodes: - del self.cached_nodes[inst_id] - - item = self.get_inst_obj({ - "inst_id": inst_id, - "ad": inst.availability_domain, - "compartment": inst.compartment_id, - "lifecycle_state": inst.lifecycle_state, - "oci_tags": inst.freeform_tags, - }) - return_nodes[inst_id] = item - self.cached_nodes[inst_id] = item - - return return_nodes - - @utils.debug_enabled(logger=logger) - def non_terminated_nodes(self, tag_filters): - """Return a list of node ids filtered by the specified tags dict. - - This list must not include terminated nodes. For performance reasons, - providers are allowed to cache the result of a call to - non_terminated_nodes() to serve single-node queries - (e.g. is_running(node_id)). This means that non_terminated_nodes() - must be called again to refresh results. - """ - VALIDITY_TAGS = [ - TAG_RAY_CLUSTER_NAME, - TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, - TAG_RAY_LAUNCH_CONFIG, - ] - filters = { - tag: tag_filters[tag] for tag in VALIDITY_TAGS if tag in tag_filters - } - - nodes = self._get_filtered_nodes(tag_filters=filters) - return [k for k, v in nodes.items() if v["status"] == "RUNNING"] - - @utils.debug_enabled(logger=logger) - def is_running(self, node_id): - """Return whether the specified node is running.""" - node = self._get_cached_node(node_id=node_id) - check_result = node is None or node["status"] == "RUNNING" - - return check_result - - @utils.debug_enabled(logger=logger) - def is_terminated(self, node_id): - """Return whether the specified node is terminated.""" - node = self._get_cached_node(node_id=node_id) - check_result = ((node is None) or (node["status"] == "TERMINATED") or - (node["status"] == "TERMINATING")) - - return check_result - - @utils.debug_enabled(logger=logger) - def node_tags(self, node_id): - return self.cached_nodes[node_id]["tags"] - - @utils.debug_enabled(logger=logger) - def external_ip(self, node_id): - """Returns the external ip of the given node.""" - return self._get_cached_node(node_id=node_id)["external_ip"] - - @utils.debug_enabled(logger=logger) - def internal_ip(self, node_id): - """Returns the internal ip (Ray ip) of the given node.""" - return self._get_cached_node(node_id=node_id)["internal_ip"] - - @synchronized - @utils.debug_enabled(logger=logger) - def create_node(self, node_config, tags, count): - """Creates a number of nodes within the namespace.""" - start_time = round(time.time() * 1000) - starting_insts = [] - # Check first if it neccessary to create new nodes / start stopped nodes - VALIDITY_TAGS = [ - TAG_RAY_CLUSTER_NAME, - TAG_RAY_NODE_KIND, - TAG_RAY_USER_NODE_TYPE, - ] - filters = {tag: tags[tag] for tag in VALIDITY_TAGS if tag in tags} - - # Starting stopped nodes if cache_stopped_nodes=True - if self.cache_stopped_nodes: - logger.debug("Checking existing stopped nodes.") - - filters_with_launch_config = copy.copy(filters) - if TAG_RAY_LAUNCH_CONFIG in tags: - filters_with_launch_config[TAG_RAY_LAUNCH_CONFIG] = tags[ - TAG_RAY_LAUNCH_CONFIG] - - nodes_matching_launch_config = self.stopped_nodes( - filters_with_launch_config) - logger.debug(f"Found stopped nodes (with same launch config): " - f"{len(nodes_matching_launch_config)}") - - reuse_nodes = [] - if len(nodes_matching_launch_config) >= count: - reuse_nodes = nodes_matching_launch_config[:count] - else: - nodes_all = self.stopped_nodes(filters) - logger.debug(f"Found stopped nodes (regardless launch config): " - f"{len(nodes_all)}") - nodes_matching_launch_config_ids = [ - n["id"] for n in nodes_matching_launch_config - ] - nodes_non_matching_launch_config = [ - n for n in nodes_all - if n["id"] not in nodes_matching_launch_config_ids - ] - reuse_nodes = (nodes_matching_launch_config + - nodes_non_matching_launch_config) - reuse_nodes = reuse_nodes[:count] - - logger.info( - f"Reusing nodes {len(reuse_nodes)}: {list(reuse_nodes)}. " - "To disable reuse, set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration.",) - - for reuse_node in reuse_nodes: - if reuse_node["status"] == "STOPPING": - get_instance_response = oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).get_instance( - instance_id=reuse_node["id"]) - oci_adaptor.oci.wait_until( - oci_adaptor.get_core_client( - self.region, oci_utils.oci_config.get_profile()), - get_instance_response, - "lifecycle_state", - "STOPPED", - ) - - start_time1 = round(time.time() * 1000) - for matched_node in reuse_nodes: - matched_node_id = matched_node["id"] - instance_action_response = oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).instance_action( - instance_id=matched_node_id, action="START") - - starting_inst = instance_action_response.data - starting_insts.append({ - "inst_id": starting_inst.id, - "ad": starting_inst.availability_domain, - "compartment": starting_inst.compartment_id, - "lifecycle_state": starting_inst.lifecycle_state, - "oci_tags": starting_inst.freeform_tags, - }) - count -= len(reuse_nodes) - - launch_stopped_time = round(time.time() * 1000) - start_time1 - logger.debug( - "Time elapsed(Launch stopped): {0} milli-seconds.".format( - launch_stopped_time)) - # end if self.cache_stopped_nodes:... - - # Let's create additional new nodes (if neccessary) - if count > 0: - compartment = oci_query_helper.find_compartment(self.region) - vcn = oci_query_helper.find_create_vcn_subnet(self.region) - if vcn is None: - raise RuntimeError("VcnSubnetNotFound Error!") - - ocpu_count = 0 - vcpu_str = node_config["VCPUs"] - instance_type_str = node_config["InstanceType"] - - if vcpu_str is not None and vcpu_str != "None": - if instance_type_str.startswith( - f"{oci_utils.oci_config.VM_PREFIX}.A"): - # For ARM cpu, 1*ocpu = 1*vcpu - ocpu_count = round(float(vcpu_str)) - else: - # For Intel / AMD cpu, 1*ocpu = 2*vcpu - ocpu_count = round(float(vcpu_str) / 2) - ocpu_count = 1 if (ocpu_count > 0 and - ocpu_count < 1) else ocpu_count - - machine_shape_config = None - if ocpu_count > 0: - mem = node_config["MemoryInGbs"] - if mem is not None and mem != "None": - machine_shape_config = (oci_adaptor.oci.core.models. - LaunchInstanceShapeConfigDetails( - ocpus=ocpu_count, - memory_in_gbs=mem)) - else: - machine_shape_config = (oci_adaptor.oci.core.models. - LaunchInstanceShapeConfigDetails( - ocpus=ocpu_count)) - - preempitible_config = ( - oci_adaptor.oci.core.models.PreemptibleInstanceConfigDetails( - preemption_action=oci_adaptor.oci.core.models. - TerminatePreemptionAction(type="TERMINATE", - preserve_boot_volume=False)) - if node_config["Preemptible"] else None) - - logger.debug(f"Shape: {instance_type_str}, ocpu: {ocpu_count}") - logger.debug(f"Shape config is {machine_shape_config}") - logger.debug(f"Spot config is {preempitible_config}") - - vm_tags = { - **tags, - TAG_RAY_CLUSTER_NAME: self.cluster_name, - "sky_spot_flag": str(node_config["Preemptible"]).lower(), - } - # Use UTC time so that header & worker nodes use same rule - batch_id = datetime.utcnow().strftime("%Y%m%d%H%M%S") - node_type = tags[TAG_RAY_NODE_KIND] - - oci_query_helper.subscribe_image( - compartment_id=compartment, - listing_id=node_config["AppCatalogListingId"], - resource_version=node_config["ResourceVersion"], - region=self.region, - ) - - start_time1 = round(time.time() * 1000) - for seq in range(1, count + 1): - launch_instance_response = oci_adaptor.get_core_client( - self.region, oci_utils.oci_config.get_profile() - ).launch_instance( - launch_instance_details=oci_adaptor.oci.core.models. - LaunchInstanceDetails( - availability_domain=node_config["AvailabilityDomain"], - compartment_id=compartment, - shape=instance_type_str, - display_name= - f"{self.cluster_name}_{node_type}_{batch_id}_{seq}", - freeform_tags=vm_tags, - metadata={ - "ssh_authorized_keys": node_config["AuthorizedKey"] - }, - source_details=oci_adaptor.oci.core.models. - InstanceSourceViaImageDetails( - source_type="image", - image_id=node_config["ImageId"], - boot_volume_size_in_gbs=node_config[ - "BootVolumeSize"], - boot_volume_vpus_per_gb=int( - node_config["BootVolumePerf"]), - ), - create_vnic_details=oci_adaptor.oci.core.models. - CreateVnicDetails( - assign_public_ip=True, - subnet_id=vcn, - ), - shape_config=machine_shape_config, - preemptible_instance_config=preempitible_config, - )) - - new_inst = launch_instance_response.data - starting_insts.append({ - "inst_id": new_inst.id, - "ad": new_inst.availability_domain, - "compartment": new_inst.compartment_id, - "lifecycle_state": new_inst.lifecycle_state, - "oci_tags": new_inst.freeform_tags, - }) - # end for loop - - launch_new_time = round(time.time() * 1000) - start_time1 - logger.debug("Time elapsed(Launch): {0} milli-seconds.".format( - launch_new_time)) - # end if count > 0:... - - for ninst in starting_insts: - # Waiting for the instance to be RUNNING state - get_instance_response = oci_adaptor.get_core_client( - self.region, oci_utils.oci_config.get_profile()).get_instance( - instance_id=ninst["inst_id"]) - oci_adaptor.oci.wait_until( - oci_adaptor.get_core_client(self.region, - oci_utils.oci_config.get_profile()), - get_instance_response, - "lifecycle_state", - "RUNNING", - ) - ninst["lifecycle_state"] = "RUNNING" - self.cached_nodes[ninst["inst_id"]] = self.get_inst_obj(ninst) - - total_time = round(time.time() * 1000) - start_time - logger.debug( - "Total time elapsed: {0} milli-seconds.".format(total_time)) - - def get_inst_obj(self, inst_info): - list_vnic_attachments_response = oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).list_vnic_attachments( - availability_domain=inst_info["ad"], - compartment_id=inst_info["compartment"], - instance_id=inst_info["inst_id"], - ) - - vnic = list_vnic_attachments_response.data[0] - get_vnic_response = (oci_adaptor.get_net_client( - self.region, oci_utils.oci_config.get_profile()).get_vnic( - vnic_id=vnic.vnic_id).data) - - internal_ip = get_vnic_response.private_ip - external_ip = get_vnic_response.public_ip - if external_ip is None: - external_ip = internal_ip - - return { - "id": inst_info["inst_id"], - "external_ip": external_ip, - "internal_ip": internal_ip, - "tags": inst_info["oci_tags"], - "status": inst_info["lifecycle_state"], - } - - @synchronized - @utils.debug_enabled(logger=logger) - def set_node_tags(self, node_id, tags): - existing_tags = self._get_cached_node(node_id)["tags"] - combined_tags = dict(existing_tags, **tags) - - self.cached_nodes[node_id]["tags"] = combined_tags - retry_count = 0 - while retry_count < oci_utils.oci_config.MAX_RETRY_COUNT: - try: - oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).update_instance( - instance_id=node_id, - update_instance_details=oci_adaptor.oci.core.models. - UpdateInstanceDetails(freeform_tags=combined_tags), - ) - logger.info(f"Tags are well set for node {node_id}") - break - except Exception as e: - retry_count = retry_count + 1 - wait_seconds = oci_utils.oci_config.RETRY_INTERVAL_BASE_SECONDS * retry_count - logger.warn( - f"Not ready yet, wait {wait_seconds} seconds & retry!") - logger.warn(f"Exception message is {str(e)}") - time.sleep(wait_seconds) - - @synchronized - def terminate_node(self, node_id): - """Terminates the specified node.""" - logger.info(f"terminate_node {node_id}...") - node = self._get_cached_node(node_id) - if node is None: - logger.info(f"The node is not existed: {node_id}..") - return # Node not exists yet. - - logger.debug(f"sky_spot_flag: {node['tags']['sky_spot_flag']}") - preemptibleFlag = (True if node and - (str(node["tags"]["sky_spot_flag"]) == "true") else - False) - - if self.cache_stopped_nodes and not preemptibleFlag: - logger.info(f"Stopping instance {node_id}" - "(to fully terminate instead, " - "set `cache_stopped_nodes: False` " - "under `provider` in the cluster configuration)") - instance_action_response = oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).instance_action( - instance_id=node_id, action="STOP") - logger.info( - f"Stopped the instance {instance_action_response.data.id}") - if node_id in self.cached_nodes: - self.cached_nodes[node_id]["status"] = "STOPPED" - state_word = "Stopped" - else: - terminate_instance_response = oci_adaptor.get_core_client( - self.region, - oci_utils.oci_config.get_profile()).terminate_instance(node_id) - logger.debug(terminate_instance_response.data) - if node_id in self.cached_nodes: - del self.cached_nodes[node_id] - state_word = "Terminated" - - logger.info( - f"{state_word} {node_id} w/ sky_spot_flag: {preemptibleFlag}.") - - def _get_node(self, node_id): - self._get_filtered_nodes({}, - force=True) # All except for those terminated. - return self.cached_nodes.get(node_id, None) - - def _get_cached_node(self, node_id): - if node_id in self.cached_nodes: - return self.cached_nodes[node_id] - return self._get_node(node_id=node_id) - - def stopped_nodes(self, tag_filters): - """Return a list of stopped nodes filtered by the specified tags dict.""" - nodes = self._get_filtered_nodes(tag_filters=tag_filters, force=True) - return [ - v for _, v in nodes.items() - if v["status"] in ("STOPPED", "STOPPING") - ] - - def running_nodes(self, tag_filters): - """Return a list of running node ids filtered by the specified tags dict.""" - nodes = self._get_filtered_nodes(tag_filters=tag_filters) - return [k for k, v in nodes.items() if v["status"] == "RUNNING"] diff --git a/sky/skylet/providers/oci/utils.py b/sky/skylet/providers/oci/utils.py deleted file mode 100644 index 5628cee2524..00000000000 --- a/sky/skylet/providers/oci/utils.py +++ /dev/null @@ -1,21 +0,0 @@ -from datetime import datetime -import functools -from logging import Logger - - -def debug_enabled(logger: Logger): - - def decorate(f): - - @functools.wraps(f) - def wrapper(*args, **kwargs): - dt_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.debug(f"{dt_str} Enter {f}, {args}, {kwargs}") - try: - return f(*args, **kwargs) - finally: - logger.debug(f"{dt_str} Exit {f}") - - return wrapper - - return decorate diff --git a/sky/skylet/subprocess_daemon.py b/sky/skylet/subprocess_daemon.py index 1261f4ecf72..55b63d1f9a5 100644 --- a/sky/skylet/subprocess_daemon.py +++ b/sky/skylet/subprocess_daemon.py @@ -15,10 +15,11 @@ def daemonize(): This detachment is crucial in the context of SkyPilot and Ray job. When 'sky cancel' is executed, it uses Ray's stop job API to terminate the job. - Without daemonization, this subprocess_daemon process would be terminated - along with its parent process, ray::task, which is launched with Ray job. - Daemonization ensures this process survives the 'sky cancel' command, - allowing it to prevent orphaned processes of Ray job. + Without daemonization, this subprocess_daemon process will still be a child + of the parent process which would be terminated along with the parent + process, ray::task or the cancel request for jobs, which is launched with + Ray job. Daemonization ensures this process survives the 'sky cancel' + command, allowing it to prevent orphaned processes of Ray job. """ # First fork: Creates a child process identical to the parent if os.fork() > 0: @@ -42,6 +43,15 @@ def daemonize(): parser = argparse.ArgumentParser() parser.add_argument('--parent-pid', type=int, required=True) parser.add_argument('--proc-pid', type=int, required=True) + parser.add_argument( + '--initial-children', + type=str, + default='', + help=( + 'Comma-separated list of initial children PIDs. This is to guard ' + 'against the case where the target process has already terminated, ' + 'while the children are still running.'), + ) args = parser.parse_args() process = None @@ -52,24 +62,34 @@ def daemonize(): except psutil.NoSuchProcess: pass - if process is None: - sys.exit() - + # Initialize children list from arguments children = [] - if parent_process is not None: - # Wait for either parent or target process to exit. + if args.initial_children: + for pid in args.initial_children.split(','): + try: + child = psutil.Process(int(pid)) + children.append(child) + except (psutil.NoSuchProcess, ValueError): + pass + + if process is not None and parent_process is not None: + # Wait for either parent or target process to exit while process.is_running() and parent_process.is_running(): try: - # process.children() must be called while the target process - # is alive, as it will return an empty list if the target - # process has already terminated. tmp_children = process.children(recursive=True) if tmp_children: children = tmp_children except psutil.NoSuchProcess: pass time.sleep(1) - children.append(process) + + if process is not None: + # Kill the target process first to avoid having more children, or fail + # the process due to the children being defunct. + children = [process] + children + + if not children: + sys.exit() for child in children: try: diff --git a/sky/templates/aws-ray.yml.j2 b/sky/templates/aws-ray.yml.j2 index 95751ab1849..8e9898cb784 100644 --- a/sky/templates/aws-ray.yml.j2 +++ b/sky/templates/aws-ray.yml.j2 @@ -173,6 +173,7 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` + # Line 'rm ~/.aws/credentials': explicitly remove the credentials file to be safe. This is to guard against the case where the credential files was uploaded once as `remote_identity` was not set in a previous launch. - mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} @@ -185,7 +186,12 @@ setup_commands: sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; {%- endif %} mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; - [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; + [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); + {%- if remote_identity != 'LOCAL_CREDENTIALS' %} + rm ~/.aws/credentials || true; + {%- endif %} + + # Command to start ray clusters are now placed in `sky.provision.instance_setup`. # We do not need to list it here anymore. diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 7b9737748d3..1140704a708 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -75,9 +75,6 @@ available_node_types: {%- if use_spot %} # optionally set priority to use Spot instances priority: Spot - # set a maximum price for spot instances if desired - # billingProfile: - # maxPrice: -1 {%- endif %} cloudInitSetupCommands: |- {%- for cmd in cloud_init_setup_commands %} diff --git a/sky/templates/kubernetes-port-forward-proxy-command.sh b/sky/templates/kubernetes-port-forward-proxy-command.sh index 0407209a77c..f8205c2393c 100644 --- a/sky/templates/kubernetes-port-forward-proxy-command.sh +++ b/sky/templates/kubernetes-port-forward-proxy-command.sh @@ -58,6 +58,11 @@ KUBECTL_ARGS=() if [ -n "$KUBE_CONTEXT" ]; then KUBECTL_ARGS+=("--context=$KUBE_CONTEXT") fi +# If context is not provided, it means we are using incluster auth. In this case, +# we need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. +if [ -z "$KUBE_CONTEXT" ]; then + KUBECTL_ARGS+=("--kubeconfig=/dev/null") +fi if [ -n "$KUBE_NAMESPACE" ]; then KUBECTL_ARGS+=("--namespace=$KUBE_NAMESPACE") fi diff --git a/sky/templates/kubernetes-ray.yml.j2 b/sky/templates/kubernetes-ray.yml.j2 index e367cd536f6..2087d9c6e9d 100644 --- a/sky/templates/kubernetes-ray.yml.j2 +++ b/sky/templates/kubernetes-ray.yml.j2 @@ -222,7 +222,9 @@ provider: - protocol: TCP port: 22 targetPort: 22 - # Service that maps to the head node of the Ray cluster. + # Service that maps to the head node of the Ray cluster, so that the + # worker nodes can find the head node using + # {{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local - apiVersion: v1 kind: Service metadata: @@ -235,18 +237,12 @@ provider: # names. name: {{cluster_name_on_cloud}}-head spec: + # Create a headless service so that the head node can be reached by + # the worker nodes with any port number. + clusterIP: None # This selector must match the head node pod's selector below. selector: component: {{cluster_name_on_cloud}}-head - ports: - - name: client - protocol: TCP - port: 10001 - targetPort: 10001 - - name: dashboard - protocol: TCP - port: 8265 - targetPort: 8265 # Specify the pod type for the ray head node (as configured below). head_node_type: ray_head_default @@ -280,15 +276,17 @@ available_node_types: # serviceAccountName: skypilot-service-account serviceAccountName: {{k8s_service_account_name}} automountServiceAccountToken: {{k8s_automount_sa_token}} - restartPolicy: Never - # Add node selector if GPUs are requested: + # Add node selector if GPU/TPUs are requested: {% if (k8s_acc_label_key is not none and k8s_acc_label_value is not none) or (k8s_spot_label_key is not none) %} nodeSelector: {% if k8s_acc_label_key is not none and k8s_acc_label_value is not none %} {{k8s_acc_label_key}}: {{k8s_acc_label_value}} {% endif %} + {% if k8s_topology_label_key is not none and k8s_topology_label_value is not none %} + {{k8s_topology_label_key}}: {{k8s_topology_label_value}} + {% endif %} {% if k8s_spot_label_key is not none %} {{k8s_spot_label_key}}: {{k8s_spot_label_value|tojson}} {% endif %} @@ -319,11 +317,160 @@ available_node_types: - name: ray-node imagePullPolicy: IfNotPresent image: {{image_id}} + env: + - name: SKYPILOT_POD_NODE_TYPE + valueFrom: + fieldRef: + fieldPath: metadata.labels['ray-node-type'] + {% for key, value in k8s_env_vars.items() if k8s_env_vars is not none %} + - name: {{ key }} + value: {{ value }} + {% endfor %} # Do not change this command - it keeps the pod alive until it is # explicitly killed. command: ["/bin/bash", "-c", "--"] args: - | + # For backwards compatibility, we put a marker file in the pod + # to indicate that the pod is running with the changes introduced + # in project nimbus: https://github.com/skypilot-org/skypilot/pull/4393 + # TODO: Remove this marker file and it's usage in setup_commands + # after v0.10.0 release. + touch /tmp/skypilot_is_nimbus + + # Helper function to conditionally use sudo + # TODO(zhwu): consolidate the two prefix_cmd and sudo replacements + prefix_cmd() { if [ $(id -u) -ne 0 ]; then echo "sudo"; else echo ""; fi; } + [ $(id -u) -eq 0 ] && function sudo() { "$@"; } || true; + + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") + + # STEP 1: Run apt update, install missing packages, and set up ssh. + ( + ( + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get update > /tmp/apt-update.log 2>&1 || \ + echo "Warning: apt-get update failed. Continuing anyway..." >> /tmp/apt-update.log + PACKAGES="rsync curl netcat gcc patch pciutils fuse openssh-server"; + + # Separate packages into two groups: packages that are installed first + # so that curl and rsync are available sooner to unblock the following + # conda installation and rsync. + set -e + INSTALL_FIRST=""; + MISSING_PACKAGES=""; + for pkg in $PACKAGES; do + if [ "$pkg" == "netcat" ]; then + if ! dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; then + INSTALL_FIRST="$INSTALL_FIRST netcat-openbsd"; + fi + elif ! dpkg -l | grep -q "^ii $pkg "; then + if [ "$pkg" == "curl" ] || [ "$pkg" == "rsync" ]; then + INSTALL_FIRST="$INSTALL_FIRST $pkg"; + else + MISSING_PACKAGES="$MISSING_PACKAGES $pkg"; + fi + fi + done; + if [ ! -z "$INSTALL_FIRST" ]; then + echo "Installing core packages: $INSTALL_FIRST"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $INSTALL_FIRST; + fi; + # SSH and other packages are not necessary, so we disable set -e + set +e + + if [ ! -z "$MISSING_PACKAGES" ]; then + echo "Installing missing packages: $MISSING_PACKAGES"; + DEBIAN_FRONTEND=noninteractive $(prefix_cmd) apt-get install -y $MISSING_PACKAGES; + fi; + $(prefix_cmd) mkdir -p /var/run/sshd; + $(prefix_cmd) sed -i "s/PermitRootLogin prohibit-password/PermitRootLogin yes/" /etc/ssh/sshd_config; + $(prefix_cmd) sed "s@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g" -i /etc/pam.d/sshd; + cd /etc/ssh/ && $(prefix_cmd) ssh-keygen -A; + $(prefix_cmd) mkdir -p ~/.ssh; + $(prefix_cmd) chown -R $(whoami) ~/.ssh; + $(prefix_cmd) chmod 700 ~/.ssh; + $(prefix_cmd) cat /etc/secret-volume/ssh-publickey* > ~/.ssh/authorized_keys; + $(prefix_cmd) chmod 644 ~/.ssh/authorized_keys; + $(prefix_cmd) service ssh restart; + $(prefix_cmd) sed -i "s/mesg n/tty -s \&\& mesg n/" ~/.profile; + ) > /tmp/${STEPS[0]}.log 2>&1 || { + echo "Error: ${STEPS[0]} failed. Continuing anyway..." > /tmp/${STEPS[0]}.failed + cat /tmp/${STEPS[0]}.log + exit 1 + } + ) & + + # STEP 2: Install conda, ray and skypilot (for dependencies); start + # ray cluster. + ( + ( + set -e + mkdir -p ~/.sky + # Wait for `curl` package to be installed before installing conda + # and ray. + until dpkg -l | grep -q "^ii curl "; do + sleep 0.1 + echo "Waiting for curl package to be installed..." + done + {{ conda_installation_commands }} + {{ ray_installation_commands }} + VIRTUAL_ENV=~/skypilot-runtime ~/.local/bin/uv pip install skypilot[kubernetes,remote] + touch /tmp/ray_skypilot_installation_complete + echo "=== Ray and skypilot installation completed ===" + + # Disable set -e, as we have some commands that are ok to fail + # after the ray start. + # TODO(zhwu): this is a hack, we should fix the commands that are + # ok to fail. + if [ "$SKYPILOT_POD_NODE_TYPE" == "head" ]; then + set +e + {{ ray_head_start_command }} + else + # Start ray worker on the worker pod. + # Wait until the head pod is available with an IP address + export SKYPILOT_RAY_HEAD_IP="{{cluster_name_on_cloud}}-head.{{k8s_namespace}}.svc.cluster.local" + export SKYPILOT_RAY_PORT={{skypilot_ray_port}} + # Wait until the ray cluster is started on the head pod + until dpkg -l | grep -q "^ii \(netcat\|netcat-openbsd\|netcat-traditional\) "; do + sleep 0.1 + echo "Waiting for netcat package to be installed..." + done + until nc -z -w 1 ${SKYPILOT_RAY_HEAD_IP} ${SKYPILOT_RAY_PORT}; do + sleep 0.1 + done + + set +e + {{ ray_worker_start_command }} + fi + ) > /tmp/${STEPS[1]}.log 2>&1 || { + echo "Error: ${STEPS[1]} failed. Continuing anyway..." > /tmp/${STEPS[1]}.failed + cat /tmp/${STEPS[1]}.log + exit 1 + } + ) & + + + # STEP 3: Set up environment variables; this should be relatively fast. + ( + ( + set -e + if [ $(id -u) -eq 0 ]; then + echo 'alias sudo=""' >> ~/.bashrc; echo succeed; + else + if command -v sudo >/dev/null 2>&1; then + timeout 2 sudo -l >/dev/null 2>&1 && echo succeed || { echo 52; exit 52; }; + else + { echo 52; exit 52; }; + fi; + fi; + printenv | while IFS='=' read -r key value; do echo "export $key=\"$value\""; done > ~/container_env_var.sh && $(prefix_cmd) mv ~/container_env_var.sh /etc/profile.d/container_env_var.sh + ) > /tmp/${STEPS[2]}.log 2>&1 || { + echo "Error: ${STEPS[2]} failed. Continuing anyway..." > /tmp/${STEPS[2]}.failed + cat /tmp/${STEPS[2]}.log + exit 1 + } + ) & + function mylsof { p=$(for pid in /proc/{0..9}*; do i=$(basename "$pid"); for file in "$pid"/fd/*; do link=$(readlink -e "$file"); if [ "$link" = "$1" ]; then echo "$i"; fi; done; done); echo "$p"; }; # Tails file and checks every 5 sec for @@ -400,18 +547,30 @@ available_node_types: requests: cpu: {{cpus}} memory: {{memory}}G - nvidia.com/gpu: {{accelerator_count}} + {% if k8s_resource_key is not none %} + # Number of requested google.com/tpu must be equal to the total + # number of available TPU chips on the TPU slice node either it + # being a node from multi-host TPU slice or single-host TPU + # slice. Example reference: + # https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#how_tpus_work + {{k8s_resource_key}}: {{accelerator_count}} + {% endif %} {% if k8s_fuse_device_required %} # Kubernetes resource exposed by the fuse device manager # https://gitlab.com/arm-research/smarter/smarter-device-manager smarter-devices/fuse: "1" {% endif %} + {% if k8s_resource_key is not none or k8s_fuse_device_required %} limits: - nvidia.com/gpu: {{accelerator_count}} # Limits need to be defined for GPU requests + # Limits need to be defined for GPU/TPU requests + {% if k8s_resource_key is not none %} + {{k8s_resource_key}}: {{accelerator_count}} + {% endif %} {% if k8s_fuse_device_required %} smarter-devices/fuse: "1" {% endif %} - + {% endif %} + setup_commands: # Disable `unattended-upgrades` to prevent apt-get from hanging. It should be called at the beginning before the process started to avoid being blocked. (This is a temporary fix.) # Create ~/.ssh/config file in case the file does not exist in the image. @@ -419,18 +578,51 @@ setup_commands: # Line 'sudo grep ..': set the number of threads per process to unlimited to avoid ray job submit stucking issue when the number of running ray jobs increase. # Line 'mkdir -p ..': disable host key check # Line 'python3 -c ..': patch the buggy ray files and enable `-o allow_other` option for `goofys` - - sudo DEBIAN_FRONTEND=noninteractive apt install lsof gcc patch pciutils rsync fuse curl -y; + # Line 'for step in ..': check if any failure indicator exists for the setup done in pod args and print the error message. This is only a best effort, as the + # commands in pod args are asynchronous and we cannot guarantee the failure indicators are created before the setup commands finish. + - | mkdir -p ~/.ssh; touch ~/.ssh/config; {%- for initial_setup_command in initial_setup_commands %} {{ initial_setup_command }} {%- endfor %} - {{ conda_installation_commands }} - {{ ray_skypilot_installation_commands }} + STEPS=("apt-ssh-setup" "runtime-setup" "env-setup") + start_epoch=$(date +%s); + echo "=== Logs for asynchronous ray and skypilot installation ==="; + if [ -f /tmp/skypilot_is_nimbus ]; then + echo "=== Logs for asynchronous ray and skypilot installation ==="; + [ -f /tmp/ray_skypilot_installation_complete ] && cat /tmp/${STEPS[1]}.log || + { tail -f -n +1 /tmp/${STEPS[1]}.log & TAIL_PID=$!; echo "Tail PID: $TAIL_PID"; until [ -f /tmp/ray_skypilot_installation_complete ]; do sleep 0.5; done; kill $TAIL_PID || true; }; + [ -f /tmp/${STEPS[1]}.failed ] && { echo "Error: ${STEPS[1]} failed. Exiting."; exit 1; } || true; + fi + end_epoch=$(date +%s); + echo "=== Ray and skypilot dependencies installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); + {{ skypilot_wheel_installation_commands }} + end_epoch=$(date +%s); + echo "=== Skypilot wheel installation completed in $(($end_epoch - $start_epoch)) secs ==="; + start_epoch=$(date +%s); sudo touch ~/.sudo_as_admin_successful; sudo bash -c 'rm -rf /etc/security/limits.d; echo "* soft nofile 1048576" >> /etc/security/limits.conf; echo "* hard nofile 1048576" >> /etc/security/limits.conf'; - sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); sudo systemctl set-property user-$(id -u $(whoami)).slice TasksMax=infinity; sudo systemctl daemon-reload; + sudo grep -e '^DefaultTasksMax' /etc/systemd/system.conf || (sudo bash -c 'echo "DefaultTasksMax=infinity" >> /etc/systemd/system.conf'); + ulimit -n 1048576; mkdir -p ~/.ssh; (grep -Pzo -q "Host \*\n StrictHostKeyChecking no" ~/.ssh/config) || printf "Host *\n StrictHostKeyChecking no\n" >> ~/.ssh/config; [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); # This is needed for `-o allow_other` option for `goofys`; + end_epoch=$(date +%s); + echo "=== Setup system configs and fuse completed in $(($end_epoch - $start_epoch)) secs ==="; + for step in $STEPS; do [ -f "/tmp/${step}.failed" ] && { echo "Error: /tmp/${step}.failed found:"; cat /tmp/${step}.log; exit 1; } || true; done; + {% if tpu_requested %} + # The /tmp/tpu_logs directory is where TPU-related logs, such as logs from + # the TPU runtime, are written. These capture runtime information about the + # TPU execution, including any warnings, errors, or general activity of + # the TPU driver. By default, the /tmp/tpu_logs directory is created with + # 755 permissions, and the user of the provisioned pod is not necessarily + # a root. Hence, we need to update the write permission so the logs can be + # properly written. + # TODO(Doyoung): Investigate to see why TPU workload fails to run without + # execution permission, such as granting 766 to log file. Check if it's a + # must and see if there's a workaround to grant minimum permission. + sudo chmod 777 /tmp/tpu_logs; + {% endif %} # Format: `REMOTE_PATH : LOCAL_PATH` file_mounts: { diff --git a/sky/templates/oci-ray.yml.j2 b/sky/templates/oci-ray.yml.j2 index 64fa4e745c7..17c3e34459f 100644 --- a/sky/templates/oci-ray.yml.j2 +++ b/sky/templates/oci-ray.yml.j2 @@ -7,7 +7,7 @@ idle_timeout_minutes: 60 provider: type: external - module: sky.skylet.providers.oci.OCINodeProvider + module: sky.provision.oci region: {{region}} cache_stopped_nodes: True # Disable launch config check for worker nodes as it can cause resource leakage. @@ -39,25 +39,6 @@ available_node_types: Preemptible: {{use_spot}} AuthorizedKey: | skypilot:ssh_public_key_content -{% if num_nodes > 1 %} - ray_worker_default: - min_workers: {{num_nodes - 1}} - max_workers: {{num_nodes - 1}} - resources: {} - node_config: - InstanceType: {{instance_type}} - VCPUs: {{cpus}} - MemoryInGbs: {{memory}} - BootVolumeSize: {{disk_size}} - BootVolumePerf: {{vpu}} - AvailabilityDomain: {{zone}} - ImageId: {{image}} - AppCatalogListingId: {{app_catalog_listing_id}} - ResourceVersion: {{resource_version}} - Preemptible: {{use_spot}} - AuthorizedKey: | - skypilot:ssh_public_key_content -{%- endif %} head_node_type: ray_head_default @@ -70,9 +51,6 @@ file_mounts: { {%- endfor %} } -rsync_exclude: [] - -initialization_commands: [] # List of shell commands to run to set up nodes. # NOTE: these are very performance-sensitive. Each new item opens/closes an SSH @@ -113,34 +91,6 @@ setup_commands: [ -f /etc/fuse.conf ] && sudo sed -i 's/#user_allow_other/user_allow_other/g' /etc/fuse.conf || (sudo sh -c 'echo "user_allow_other" > /etc/fuse.conf'); sudo iptables -I INPUT -i ens3 -m state --state ESTABLISHED,RELATED,NEW -j ACCEPT; -# Command to start ray on the head node. You don't need to change this. -# NOTE: these are very performance-sensitive. Each new item opens/closes an SSH -# connection, which is expensive. Try your best to co-locate commands into fewer -# items! The same comment applies for worker_start_ray_commands. -# -# Increment the following for catching performance bugs easier: -# current num items (num SSH connections): 2 -head_start_ray_commands: - # NOTE: --disable-usage-stats in `ray start` saves 10 seconds of idle wait. - # Line "which prlimit ..": increase the limit of the number of open files for the raylet process, as the `ulimit` may not take effect at this point, because it requires - # all the sessions to be reloaded. This is a workaround. - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --head --port={{ray_port}} --dashboard-port={{ray_dashboard_port}} --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; - {{dump_port_command}}; {{ray_head_wait_initialized_command}} - -{%- if num_nodes > 1 %} -worker_start_ray_commands: - - {{ sky_activate_python_env }}; {{ sky_ray_cmd }} stop; RAY_SCHEDULER_EVENTS=0 RAY_DEDUP_LOGS=0 {{ sky_ray_cmd }} start --disable-usage-stats --address=$RAY_HEAD_IP:{{ray_port}} --object-manager-port=8076 {{"--resources='%s'" % custom_resources if custom_resources}} --temp-dir {{ray_temp_dir}} || exit 1; - which prlimit && for id in $(pgrep -f raylet/raylet); do sudo prlimit --nofile=1048576:1048576 --pid=$id || true; done; -{%- else %} -worker_start_ray_commands: [] -{%- endif %} - -head_node: {} -worker_nodes: {} +# Command to start ray clusters are now placed in `sky.provision.instance_setup`. +# We do not need to list it here anymore. -# These fields are required for external cloud providers. -head_setup_commands: [] -worker_setup_commands: [] -cluster_synced_files: [] -file_mounts_sync_continuously: False diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index e705debaf8d..92d1f2749d7 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -767,6 +767,10 @@ def run( ] if self.context: kubectl_args += ['--context', self.context] + # If context is none, it means we are using incluster auth. In this + # case, need to set KUBECONFIG to /dev/null to avoid using kubeconfig. + if self.context is None: + kubectl_args += ['--kubeconfig', '/dev/null'] kubectl_args += [self.pod_name] if ssh_mode == SshMode.LOGIN: assert isinstance(cmd, list), 'cmd must be a list for login mode.' diff --git a/sky/utils/common_utils.py b/sky/utils/common_utils.py index 5fce435b770..3fcdd24e505 100644 --- a/sky/utils/common_utils.py +++ b/sky/utils/common_utils.py @@ -697,3 +697,22 @@ def truncate_long_string(s: str, max_length: int = 35) -> str: if len(prefix) < max_length: prefix += s[len(prefix):max_length] return prefix + '...' + + +def hash_file(path: str, hash_alg: str) -> 'hashlib._Hash': + # In python 3.11, hashlib.file_digest is available, but for <3.11 we have to + # do it manually. + # This implementation is simplified from the implementation in CPython. + # TODO(cooperc): Use hashlib.file_digest once we move to 3.11+. + # Beware of f.read() as some files may be larger than memory. + with open(path, 'rb') as f: + file_hash = hashlib.new(hash_alg) + buf = bytearray(2**18) + view = memoryview(buf) + while True: + size = f.readinto(buf) + if size == 0: + # EOF + break + file_hash.update(view[:size]) + return file_hash diff --git a/sky/utils/controller_utils.py b/sky/utils/controller_utils.py index 0ab2fd7e117..3f0bd5c5ed7 100644 --- a/sky/utils/controller_utils.py +++ b/sky/utils/controller_utils.py @@ -25,7 +25,9 @@ from sky.jobs import utils as managed_job_utils from sky.serve import constants as serve_constants from sky.serve import serve_utils +from sky.setup_files import dependencies from sky.skylet import constants +from sky.skylet import log_lib from sky.utils import common_utils from sky.utils import env_options from sky.utils import rich_utils @@ -187,79 +189,49 @@ def from_type(cls, controller_type: str) -> Optional['Controllers']: # Install cli dependencies. Not using SkyPilot wheels because the wheel # can be cleaned up by another process. -# TODO(zhwu): Keep the dependencies align with the ones in setup.py def _get_cloud_dependencies_installation_commands( controller: Controllers) -> List[str]: - # TODO(tian): Make dependency installation command a method of cloud - # class and get all installation command for enabled clouds. - commands = [] # We use / instead of strong formatting, as we need to update # the at the end of the for loop, and python does not support # partial string formatting. prefix_str = ('[/] Check & install cloud dependencies ' 'on controller: ') + commands: List[str] = [] # This is to make sure the shorter checking message does not have junk # characters from the previous message. - empty_str = ' ' * 10 - aws_dependencies_installation = ( - 'pip list | grep boto3 > /dev/null 2>&1 || pip install ' - 'botocore>=1.29.10 boto3>=1.26.1; ' - # Need to separate the installation of awscli from above because some - # other clouds will install boto3 but not awscli. - 'pip list | grep awscli> /dev/null 2>&1 || pip install "urllib3<2" ' - 'awscli>=1.27.10 "colorama<0.4.5" > /dev/null 2>&1') - setup_clouds: List[str] = [] + empty_str = ' ' * 20 + + # All python dependencies will be accumulated and then installed in one + # command at the end. This is very fast if the packages are already + # installed, so we don't check that. + python_packages: Set[str] = set() + + step_prefix = prefix_str.replace('', str(len(commands) + 1)) + commands.append(f'echo -en "\\r{step_prefix}uv{empty_str}" &&' + f'{constants.SKY_UV_INSTALL_CMD} >/dev/null 2>&1') + for cloud in sky_check.get_cached_enabled_clouds_or_refresh(): - if isinstance( - clouds, - (clouds.Lambda, clouds.SCP, clouds.Fluidstack, clouds.Paperspace)): - # no need to install any cloud dependencies for lambda, scp, - # fluidstack and paperspace - continue - if isinstance(cloud, clouds.AWS): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append(f'echo -en "\\r{step_prefix}AWS{empty_str}" && ' + - aws_dependencies_installation) - setup_clouds.append(str(cloud)) - elif isinstance(cloud, clouds.Azure): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append( - f'echo -en "\\r{step_prefix}Azure{empty_str}" && ' - 'pip list | grep azure-cli > /dev/null 2>&1 || ' - 'pip install "azure-cli>=2.31.0" azure-core ' - '"azure-identity>=1.13.0" azure-mgmt-network > /dev/null 2>&1') - # Have to separate this installation of az blob storage from above - # because this is newly-introduced and not part of azure-cli. We - # need a separate installed check for this. + cloud_python_dependencies: List[str] = dependencies.extras_require[ + cloud.canonical_name()] + + if isinstance(cloud, clouds.Azure): + # azure-cli cannot be normally installed by uv. + # See comments in sky/skylet/constants.py. + cloud_python_dependencies.remove(dependencies.AZURE_CLI) + + step_prefix = prefix_str.replace('', str(len(commands) + 1)) commands.append( - 'pip list | grep azure-storage-blob > /dev/null 2>&1 || ' - 'pip install azure-storage-blob msgraph-sdk > /dev/null 2>&1') - setup_clouds.append(str(cloud)) + f'echo -en "\\r{step_prefix}azure-cli{empty_str}" &&' + f'{constants.SKY_UV_PIP_CMD} install --prerelease=allow ' + f'"{dependencies.AZURE_CLI}" > /dev/null 2>&1') elif isinstance(cloud, clouds.GCP): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append( - f'echo -en "\\r{step_prefix}GCP{empty_str}" && ' - 'pip list | grep google-api-python-client > /dev/null 2>&1 || ' - 'pip install "google-api-python-client>=2.69.0" ' - '> /dev/null 2>&1') - # Have to separate the installation of google-cloud-storage from - # above because for a VM launched on GCP, the VM may have - # google-api-python-client installed alone. - commands.append( - 'pip list | grep google-cloud-storage > /dev/null 2>&1 || ' - 'pip install google-cloud-storage > /dev/null 2>&1') - commands.append(f'{gcp.GOOGLE_SDK_INSTALLATION_COMMAND}') - setup_clouds.append(str(cloud)) + step_prefix = prefix_str.replace('', str(len(commands) + 1)) + commands.append(f'echo -en "\\r{step_prefix}GCP SDK{empty_str}" &&' + f'{gcp.GOOGLE_SDK_INSTALLATION_COMMAND}') elif isinstance(cloud, clouds.Kubernetes): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) + step_prefix = prefix_str.replace('', str(len(commands) + 1)) commands.append( f'echo -en "\\r{step_prefix}Kubernetes{empty_str}" && ' - 'pip list | grep kubernetes > /dev/null 2>&1 || ' - 'pip install "kubernetes>=20.0.0" > /dev/null 2>&1 &&' # Install k8s + skypilot dependencies 'sudo bash -c "if ' '! command -v curl &> /dev/null || ' @@ -275,54 +247,36 @@ def _get_cloud_dependencies_installation_commands( '/bin/linux/amd64/kubectl" && ' 'sudo install -o root -g root -m 0755 ' 'kubectl /usr/local/bin/kubectl))') - setup_clouds.append(str(cloud)) elif isinstance(cloud, clouds.Cudo): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) + step_prefix = prefix_str.replace('', str(len(commands) + 1)) commands.append( - f'echo -en "\\r{step_prefix}Cudo{empty_str}" && ' - 'pip list | grep cudo-compute > /dev/null 2>&1 || ' - 'pip install "cudo-compute>=0.1.10" > /dev/null 2>&1 && ' + f'echo -en "\\r{step_prefix}cudoctl{empty_str}" && ' 'wget https://download.cudo.org/compute/cudoctl-0.3.2-amd64.deb -O ~/cudoctl.deb > /dev/null 2>&1 && ' # pylint: disable=line-too-long 'sudo dpkg -i ~/cudoctl.deb > /dev/null 2>&1') - setup_clouds.append(str(cloud)) - elif isinstance(cloud, clouds.RunPod): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append(f'echo -en "\\r{step_prefix}RunPod{empty_str}" && ' - 'pip list | grep runpod > /dev/null 2>&1 || ' - 'pip install "runpod>=1.5.1" > /dev/null 2>&1') - setup_clouds.append(str(cloud)) - if controller == Controllers.JOBS_CONTROLLER: - if isinstance(cloud, clouds.IBM): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append( - f'echo -en "\\r{step_prefix}IBM{empty_str}" ' - '&& pip list | grep ibm-cloud-sdk-core > /dev/null 2>&1 || ' - 'pip install ibm-cloud-sdk-core ibm-vpc ' - 'ibm-platform-services ibm-cos-sdk > /dev/null 2>&1') - setup_clouds.append(str(cloud)) - elif isinstance(cloud, clouds.OCI): - step_prefix = prefix_str.replace('', - str(len(setup_clouds) + 1)) - commands.append(f'echo -en "\\r{prefix_str}OCI{empty_str}" && ' - 'pip list | grep oci > /dev/null 2>&1 || ' - 'pip install oci > /dev/null 2>&1') - setup_clouds.append(str(cloud)) + elif isinstance(cloud, clouds.IBM): + if controller != Controllers.JOBS_CONTROLLER: + # We only need IBM deps on the jobs controller. + cloud_python_dependencies = [] + + python_packages.update(cloud_python_dependencies) + if (cloudflare.NAME in storage_lib.get_cached_enabled_storage_clouds_or_refresh()): - step_prefix = prefix_str.replace('', str(len(setup_clouds) + 1)) - commands.append( - f'echo -en "\\r{step_prefix}Cloudflare{empty_str}" && ' + - aws_dependencies_installation) - setup_clouds.append(cloudflare.NAME) + python_packages.update(dependencies.extras_require['cloudflare']) + + packages_string = ' '.join([f'"{package}"' for package in python_packages]) + step_prefix = prefix_str.replace('', str(len(commands) + 1)) + commands.append( + f'echo -en "\\r{step_prefix}cloud python packages{empty_str}" && ' + f'{constants.SKY_UV_PIP_CMD} install {packages_string} > /dev/null 2>&1' + ) + total_commands = len(commands) finish_prefix = prefix_str.replace('[/] ', ' ') commands.append(f'echo -e "\\r{finish_prefix}done.{empty_str}"') + commands = [ - command.replace('', str(len(setup_clouds))) - for command in commands + command.replace('', str(total_commands)) for command in commands ] return commands @@ -380,11 +334,19 @@ def download_and_stream_latest_job_log( else: log_dir = list(log_dirs.values())[0] log_file = os.path.join(log_dir, 'run.log') - # Print the logs to the console. + # TODO(zhwu): refactor this into log_utils, along with the + # refactoring for the log_lib.tail_logs. try: with open(log_file, 'r', encoding='utf-8') as f: - print(f.read()) + # Stream the logs to the console without reading the whole + # file into memory. + start_streaming = False + for line in f: + if log_lib.LOG_FILE_START_STREAMING_AT in line: + start_streaming = True + if start_streaming: + print(line, end='', flush=True) except FileNotFoundError: logger.error('Failed to find the logs for the user ' f'program at {log_file}.') @@ -818,8 +780,9 @@ def maybe_translate_local_file_mounts_and_sync_up(task: 'task_lib.Task', '[dim]View storages: sky storage ls')) try: task.sync_storage_mounts() - except ValueError as e: - if 'No enabled cloud for storage' in str(e): + except (ValueError, exceptions.NoCloudAccessError) as e: + if 'No enabled cloud for storage' in str(e) or isinstance( + e, exceptions.NoCloudAccessError): data_src = None if has_local_source_paths_file_mounts: data_src = 'file_mounts' diff --git a/sky/utils/kubernetes/generate_kubeconfig.sh b/sky/utils/kubernetes/generate_kubeconfig.sh index 8923df9c051..4ed27b62e1e 100755 --- a/sky/utils/kubernetes/generate_kubeconfig.sh +++ b/sky/utils/kubernetes/generate_kubeconfig.sh @@ -12,6 +12,7 @@ # * Specify SKYPILOT_NAMESPACE env var to override the default namespace where the service account is created. # * Specify SKYPILOT_SA_NAME env var to override the default service account name. # * Specify SKIP_SA_CREATION=1 to skip creating the service account and use an existing one +# * Specify SUPER_USER=1 to create a service account with cluster-admin permissions # # Usage: # # Create "sky-sa" service account with minimal permissions in "default" namespace and generate kubeconfig @@ -22,6 +23,9 @@ # # # Use an existing service account "my-sa" in "my-namespace" namespace and generate kubeconfig # $ SKIP_SA_CREATION=1 SKYPILOT_SA_NAME=my-sa SKYPILOT_NAMESPACE=my-namespace ./generate_kubeconfig.sh +# +# # Create "sky-sa" service account with cluster-admin permissions in "default" namespace +# $ SUPER_USER=1 ./generate_kubeconfig.sh set -eu -o pipefail @@ -29,9 +33,11 @@ set -eu -o pipefail # use default. SKYPILOT_SA=${SKYPILOT_SA_NAME:-sky-sa} NAMESPACE=${SKYPILOT_NAMESPACE:-default} +SUPER_USER=${SUPER_USER:-0} echo "Service account: ${SKYPILOT_SA}" echo "Namespace: ${NAMESPACE}" +echo "Super user permissions: ${SUPER_USER}" # Set OS specific values. if [[ "$OSTYPE" == "linux-gnu" ]]; then @@ -47,8 +53,43 @@ fi # If the user has set SKIP_SA_CREATION=1, skip creating the service account. if [ -z ${SKIP_SA_CREATION+x} ]; then - echo "Creating the Kubernetes Service Account with minimal RBAC permissions." - kubectl apply -f - <`') + f'`{kubernetes_utils.GPU_RESOURCE_KEY}: `') else: print('GPU labeling started - this may take 10 min or more to complete.' '\nTo check the status of GPU labeling jobs, run ' diff --git a/sky/utils/kubernetes/rsync_helper.sh b/sky/utils/kubernetes/rsync_helper.sh index 79bd5fa79f8..719ee00d872 100755 --- a/sky/utils/kubernetes/rsync_helper.sh +++ b/sky/utils/kubernetes/rsync_helper.sh @@ -16,7 +16,9 @@ echo "context: $context" >&2 context_lower=$(echo "$context" | tr '[:upper:]' '[:lower:]') shift if [ -z "$context" ] || [ "$context_lower" = "none" ]; then - kubectl exec -i $pod -n $namespace -- "$@" + # If context is none, it means we are using incluster auth. In this case, + # use need to set KUBECONFIG to /dev/null to avoid using kubeconfig file. + kubectl exec -i $pod -n $namespace --kubeconfig=/dev/null -- "$@" else kubectl exec -i $pod -n $namespace --context=$context -- "$@" fi diff --git a/sky/utils/log_utils.py b/sky/utils/log_utils.py index e116f36819e..a5884333609 100644 --- a/sky/utils/log_utils.py +++ b/sky/utils/log_utils.py @@ -1,7 +1,8 @@ """Logging utils.""" import enum +import time import types -from typing import List, Optional, Type +from typing import Callable, Iterator, List, Optional, TextIO, Type import colorama import pendulum @@ -284,3 +285,53 @@ def readable_time_duration(start: Optional[float], diff = diff.replace('hour', 'hr') return diff + + +def follow_logs( + file: TextIO, + *, + should_stop: Callable[[], bool], + stop_on_eof: bool = False, + process_line: Optional[Callable[[str], Iterator[str]]] = None, + idle_timeout_seconds: Optional[int] = None, +) -> Iterator[str]: + """Streams and processes logs line by line from a file. + + Args: + file: File object to read logs from. + should_stop: Callback that returns True when streaming should stop. + stop_on_eof: If True, stop when reaching end of file. + process_line: Optional callback to transform/filter each line. + idle_timeout_seconds: If set, stop after these many seconds without + new content. + + Yields: + Log lines, possibly transformed by process_line if provided. + """ + current_line: str = '' + seconds_without_content: int = 0 + + while True: + content = file.readline() + + if not content: + if stop_on_eof or should_stop(): + break + + if idle_timeout_seconds is not None: + if seconds_without_content >= idle_timeout_seconds: + break + seconds_without_content += 1 + + time.sleep(1) + continue + + seconds_without_content = 0 + current_line += content + + if '\n' in current_line or '\r' in current_line: + if process_line is not None: + yield from process_line(current_line) + else: + yield current_line + current_line = '' diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 81c4cb332a6..851e77a57fc 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -308,6 +308,9 @@ def get_storage_schema(): def get_service_schema(): """Schema for top-level `service:` field (for SkyServe).""" + # To avoid circular imports, only import when needed. + # pylint: disable=import-outside-toplevel + from sky.serve import load_balancing_policies return { '$schema': 'https://json-schema.org/draft/2020-12/schema', 'type': 'object', @@ -382,6 +385,11 @@ def get_service_schema(): 'replicas': { 'type': 'integer', }, + 'load_balancing_policy': { + 'type': 'string', + 'case_insensitive_enum': list( + load_balancing_policies.LB_POLICIES.keys()) + }, } } @@ -655,6 +663,7 @@ class RemoteIdentityOptions(enum.Enum): """ LOCAL_CREDENTIALS = 'LOCAL_CREDENTIALS' SERVICE_ACCOUNT = 'SERVICE_ACCOUNT' + NO_UPLOAD = 'NO_UPLOAD' def get_default_remote_identity(cloud: str) -> str: @@ -675,7 +684,14 @@ def get_default_remote_identity(cloud: str) -> str: _REMOTE_IDENTITY_SCHEMA_KUBERNETES = { 'remote_identity': { - 'type': 'string' + 'anyOf': [{ + 'type': 'string' + }, { + 'type': 'object', + 'additionalProperties': { + 'type': 'string' + } + }] }, } diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index acb8fb9f490..992c6bbe3ff 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -2,21 +2,25 @@ from multiprocessing import pool import os import random +import resource import subprocess import time -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import colorama import psutil from sky import exceptions from sky import sky_logging +from sky.skylet import constants from sky.skylet import log_lib from sky.utils import timeline from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) +_fd_limit_warning_shown = False + @timeline.event def run(cmd, **kwargs): @@ -42,12 +46,54 @@ def run_no_outputs(cmd, **kwargs): **kwargs) -def get_parallel_threads() -> int: - """Returns the number of idle CPUs.""" +def _get_thread_multiplier(cloud_str: Optional[str] = None) -> int: + # If using Kubernetes, we use 4x the number of cores. + if cloud_str and cloud_str.lower() == 'kubernetes': + return 4 + return 1 + + +def get_max_workers_for_file_mounts(common_file_mounts: Dict[str, str], + cloud_str: Optional[str] = None) -> int: + global _fd_limit_warning_shown + fd_limit, _ = resource.getrlimit(resource.RLIMIT_NOFILE) + + # Raise warning for low fd_limit (only once) + if fd_limit < 1024 and not _fd_limit_warning_shown: + logger.warning( + f'Open file descriptor limit ({fd_limit}) is low. File sync to ' + 'remote clusters may be slow. Consider increasing the limit using ' + '`ulimit -n ` or modifying system limits.') + _fd_limit_warning_shown = True + + fd_per_rsync = 5 + for src in common_file_mounts.values(): + if os.path.isdir(src): + # Assume that each file/folder under src takes 5 file descriptors + # on average. + fd_per_rsync = max(fd_per_rsync, len(os.listdir(src)) * 5) + + # Reserve some file descriptors for the system and other processes + fd_reserve = 100 + + max_workers = (fd_limit - fd_reserve) // fd_per_rsync + # At least 1 worker, and avoid too many workers overloading the system. + num_threads = get_parallel_threads(cloud_str) + max_workers = min(max(max_workers, 1), num_threads) + logger.debug(f'Using {max_workers} workers for file mounts.') + return max_workers + + +def get_parallel_threads(cloud_str: Optional[str] = None) -> int: + """Returns the number of threads to use for parallel execution. + + Args: + cloud_str: The cloud + """ cpu_count = os.cpu_count() if cpu_count is None: cpu_count = 1 - return max(4, cpu_count - 1) + return max(4, cpu_count - 1) * _get_thread_multiplier(cloud_str) def run_in_parallel(func: Callable, @@ -198,3 +244,52 @@ def run_with_retries( continue break return returncode, stdout, stderr + + +def kill_process_daemon(process_pid: int) -> None: + """Start a daemon as a safety net to kill the process. + + Args: + process_pid: The PID of the process to kill. + """ + # Get initial children list + try: + process = psutil.Process(process_pid) + initial_children = [p.pid for p in process.children(recursive=True)] + except psutil.NoSuchProcess: + initial_children = [] + + parent_pid = os.getpid() + daemon_script = os.path.join( + os.path.dirname(os.path.abspath(log_lib.__file__)), + 'subprocess_daemon.py') + python_path = subprocess.check_output(constants.SKY_GET_PYTHON_PATH_CMD, + shell=True, + stderr=subprocess.DEVNULL, + encoding='utf-8').strip() + daemon_cmd = [ + python_path, + daemon_script, + '--parent-pid', + str(parent_pid), + '--proc-pid', + str(process_pid), + # We pass the initial children list to avoid the race condition where + # the process_pid is terminated before the daemon starts and gets the + # children list. + '--initial-children', + ','.join(map(str, initial_children)), + ] + + # We do not need to set `start_new_session=True` here, as the + # daemon script will detach itself from the parent process with + # fork to avoid being killed by parent process. See the reason we + # daemonize the process in `sky/skylet/subprocess_daemon.py`. + subprocess.Popen( + daemon_cmd, + # Suppress output + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + # Disable input + stdin=subprocess.DEVNULL, + ) diff --git a/sky/utils/timeline.py b/sky/utils/timeline.py index 29c6c3d94ee..4db9bd149b2 100644 --- a/sky/utils/timeline.py +++ b/sky/utils/timeline.py @@ -9,6 +9,7 @@ import os import threading import time +import traceback from typing import Callable, Optional, Union import filelock @@ -48,8 +49,9 @@ def begin(self): 'ph': 'B', 'ts': f'{time.time() * 10 ** 6: .3f}', }) + event_begin['args'] = {'stack': '\n'.join(traceback.format_stack())} if self._message is not None: - event_begin['args'] = {'message': self._message} + event_begin['args']['message'] = self._message _events.append(event_begin) def end(self): @@ -77,11 +79,9 @@ def event(name_or_fn: Union[str, Callable], message: Optional[str] = None): class FileLockEvent: """Serve both as a file lock and event for the lock.""" - def __init__(self, lockfile: Union[str, os.PathLike]): + def __init__(self, lockfile: Union[str, os.PathLike], timeout: float = -1): self._lockfile = lockfile - # TODO(mraheja): remove pylint disabling when filelock version updated - # pylint: disable=abstract-class-instantiated - self._lock = filelock.FileLock(self._lockfile) + self._lock = filelock.FileLock(self._lockfile, timeout) self._hold_lock_event = Event(f'[FileLock.hold]:{self._lockfile}') def acquire(self): diff --git a/tests/backward_compatibility_tests.sh b/tests/backward_compatibility_tests.sh index 276fda899dd..696b87ff6ad 100644 --- a/tests/backward_compatibility_tests.sh +++ b/tests/backward_compatibility_tests.sh @@ -187,7 +187,7 @@ sky jobs logs -n "${MANAGED_JOB_JOB_NAME}-7-1" || exit 1 s=$(sky jobs queue | grep ${MANAGED_JOB_JOB_NAME}-7) echo "$s" echo "$s" | grep "SUCCEEDED" | wc -l | grep 2 || exit 1 -echo "$s" | grep "CANCELLED" | wc -l | grep 1 || exit 1 +echo "$s" | grep "CANCELLING\|CANCELLED" | wc -l | grep 1 || exit 1 fi sky down ${CLUSTER_NAME}* -y diff --git a/tests/common.py b/tests/common.py index d41ff3bead0..5f38cb73855 100644 --- a/tests/common.py +++ b/tests/common.py @@ -64,8 +64,9 @@ def _get_az_mappings(_): monkeypatch.setattr( 'sky.provision.kubernetes.utils.detect_gpu_label_formatter', lambda *_args, **_kwargs: [kubernetes_utils.SkyPilotLabelFormatter, {}]) - monkeypatch.setattr('sky.provision.kubernetes.utils.detect_gpu_resource', - lambda *_args, **_kwargs: [True, []]) + monkeypatch.setattr( + 'sky.provision.kubernetes.utils.detect_accelerator_resource', + lambda *_args, **_kwargs: [True, []]) monkeypatch.setattr('sky.provision.kubernetes.utils.check_instance_fits', lambda *_args, **_kwargs: [True, '']) monkeypatch.setattr('sky.provision.kubernetes.utils.get_spot_label', diff --git a/tests/conftest.py b/tests/conftest.py index c5c0f4a1515..ee5caf062b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional import common # TODO(zongheng): for some reason isort places it here. import pytest @@ -73,7 +73,6 @@ def pytest_addoption(parser): parser.addoption( '--generic-cloud', type=str, - default='aws', choices=all_clouds_in_smoke_tests, help='Cloud to use for generic tests. If the generic cloud is ' 'not within the clouds to be run, it will be reset to the first ' @@ -102,14 +101,21 @@ def pytest_configure(config): def _get_cloud_to_run(config) -> List[str]: cloud_to_run = [] + for cloud in all_clouds_in_smoke_tests: if config.getoption(f'--{cloud}'): if cloud == 'cloudflare': cloud_to_run.append(default_clouds_to_run[0]) else: cloud_to_run.append(cloud) - if not cloud_to_run: + + generic_cloud_option = config.getoption('--generic-cloud') + if generic_cloud_option is not None and generic_cloud_option not in cloud_to_run: + cloud_to_run.append(generic_cloud_option) + + if len(cloud_to_run) == 0: cloud_to_run = default_clouds_to_run + return cloud_to_run @@ -187,11 +193,10 @@ def _is_generic_test(item) -> bool: def _generic_cloud(config) -> str: - c = config.getoption('--generic-cloud') - cloud_to_run = _get_cloud_to_run(config) - if c not in cloud_to_run: - c = cloud_to_run[0] - return c + generic_cloud_option = config.getoption('--generic-cloud') + if generic_cloud_option is not None: + return generic_cloud_option + return _get_cloud_to_run(config)[0] @pytest.fixture diff --git a/tests/kubernetes/README.md b/tests/kubernetes/README.md index 7c5ed7586ff..e15f593e006 100644 --- a/tests/kubernetes/README.md +++ b/tests/kubernetes/README.md @@ -1,10 +1,10 @@ # SkyPilot Kubernetes Development Scripts -This directory contains useful scripts and notes for developing SkyPilot on Kubernetes. +This directory contains useful scripts and notes for developing SkyPilot on Kubernetes. ## Building and pushing SkyPilot image -We maintain a container image that has all basic SkyPilot dependencies installed. +We maintain a container image that has all basic SkyPilot dependencies installed. This image is hosted at `us-central1-docker.pkg.dev/skypilot-375900/skypilotk8s/skypilot:latest`. To build this image locally and optionally push to the SkyPilot registry, run: @@ -18,10 +18,10 @@ To build this image locally and optionally push to the SkyPilot registry, run: ``` ## Running a local development cluster -We use (kind)[https://kind.sigs.k8s.io/] to run a local Kubernetes cluster +We use (kind)[https://kind.sigs.k8s.io/] to run a local Kubernetes cluster for development. To create a local development cluster, run: -```bash +```bash sky local up ``` @@ -50,7 +50,13 @@ curl --header "Content-Type: application/json-patch+json" \ ```bash PROJECT_ID=$(gcloud config get-value project) CLUSTER_NAME=testclusterromil - gcloud beta container --project "${PROJECT_ID}" clusters create "${CLUSTER_NAME}" --zone "us-central1-c" --no-enable-basic-auth --cluster-version "1.29.1-gke.1589020" --release-channel "regular" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-t4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/${PROJECT_ID}/global/networks/default" --subnetwork "projects/${PROJECT_ID}/regions/us-central1/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "110" --security-posture=standard --workload-vulnerability-scanning=disabled --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "v100" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-v100,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "largecpu" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "n1-standard-16" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" && gcloud beta container --project "${PROJECT_ID}" node-pools create "l4" --cluster "${CLUSTER_NAME}" --zone "us-central1-c" --machine-type "g2-standard-4" --accelerator "type=nvidia-l4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "us-central1-c" + REGION=us-central1-c + GKE_VERSION=$(gcloud container get-server-config \ + --region=${REGION} \ + --flatten=channels \ + --filter="channels.channel=REGULAR" \ + --format="value(channels.defaultVersion)") + gcloud beta container --project "${PROJECT_ID}" clusters create "${CLUSTER_NAME}" --zone "${REGION}" --no-enable-basic-auth --cluster-version "${GKE_VERSION}" --release-channel "regular" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-t4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --logging=SYSTEM,WORKLOAD --monitoring=SYSTEM --enable-ip-alias --network "projects/${PROJECT_ID}/global/networks/default" --subnetwork "projects/${PROJECT_ID}/regions/${REGION%-*}/subnetworks/default" --no-enable-intra-node-visibility --default-max-pods-per-node "110" --security-posture=standard --workload-vulnerability-scanning=disabled --no-enable-master-authorized-networks --addons HorizontalPodAutoscaling,HttpLoadBalancing,GcePersistentDiskCsiDriver --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --enable-managed-prometheus --enable-shielded-nodes --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "v100" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "n1-standard-8" --accelerator "type=nvidia-tesla-v100,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "largecpu" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "n1-standard-16" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" && gcloud beta container --project "${PROJECT_ID}" node-pools create "l4" --cluster "${CLUSTER_NAME}" --zone "${REGION}" --machine-type "g2-standard-4" --accelerator "type=nvidia-l4,count=1" --image-type "COS_CONTAINERD" --disk-type "pd-balanced" --disk-size "100" --metadata disable-legacy-endpoints=true --scopes "https://www.googleapis.com/auth/devstorage.read_only","https://www.googleapis.com/auth/logging.write","https://www.googleapis.com/auth/monitoring","https://www.googleapis.com/auth/servicecontrol","https://www.googleapis.com/auth/service.management.readonly","https://www.googleapis.com/auth/trace.append" --num-nodes "2" --enable-autoupgrade --enable-autorepair --max-surge-upgrade 1 --max-unavailable-upgrade 0 --node-locations "${REGION}" ``` 2. Get the kubeconfig for your cluster and place it in `~/.kube/config`: ```bash @@ -65,7 +71,7 @@ curl --header "Content-Type: application/json-patch+json" \ kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded.yaml kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/cos/daemonset-preloaded-latest.yaml - + # If using Ubuntu based nodes: kubectl apply -f https://raw.githubusercontent.com/GoogleCloudPlatform/container-engine-accelerators/master/nvidia-driver-installer/ubuntu/daemonset-preloaded.yaml @@ -123,6 +129,6 @@ NOTE - If are using nodeport networking, make sure port 32100 is open in your no NOTE - If are using nodeport networking, make sure port 32100 is open in your EKS cluster's default security group. ## Other useful scripts -`scripts` directory contains other useful scripts for development, including -Kubernetes dashboard, ray yaml for testing the SkyPilot Kubernetes node provider +`scripts` directory contains other useful scripts for development, including +Kubernetes dashboard, ray yaml for testing the SkyPilot Kubernetes node provider and more. diff --git a/tests/skyserve/http/oci.yaml b/tests/skyserve/http/oci.yaml new file mode 100644 index 00000000000..d7d98c18ab4 --- /dev/null +++ b/tests/skyserve/http/oci.yaml @@ -0,0 +1,10 @@ +service: + readiness_probe: / + replicas: 2 + +resources: + cloud: oci + ports: 8080 + cpus: 2+ + +run: python -m http.server 8080 \ No newline at end of file diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 4433b0ae4df..f37467417fa 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -25,6 +25,7 @@ # Change cloud for generic tests to aws # > pytest tests/test_smoke.py --generic-cloud aws +import enum import inspect import json import os @@ -95,6 +96,166 @@ 'sleep 10; s=$(sky jobs queue);' 'echo "Waiting for job to stop RUNNING"; echo "$s"; done') +# Cluster functions +_ALL_JOB_STATUSES = "|".join([status.value for status in sky.JobStatus]) +_ALL_CLUSTER_STATUSES = "|".join([status.value for status in sky.ClusterStatus]) +_ALL_MANAGED_JOB_STATUSES = "|".join( + [status.value for status in sky.ManagedJobStatus]) + + +def _statuses_to_str(statuses: List[enum.Enum]): + """Convert a list of enums to a string with all the values separated by |.""" + assert len(statuses) > 0, 'statuses must not be empty' + if len(statuses) > 1: + return '(' + '|'.join([status.value for status in statuses]) + ')' + else: + return statuses[0].value + + +_WAIT_UNTIL_CLUSTER_STATUS_CONTAINS = ( + # A while loop to wait until the cluster status + # becomes certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster status \'{cluster_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky status {cluster_name} --refresh | ' + 'awk "/^{cluster_name}/ ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_CLUSTER_STATUSES + + ')$/) print \$i}}"); ' + 'if [[ "$current_status" =~ {cluster_status} ]]; ' + 'then echo "Target cluster status {cluster_status} reached."; break; fi; ' + 'echo "Waiting for cluster status to become {cluster_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_status_contains( + cluster_name: str, cluster_status: List[sky.ClusterStatus], + timeout: int): + return _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.format( + cluster_name=cluster_name, + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +def _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard: str, cluster_status: List[sky.ClusterStatus], + timeout: int): + wait_cmd = _WAIT_UNTIL_CLUSTER_STATUS_CONTAINS.replace( + 'sky status {cluster_name}', + 'sky status "{cluster_name}"').replace('awk "/^{cluster_name}/', + 'awk "/^{cluster_name_awk}/') + return wait_cmd.format(cluster_name=cluster_name_wildcard, + cluster_name_awk=cluster_name_wildcard.replace( + '*', '.*'), + cluster_status=_statuses_to_str(cluster_status), + timeout=timeout) + + +_WAIT_UNTIL_CLUSTER_IS_NOT_FOUND = ( + # A while loop to wait until the cluster is not found or timeout + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for cluster to be removed"; exit 1; ' + 'fi; ' + 'if sky status -r {cluster_name}; sky status {cluster_name} | grep "{cluster_name} not found"; then ' + ' echo "Cluster {cluster_name} successfully removed."; break; ' + 'fi; ' + 'echo "Waiting for cluster {cluster_name} to be removed..."; ' + 'sleep 10; ' + 'done') + + +def _get_cmd_wait_until_cluster_is_not_found(cluster_name: str, timeout: int): + return _WAIT_UNTIL_CLUSTER_IS_NOT_FOUND.format(cluster_name=cluster_name, + timeout=timeout) + + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID = ( + # A while loop to wait until the job status + # contains certain status, with timeout. + 'start_time=$SECONDS; ' + 'while true; do ' + 'if (( $SECONDS - $start_time > {timeout} )); then ' + ' echo "Timeout after {timeout} seconds waiting for job status \'{job_status}\'"; exit 1; ' + 'fi; ' + 'current_status=$(sky queue {cluster_name} | ' + 'awk "\\$1 == \\"{job_id}\\" ' + '{{for (i=1; i<=NF; i++) if (\$i ~ /^(' + _ALL_JOB_STATUSES + + ')$/) print \$i}}"); ' + 'found=0; ' # Initialize found variable outside the loop + 'while read -r line; do ' # Read line by line + ' if [[ "$line" =~ {job_status} ]]; then ' # Check each line + ' echo "Target job status {job_status} reached."; ' + ' found=1; ' + ' break; ' # Break inner loop + ' fi; ' + 'done <<< "$current_status"; ' + 'if [ "$found" -eq 1 ]; then break; fi; ' # Break outer loop if match found + 'echo "Waiting for job status to contain {job_status}, current status: $current_status"; ' + 'sleep 10; ' + 'done') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "') + +_WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.replace( + 'awk "\\$1 == \\"{job_id}\\"', 'awk "\\$2 == \\"{job_name}\\"') + + +def _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name: str, job_id: str, job_status: List[sky.JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_ID.format( + cluster_name=cluster_name, + job_id=job_id, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name: str, job_status: List[sky.JobStatus], timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_WITHOUT_MATCHING_JOB.format( + cluster_name=cluster_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +def _get_cmd_wait_until_job_status_contains_matching_job_name( + cluster_name: str, job_name: str, job_status: List[sky.JobStatus], + timeout: int): + return _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + cluster_name=cluster_name, + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# Managed job functions + +_WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME = _WAIT_UNTIL_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.replace( + 'sky queue {cluster_name}', 'sky jobs queue').replace( + 'awk "\\$2 == \\"{job_name}\\"', + 'awk "\\$2 == \\"{job_name}\\" || \\$3 == \\"{job_name}\\"').replace( + _ALL_JOB_STATUSES, _ALL_MANAGED_JOB_STATUSES) + + +def _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name: str, job_status: List[sky.JobStatus], timeout: int): + return _WAIT_UNTIL_MANAGED_JOB_STATUS_CONTAINS_MATCHING_JOB_NAME.format( + job_name=job_name, + job_status=_statuses_to_str(job_status), + timeout=timeout) + + +# After the timeout, the cluster will stop if autostop is set, and our check +# should be more than the timeout. To address this, we extend the timeout by +# _BUMP_UP_SECONDS before exiting. +_BUMP_UP_SECONDS = 35 + DEFAULT_CMD_TIMEOUT = 15 * 60 @@ -288,13 +449,13 @@ def test_example_app(): # (min, pid=1277) # (min, pid=1277) task run finish # ✓ Job finished (status: SUCCEEDED). - - # 📋 Useful Commands + # # Job ID: 1 + # 📋 Useful Commands # ├── To cancel the job: sky cancel test 1 # ├── To stream job logs: sky logs test 1 # └── To view job queue: sky queue test - + # # Cluster name: test # ├── To log into the head VM: ssh test # ├── To submit a job: sky exec test yaml_file @@ -314,8 +475,8 @@ def test_example_app(): 'grep "Job finished (status: SUCCEEDED)" && ' 'echo "==Validating task output ending 2==" && ' 'echo "$s" | grep -A 5 "Job finished (status: SUCCEEDED)" | ' - 'grep "Useful Commands" && ' - 'echo "$s" | grep -A 1 "Useful Commands" | grep "Job ID:"') + 'grep "Job ID:" && ' + 'echo "$s" | grep -A 1 "Job ID:" | grep "Useful Commands"') # ---------- A minimal task ---------- @@ -399,7 +560,6 @@ def test_launch_fast_with_autostop(generic_cloud: str): # Azure takes ~ 7m15s (435s) to autostop a VM, so here we use 600 to ensure # the VM is stopped. autostop_timeout = 600 if generic_cloud == 'azure' else 250 - test = Test( 'test_launch_fast_with_autostop', [ @@ -407,11 +567,15 @@ def test_launch_fast_with_autostop(generic_cloud: str): f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --cloud {generic_cloud} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 1 --status', f'sky status -r {name} | grep UP', - f'sleep {autostop_timeout}', # Ensure cluster is stopped - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', - + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), + # Even the cluster is stopped, cloud platform may take a while to + # delete the VM. + f'sleep {_BUMP_UP_SECONDS}', # Launch again. Do full output validation - we expect the cluster to re-launch f'unset SKYPILOT_DEBUG; s=$(sky launch -y -c {name} --fast -i 1 tests/test_yamls/minimal.yaml) && {_VALIDATE_LAUNCH_OUTPUT}', f'sky logs {name} 2 --status', @@ -448,6 +612,7 @@ def test_aws_region(): @pytest.mark.aws def test_aws_with_ssh_proxy_command(): name = _get_cluster_name() + with tempfile.NamedTemporaryFile(mode='w') as f: f.write( textwrap.dedent(f"""\ @@ -465,9 +630,23 @@ def test_aws_with_ssh_proxy_command(): f'sky logs {name} 1 --status', f'export SKYPILOT_CONFIG={f.name}; sky exec {name} echo hi', f'sky logs {name} 2 --status', + # Start a small job to make sure the controller is created. + f'sky jobs launch -n {name}-0 --cloud aws --cpus 2 --use-spot -y echo hi', + # Wait other tests to create the job controller first, so that + # the job controller is not launched with proxy command. + _get_cmd_wait_until_cluster_status_contains_wildcard( + cluster_name_wildcard='sky-jobs-controller-*', + cluster_status=[sky.ClusterStatus.UP], + timeout=300), f'export SKYPILOT_CONFIG={f.name}; sky jobs launch -n {name} --cpus 2 --cloud aws --region us-east-1 -yd echo hi', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name} | grep "STARTING\|RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.SUCCEEDED, + sky.ManagedJobStatus.RUNNING, + sky.ManagedJobStatus.STARTING + ], + timeout=300), ], f'sky down -y {name} jump-{name}; sky jobs cancel -y -n {name}', ) @@ -837,6 +1016,12 @@ def test_clone_disk_aws(): f'sky launch -y -c {name} --cloud aws --region us-east-2 --retry-until-up "echo hello > ~/user_file.txt"', f'sky launch --clone-disk-from {name} -y -c {name}-clone && exit 1 || true', f'sky stop {name} -y', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=60), + # Wait for EC2 instance to be in stopped state. + # TODO: event based wait. 'sleep 60', f'sky launch --clone-disk-from {name} -y -c {name}-clone --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', f'sky launch --clone-disk-from {name} -y -c {name}-clone-2 --cloud aws -d --region us-east-2 "cat ~/user_file.txt | grep hello"', @@ -883,8 +1068,8 @@ def test_gcp_mig(): # Check MIG exists. f'gcloud compute instance-groups managed list --format="value(name)" | grep "^sky-mig-{name}"', f'sky autostop -i 0 --down -y {name}', - 'sleep 120', - f'sky status -r {name}; sky status {name} | grep "{name} not found"', + _get_cmd_wait_until_cluster_is_not_found(cluster_name=name, + timeout=120), f'gcloud compute instance-templates list | grep "sky-it-{name}"', # Launch again with the same region. The original instance template # should be removed. @@ -951,8 +1136,10 @@ def test_custom_default_conda_env(generic_cloud: str): f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', f'sky logs {name} 2 --status', f'sky autostop -y -i 0 {name}', - 'sleep 60', - f'sky status -r {name} | grep "STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=80), f'sky start -y {name}', f'sky logs {name} 2 --no-follow | grep -E "myenv\\s+\\*"', f'sky exec {name} tests/test_yamls/test_custom_default_conda_env.yaml', @@ -973,10 +1160,13 @@ def test_stale_job(generic_cloud: str): f'sky launch -y -c {name} --cloud {generic_cloud} "echo hi"', f'sky exec {name} -d "echo start; sleep 10000"', f'sky stop {name} -y', - 'sleep 100', # Ensure this is large enough, else GCP leaks. + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=100), f'sky start {name} -y', f'sky logs {name} 1 --status', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED_DRIVER', ], f'sky down -y {name}', ) @@ -1001,13 +1191,18 @@ def test_aws_stale_job_manual_restart(): '--output text`; ' f'aws ec2 stop-instances --region {region} ' '--instance-ids $id', - 'sleep 40', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=40), f'sky launch -c {name} -y "echo hi"', f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[sky.JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS), ], f'sky down -y {name}', ) @@ -1037,8 +1232,10 @@ def test_gcp_stale_job_manual_restart(): f'sky logs {name} 1 --status', f'sky logs {name} 3 --status', # Ensure the skylet updated the stale job status. - f'sleep {events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS}', - f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep FAILED', + _get_cmd_wait_until_job_status_contains_without_matching_job( + cluster_name=name, + job_status=[sky.JobStatus.FAILED_DRIVER], + timeout=events.JobSchedulerEvent.EVENT_INTERVAL_SECONDS) ], f'sky down -y {name}', ) @@ -1056,6 +1253,10 @@ def test_env_check(generic_cloud: str): [ f'sky launch -y -c {name} --cloud {generic_cloud} --detach-setup examples/env_check.yaml', f'sky logs {name} 1 --status', # Ensure the job succeeded. + # Test --detach-setup with only setup. + f'sky launch -y -c {name} --detach-setup tests/test_yamls/test_only_setup.yaml', + f'sky logs {name} 2 --status', + f'sky logs {name} 2 | grep "hello world"', ], f'sky down -y {name}', timeout=total_timeout_minutes * 60, @@ -1715,6 +1916,7 @@ def test_large_job_queue(generic_cloud: str): f'for i in `seq 1 75`; do sky exec {name} -n {name}-$i -d "echo $i; sleep 100000000"; done', f'sky cancel -y {name} 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16', 'sleep 90', + # Each job takes 0.5 CPU and the default VM has 8 CPUs, so there should be 8 / 0.5 = 16 jobs running. # The first 16 jobs are canceled, so there should be 75 - 32 = 43 jobs PENDING. f's=$(sky queue {name}); echo "$s"; echo; echo; echo "$s" | grep -v grep | grep PENDING | wc -l | grep 43', @@ -1856,7 +2058,13 @@ def test_multi_echo(generic_cloud: str): f'until sky logs {name} 32 --status; do echo "Waiting for job 32 to finish..."; sleep 1; done', ] + # Ensure jobs succeeded. - [f'sky logs {name} {i + 1} --status' for i in range(32)] + + [ + _get_cmd_wait_until_job_status_contains_matching_job_id( + cluster_name=name, + job_id=i + 1, + job_status=[sky.JobStatus.SUCCEEDED], + timeout=120) for i in range(32) + ] + # Ensure monitor/autoscaler didn't crash on the 'assert not # unfulfilled' error. If process not found, grep->ssh returns 1. [f'ssh {name} \'ps aux | grep "[/]"monitor.py\''], @@ -1999,6 +2207,25 @@ def test_tpu_vm_pod(): run_one_test(test) +# ---------- TPU Pod Slice on GKE. ---------- +@pytest.mark.kubernetes +def test_tpu_pod_slice_gke(): + name = _get_cluster_name() + test = Test( + 'tpu_pod_slice_gke', + [ + f'sky launch -y -c {name} examples/tpu/tpuvm_mnist.yaml --cloud kubernetes --gpus tpu-v5-lite-podslice', + f'sky logs {name} 1', # Ensure the job finished. + f'sky logs {name} 1 --status', # Ensure the job succeeded. + f'sky exec {name} "conda activate flax; python -c \'import jax; print(jax.devices()[0].platform);\' | grep tpu || exit 1;"', # Ensure TPU is reachable. + f'sky logs {name} 2 --status' + ], + f'sky down -y {name}', + timeout=30 * 60, # can take 30 mins + ) + run_one_test(test) + + # ---------- Simple apps. ---------- @pytest.mark.no_scp # SCP does not support num_nodes > 1 yet def test_multi_hostname(generic_cloud: str): @@ -2409,12 +2636,19 @@ def test_gcp_start_stop(): f'sky exec {name} "prlimit -n --pid=\$(pgrep -f \'raylet/raylet --raylet_socket_name\') | grep \'"\'1048576 1048576\'"\'"', # Ensure the raylet process has the correct file descriptor limit. f'sky logs {name} 3 --status', # Ensure the job succeeded. f'sky stop -y {name}', - f'sleep 20', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=40), f'sky start -y {name} -i 1', f'sky exec {name} examples/gcp_start_stop.yaml', f'sky logs {name} 4 --status', # Ensure the job succeeded. - 'sleep 180', - f'sky status -r {name} | grep "INIT\|STOPPED"', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT + ], + timeout=200), ], f'sky down -y {name}', ) @@ -2437,9 +2671,13 @@ def test_azure_start_stop(): f'sky start -y {name} -i 1', f'sky exec {name} examples/azure_start_stop.yaml', f'sky logs {name} 3 --status', # Ensure the job succeeded. - 'sleep 260', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' - f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.STOPPED, sky.ClusterStatus.INIT + ], + timeout=280) + + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}', ], f'sky down -y {name}', timeout=30 * 60, # 30 mins @@ -2475,8 +2713,10 @@ def test_autostop(generic_cloud: str): f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', # Ensure the cluster is STOPPED. - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), # Ensure the cluster is UP and the autostop setting is reset ('-'). f'sky start -y {name}', @@ -2492,8 +2732,10 @@ def test_autostop(generic_cloud: str): f'sky autostop -y {name} -i 1', # Should restart the timer. 'sleep 40', f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout), # Test restarting the idleness timer via exec: f'sky start -y {name}', @@ -2502,9 +2744,10 @@ def test_autostop(generic_cloud: str): 'sleep 45', # Almost reached the threshold. f'sky exec {name} echo hi', # Should restart the timer. 'sleep 45', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep UP', - f'sleep {autostop_timeout}', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=autostop_timeout + _BUMP_UP_SECONDS), ], f'sky down -y {name}', timeout=total_timeout_minutes * 60, @@ -2644,7 +2887,7 @@ def test_cancel_pytorch(generic_cloud: str): f'sky launch -c {name} --cloud {generic_cloud} examples/resnet_distributed_torch.yaml -y -d', # Wait the GPU process to start. 'sleep 90', - f'sky exec {name} "(nvidia-smi | grep python) || ' + f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep python) || ' # When run inside container/k8s, nvidia-smi cannot show process ids. # See https://github.com/NVIDIA/nvidia-docker/issues/179 # To work around, we check if GPU utilization is greater than 0. @@ -2652,7 +2895,7 @@ def test_cancel_pytorch(generic_cloud: str): f'sky logs {name} 2 --status', # Ensure the job succeeded. f'sky cancel -y {name} 1', 'sleep 60', - f'sky exec {name} "(nvidia-smi | grep \'No running process\') || ' + f'sky exec {name} --num-nodes 2 "(nvidia-smi | grep \'No running process\') || ' # Ensure Xorg is the only process running. '[ \$(nvidia-smi | grep -A 10 Processes | grep -A 10 === | grep -v Xorg) -eq 2 ]"', f'sky logs {name} 3 --status', # Ensure the job succeeded. @@ -2721,15 +2964,19 @@ def test_stop_gcp_spot(): f'sky exec {name} -- ls myfile', f'sky logs {name} 2 --status', f'sky autostop {name} -i0 -y', - 'sleep 90', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=90), f'sky start {name} -y', f'sky exec {name} -- ls myfile', f'sky logs {name} 3 --status', # -i option at launch should go through: f'sky launch -c {name} -i0 -y', - 'sleep 120', - f's=$(sky status {name} --refresh); echo "$s"; echo; echo; echo "$s" | grep {name} | grep STOPPED', + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=120), ], f'sky down -y {name}', ) @@ -2749,14 +2996,27 @@ def test_managed_jobs(generic_cloud: str): [ f'sky jobs launch -n {name}-1 --cloud {generic_cloud} examples/managed_job.yaml -y -d', f'sky jobs launch -n {name}-2 --cloud {generic_cloud} examples/managed_job.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "PENDING\|SUBMITTED\|STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[ + sky.ManagedJobStatus.PENDING, + sky.ManagedJobStatus.SUBMITTED, + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60), + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[ + sky.ManagedJobStatus.PENDING, + sky.ManagedJobStatus.SUBMITTED, + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60), f'sky jobs cancel -y -n {name}-1', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name}-1 | head -n1 | grep CANCELLED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-1', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=230), # Test the functionality for logging. f's=$(sky jobs logs -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "start counting"', f's=$(sky jobs logs --controller -n {name}-2 --no-follow); echo "$s"; echo "$s" | grep "Cluster launched:"', @@ -2826,9 +3086,11 @@ def test_managed_jobs_failed_setup(generic_cloud: str): 'managed_jobs_failed_setup', [ f'sky jobs launch -n {name} --cloud {generic_cloud} -y -d tests/test_yamls/failed_setup.yaml', - 'sleep 330', # Make sure the job failed quickly. - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.FAILED_SETUP], + timeout=330 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -2851,7 +3113,10 @@ def test_managed_jobs_pipeline_failed_setup(generic_cloud: str): 'managed_jobs_pipeline_failed_setup', [ f'sky jobs launch -n {name} -y -d tests/test_yamls/failed_setup_pipeline.yaml', - 'sleep 600', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.FAILED_SETUP], + timeout=600), # Make sure the job failed quickly. f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "FAILED_SETUP"', # Task 0 should be SUCCEEDED. @@ -2885,8 +3150,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): 'managed_jobs_recovery_aws', [ f'sky jobs launch --cloud aws --region {region} --use-spot -n {name} "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=600), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -2896,8 +3163,10 @@ def test_managed_jobs_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2925,15 +3194,19 @@ def test_managed_jobs_recovery_gcp(): 'managed_jobs_recovery_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --cpus 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo "$RUN_ID"; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -2956,8 +3229,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): 'managed_jobs_pipeline_recovery_aws', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_aws.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -2976,8 +3251,10 @@ def test_managed_jobs_pipeline_recovery_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3007,8 +3284,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): 'managed_jobs_pipeline_recovery_gcp', [ f'sky jobs launch -n {name} tests/test_yamls/pipeline_gcp.yaml -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids', # Terminate the cluster manually. @@ -3019,8 +3298,10 @@ def test_managed_jobs_pipeline_recovery_gcp(): f'cut -d\'_\' -f1 | rev | cut -d\'-\' -f1`; {terminate_cmd}'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 200', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=200), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID: | grep "$RUN_ID"', f'RUN_IDS=$(sky jobs logs -n {name} --no-follow | grep -A 4 SKYPILOT_TASK_IDS | cut -d")" -f2); echo "$RUN_IDS" | tee /tmp/{name}-run-ids-new', f'diff /tmp/{name}-run-ids /tmp/{name}-run-ids-new', @@ -3046,8 +3327,13 @@ def test_managed_jobs_recovery_default_resources(generic_cloud: str): 'managed-spot-recovery-default-resources', [ f'sky jobs launch -n {name} --cloud {generic_cloud} --use-spot "sleep 30 && sudo shutdown now && sleep 1000" -y -d', - 'sleep 360', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|RECOVERING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.RUNNING, + sky.ManagedJobStatus.RECOVERING + ], + timeout=360), ], f'sky jobs cancel -y -n {name}', timeout=25 * 60, @@ -3067,8 +3353,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): 'managed_jobs_recovery_multi_node_aws', [ f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 450', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=450), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' @@ -3079,8 +3367,10 @@ def test_managed_jobs_recovery_multi_node_aws(aws_config_region): '--output text)'), _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 560', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3108,15 +3398,19 @@ def test_managed_jobs_recovery_multi_node_gcp(): 'managed_jobs_recovery_multi_node_gcp', [ f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot --num-nodes 2 "echo SKYPILOT_TASK_ID: \$SKYPILOT_TASK_ID; sleep 1800" -y -d', - 'sleep 400', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=400), f'RUN_ID=$(sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2); echo "$RUN_ID" | tee /tmp/{name}-run-id', # Terminate the worker manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=name), f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RECOVERING"', - 'sleep 420', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=560), f'RUN_ID=$(cat /tmp/{name}-run-id); echo $RUN_ID; sky jobs logs -n {name} --no-follow | grep SKYPILOT_TASK_ID | cut -d: -f2 | grep "$RUN_ID"', ], f'sky jobs cancel -y -n {name}', @@ -3141,13 +3435,17 @@ def test_managed_jobs_cancellation_aws(aws_config_region): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud aws --region {region} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING\|RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.STARTING, sky.ManagedJobStatus.RUNNING + ], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3155,12 +3453,16 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud aws --region {region} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), (f's=$(aws ec2 describe-instances --region {region} ' f'--filters Name=tag:ray-cluster-name,Values={name_2_on_cloud}-* ' f'--query Reservations[].Instances[].State[].Name ' @@ -3168,8 +3470,11 @@ def test_managed_jobs_cancellation_aws(aws_config_region): ), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud aws --region {region} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + # The job is running in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. (f'aws ec2 terminate-instances --region {region} --instance-ids $(' f'aws ec2 describe-instances --region {region} ' @@ -3179,10 +3484,10 @@ def test_managed_jobs_cancellation_aws(aws_config_region): _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (shutting-down) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$(aws ec2 describe-instances --region {region} ' @@ -3217,34 +3522,42 @@ def test_managed_jobs_cancellation_gcp(): [ # Test cancellation during spot cluster being launched. f'sky jobs launch --cloud gcp --zone {zone} -n {name} --use-spot "sleep 1000" -y -d', - 'sleep 60', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "STARTING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancelling the spot cluster during spot job being setup. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-2 --use-spot tests/test_yamls/test_long_setup.yaml -y -d', - 'sleep 300', + # The job is set up in the cluster, will shown as RUNNING. + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), f'sky jobs cancel -y -n {name}-2', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-2 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-2', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # Test cancellation during spot job is recovering. f'sky jobs launch --cloud gcp --zone {zone} -n {name}-3 --use-spot "sleep 1000" -y -d', - 'sleep 300', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RUNNING"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.RUNNING], + timeout=300 + _BUMP_UP_SECONDS), # Terminate the cluster manually. terminate_cmd, _JOB_WAIT_NOT_RUNNING.format(job_name=f'{name}-3'), f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "RECOVERING"', f'sky jobs cancel -y -n {name}-3', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLING\|CANCELLED"', - 'sleep 120', - f'{_GET_JOB_QUEUE} | grep {name}-3 | head -n1 | grep "CANCELLED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=f'{name}-3', + job_status=[sky.ManagedJobStatus.CANCELLED], + timeout=120 + _BUMP_UP_SECONDS), # The cluster should be terminated (STOPPING) after cancellation. We don't use the `=` operator here because # there can be multiple VM with the same name due to the recovery. (f's=$({query_state_cmd}) && echo "$s" && echo; [[ -z "$s" ]] || echo "$s" | grep -v -E "PROVISIONING|STAGING|RUNNING|REPAIRING|TERMINATED|SUSPENDING|SUSPENDED|SUSPENDED"' @@ -3334,8 +3647,12 @@ def test_managed_jobs_storage(generic_cloud: str): *STORAGE_SETUP_COMMANDS, f'sky jobs launch -n {name}{use_spot} --cloud {generic_cloud}{region_flag} {file_path} -y', region_validation_cmd, # Check if the bucket is created in the correct region - 'sleep 60', # Wait the spot queue to be updated - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=60 + _BUMP_UP_SECONDS), + # Wait for the job to be cleaned up. + 'sleep 20', f'[ $(aws s3api list-buckets --query "Buckets[?contains(Name, \'{storage_name}\')].Name" --output text | wc -l) -eq 0 ]', # Check if file was written to the mounted output bucket output_check_cmd @@ -3359,10 +3676,17 @@ def test_managed_jobs_tpu(): 'test-spot-tpu', [ f'sky jobs launch -n {name} --use-spot examples/tpu/tpuvm_mnist.yaml -y -d', - 'sleep 5', - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep STARTING', - 'sleep 900', # TPU takes a while to launch - f'{_GET_JOB_QUEUE} | grep {name} | head -n1 | grep "RUNNING\|SUCCEEDED"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.STARTING], + timeout=60 + _BUMP_UP_SECONDS), + # TPU takes a while to launch + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[ + sky.ManagedJobStatus.RUNNING, sky.ManagedJobStatus.SUCCEEDED + ], + timeout=900 + _BUMP_UP_SECONDS), ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3379,9 +3703,19 @@ def test_managed_jobs_inline_env(generic_cloud: str): test = Test( 'test-managed-jobs-inline-env', [ - f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', - 'sleep 20', - f'{_GET_JOB_QUEUE} | grep {name} | grep SUCCEEDED', + f'sky jobs launch -n {name} -y --cloud {generic_cloud} --env TEST_ENV="hello world" -- "echo "\\$TEST_ENV"; ([[ ! -z \\"\$TEST_ENV\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_IPS}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NODE_RANK}\\" ]] && [[ ! -z \\"\${constants.SKYPILOT_NUM_NODES}\\" ]]) || exit 1"', + _get_cmd_wait_until_managed_job_status_contains_matching_job_name( + job_name=name, + job_status=[sky.ManagedJobStatus.SUCCEEDED], + timeout=20 + _BUMP_UP_SECONDS), + f'JOB_ROW=$(sky jobs queue | grep {name} | head -n1) && ' + f'echo "$JOB_ROW" && echo "$JOB_ROW" | grep "SUCCEEDED" && ' + f'JOB_ID=$(echo "$JOB_ROW" | awk \'{{print $1}}\') && ' + f'echo "JOB_ID=$JOB_ID" && ' + # Test that logs are still available after the job finishes. + 'unset SKYPILOT_DEBUG; s=$(sky jobs logs $JOB_ID --refresh) && echo "$s" && echo "$s" | grep "hello world" && ' + # Make sure we skip the unnecessary logs. + 'echo "$s" | head -n1 | grep "Waiting for"', ], f'sky jobs cancel -y -n {name}', # Increase timeout since sky jobs queue -r can be blocked by other spot tests. @@ -3488,8 +3822,12 @@ def test_azure_start_stop_two_nodes(): f'sky start -y {name} -i 1', f'sky exec --num-nodes=2 {name} examples/azure_start_stop.yaml', f'sky logs {name} 2 --status', # Ensure the job succeeded. - 'sleep 200', - f's=$(sky status -r {name}) && echo "$s" && echo "$s" | grep "INIT\|STOPPED"' + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[ + sky.ClusterStatus.INIT, sky.ClusterStatus.STOPPED + ], + timeout=200 + _BUMP_UP_SECONDS) + f'|| {{ ssh {name} "cat ~/.sky/skylet.log"; exit 1; }}' ], f'sky down -y {name}', @@ -3857,6 +4195,15 @@ def test_skyserve_kubernetes_http(): run_one_test(test) +@pytest.mark.oci +@pytest.mark.serve +def test_skyserve_oci_http(): + """Test skyserve on OCI""" + name = _get_service_name() + test = _get_skyserve_http_test(name, 'oci', 20) + run_one_test(test) + + @pytest.mark.no_fluidstack # Fluidstack does not support T4 gpus for now @pytest.mark.serve def test_skyserve_llm(generic_cloud: str): @@ -4491,7 +4838,10 @@ def test_core_api_sky_launch_fast(generic_cloud: str): idle_minutes_to_autostop=1, fast=True) # Sleep to let the cluster autostop - time.sleep(120) + _get_cmd_wait_until_cluster_status_contains( + cluster_name=name, + cluster_status=[sky.ClusterStatus.STOPPED], + timeout=120) # Run it again - should work with fast=True sky.launch(task, cluster_name=name, diff --git a/tests/test_yamls/test_only_setup.yaml b/tests/test_yamls/test_only_setup.yaml new file mode 100644 index 00000000000..245d2b1de69 --- /dev/null +++ b/tests/test_yamls/test_only_setup.yaml @@ -0,0 +1,2 @@ +setup: | + echo "hello world" diff --git a/tests/unit_tests/test_admin_policy.py b/tests/unit_tests/test_admin_policy.py index 48e47a6007c..c9e7ad35af2 100644 --- a/tests/unit_tests/test_admin_policy.py +++ b/tests/unit_tests/test_admin_policy.py @@ -16,6 +16,10 @@ POLICY_PATH = os.path.join(os.path.dirname(os.path.dirname(sky.__file__)), 'examples', 'admin_policy') +if not os.path.exists(POLICY_PATH): + # This is used for GitHub Actions, as we copy the examples to the package. + POLICY_PATH = os.path.join(os.path.dirname(__file__), 'examples', + 'admin_policy') @pytest.fixture @@ -172,7 +176,7 @@ def _gen_cluster_record(status: sky.ClusterStatus, autostop: int) -> dict: idle_minutes_to_autostop=None) -@mock.patch('sky.provision.kubernetes.utils.get_all_kube_config_context_names', +@mock.patch('sky.provision.kubernetes.utils.get_all_kube_context_names', return_value=['kind-skypilot', 'kind-skypilot2', 'kind-skypilot3']) def test_dynamic_kubernetes_contexts_policy(add_example_policy_paths, task): _, config = _load_task_and_apply_policy( diff --git a/tests/unit_tests/test_backend_utils.py b/tests/unit_tests/test_backend_utils.py index 5da4410abb9..c9aa21567c2 100644 --- a/tests/unit_tests/test_backend_utils.py +++ b/tests/unit_tests/test_backend_utils.py @@ -22,6 +22,8 @@ return_value='~/.aws/credentials') @mock.patch('sky.backends.backend_utils._get_yaml_path_from_cluster_name', return_value='/tmp/fake/path') +@mock.patch('sky.backends.backend_utils._deterministic_cluster_yaml_hash', + return_value='fake-hash') @mock.patch('sky.utils.common_utils.fill_template') def test_write_cluster_config_w_remote_identity(mock_fill_template, *mocks) -> None: diff --git a/tests/unit_tests/test_dag.py b/tests/unit_tests/test_dag.py new file mode 100644 index 00000000000..9bc626b0085 --- /dev/null +++ b/tests/unit_tests/test_dag.py @@ -0,0 +1,148 @@ +"""Unit tests for sky.dag.""" +import pytest + +import sky + + +@pytest.fixture +def empty_dag(): + """Fixture for an empty DAG.""" + with sky.Dag() as dag: + yield dag + + +@pytest.fixture +def single_task_dag(): + """Fixture for a DAG with single task.""" + with sky.Dag() as dag: + sky.Task() + yield dag + + +@pytest.fixture +def linear_three_task_dag(): + """Fixture for a DAG with three tasks in a linear chain.""" + with sky.Dag() as dag: + task1 = sky.Task() + task2 = sky.Task() + task3 = sky.Task() + dag.add_edge(task1, task2) + dag.add_edge(task2, task3) + yield dag + + +@pytest.fixture +def branching_dag(): + """Fixture for a DAG with one task branching into two tasks.""" + with sky.Dag() as dag: + dag._tasks = [sky.Task() for _ in range(3)] + dag.add_edge(dag.tasks[0], dag.tasks[1]) + dag.add_edge(dag.tasks[0], dag.tasks[2]) + yield dag + + +@pytest.fixture +def merging_dag(): + """Fixture for a DAG with two tasks merging into one task.""" + with sky.Dag() as dag: + dag._tasks = [sky.Task() for _ in range(3)] + dag.add_edge(dag.tasks[0], dag.tasks[2]) + dag.add_edge(dag.tasks[1], dag.tasks[2]) + yield dag + + +@pytest.fixture +def multi_parent_dag(): + """Fixture for a DAG with two tasks merging into one task.""" + with sky.Dag() as dag: + dag._tasks = [sky.Task() for _ in range(3)] + dag.add_edge(dag.tasks[0], dag.tasks[2]) + dag.add_edge(dag.tasks[1], dag.tasks[2]) + yield dag + + +@pytest.fixture +def diamond_dag(): + """Fixture for a diamond-shaped DAG: A -> (B,C) -> D.""" + with sky.Dag() as dag: + dag._tasks = [sky.Task() for _ in range(4)] + dag.add_edge(dag.tasks[0], dag.tasks[1]) + dag.add_edge(dag.tasks[0], dag.tasks[2]) + dag.add_edge(dag.tasks[1], dag.tasks[3]) + dag.add_edge(dag.tasks[2], dag.tasks[3]) + yield dag + + +@pytest.fixture +def cyclic_dag(): + """Fixture for a DAG with a cycle (invalid DAG).""" + with sky.Dag() as dag: + task1 = sky.Task() + task2 = sky.Task() + dag.add_edge(task1, task2) + dag.add_edge(task2, task1) # Create cycle + yield dag + + +@pytest.mark.parametrize('dag_fixture,description', [ + ('empty_dag', 'Empty DAG'), + ('single_task_dag', 'Single task DAG'), + ('linear_three_task_dag', 'Linear chain of three tasks'), +]) +def test_is_chain_true_cases(request, dag_fixture, description): + """Test cases where is_chain() should return True.""" + dag = request.getfixturevalue(dag_fixture) + assert dag.is_chain(), f"Failed for case: {description}" + + +@pytest.mark.parametrize('dag_fixture,description', [ + ('branching_dag', 'Branching DAG'), + ('merging_dag', 'Merging DAG'), + ('diamond_dag', 'Diamond DAG'), +]) +def test_is_chain_false_cases(request, dag_fixture, description): + """Test cases where is_chain() should return False.""" + dag = request.getfixturevalue(dag_fixture) + assert not dag.is_chain(), f"Failed for case: {description}" + + +@pytest.mark.parametrize('dag_fixture,description,expected_old,expected_new', [ + ('linear_three_task_dag', 'Linear chain of three tasks', True, True), + ('multi_parent_dag', 'DAG with two tasks merging into one task', True, + False), +]) +def test_is_chain_regression(request, dag_fixture, description, expected_old, + expected_new): + """Regression test comparing new implementation with old behavior.""" + dag = request.getfixturevalue(dag_fixture) + + def old_is_chain(dag): + # Old implementation + is_chain = True + visited_zero_out_degree = False + for node in dag.graph.nodes: + out_degree = dag.graph.out_degree(node) + if out_degree > 1: + is_chain = False + break + elif out_degree == 0: + if visited_zero_out_degree: + is_chain = False + break + else: + visited_zero_out_degree = True + return is_chain + + assert dag.is_chain() == expected_new, f"Failed for case: {description}" + assert old_is_chain(dag) == expected_old, f"Failed for case: {description}" + + +# TODO(andy): Currently cyclic DAGs are not detected and is_chain() simply +# returns False. Once we implement cycle detection that raises an error, +# update this test to use pytest.raises. +@pytest.mark.xfail(reason="Cycle detection not implemented yet") +def test_is_chain_with_cycle(cyclic_dag): + """Test is_chain() with cyclic graph. + """ + with pytest.raises(ValueError): + cyclic_dag.is_chain() diff --git a/tests/unit_tests/test_recovery_strategy.py b/tests/unit_tests/test_recovery_strategy.py new file mode 100644 index 00000000000..da8e8142da0 --- /dev/null +++ b/tests/unit_tests/test_recovery_strategy.py @@ -0,0 +1,48 @@ +from unittest import mock + +from sky.exceptions import ClusterDoesNotExist +from sky.jobs import recovery_strategy + + +@mock.patch('sky.down') +@mock.patch('sky.usage.usage_lib.messages.usage.set_internal') +def test_terminate_cluster_retry_on_value_error(mock_set_internal, + mock_sky_down) -> None: + # Set up mock to fail twice with ValueError, then succeed + mock_sky_down.side_effect = [ + ValueError('Mock error 1'), + ValueError('Mock error 2'), + None, + ] + + # Call should succeed after retries + recovery_strategy.terminate_cluster('test-cluster') + + # Verify sky.down was called 3 times + assert mock_sky_down.call_count == 3 + mock_sky_down.assert_has_calls([ + mock.call('test-cluster'), + mock.call('test-cluster'), + mock.call('test-cluster'), + ]) + + # Verify usage.set_internal was called before each sky.down + assert mock_set_internal.call_count == 3 + + +@mock.patch('sky.down') +@mock.patch('sky.usage.usage_lib.messages.usage.set_internal') +def test_terminate_cluster_handles_nonexistent_cluster(mock_set_internal, + mock_sky_down) -> None: + # Set up mock to raise ClusterDoesNotExist + mock_sky_down.side_effect = ClusterDoesNotExist('test-cluster') + + # Call should succeed silently + recovery_strategy.terminate_cluster('test-cluster') + + # Verify sky.down was called once + assert mock_sky_down.call_count == 1 + mock_sky_down.assert_called_once_with('test-cluster') + + # Verify usage.set_internal was called once + assert mock_set_internal.call_count == 1 diff --git a/tests/unit_tests/test_resources.py b/tests/unit_tests/test_resources.py index 5006fc454aa..65c90544f49 100644 --- a/tests/unit_tests/test_resources.py +++ b/tests/unit_tests/test_resources.py @@ -140,6 +140,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) expected_config_base = { @@ -180,6 +181,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated') @@ -195,6 +197,7 @@ def test_aws_make_deploy_variables(*mocks) -> None: config = resource.make_deploy_variables(cluster_name, region, zones, + num_nodes=1, dryrun=True) assert config == expected_config, ('unexpected resource ' 'variables generated')