Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ansible provision command #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"psycopg2-binary",
"optuna",
"optuna-dashboard",
"ansible-core>=2.15.0",
]
dynamic = ["version"]

Expand All @@ -42,4 +43,4 @@ exclude = ["tests*"]

[project.urls]
"Homepage" = "https://github.com/dream3d-ai/torch-submit"
"Bug Tracker" = "https://github.com/dream3d-ai/torch-submit/issues"
"Bug Tracker" = "https://github.com/dream3d-ai/torch-submit/issues"
96 changes: 96 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
from pathlib import Path
from typer.testing import CliRunner
from rich.console import Console

from torch_submit.commands.cluster import app, config
from torch_submit.config import Node

runner = CliRunner()
console = Console()

@pytest.fixture
def mock_cluster(tmp_path):
"""Create a mock cluster configuration for testing."""
head_node = Node(
public_ip="1.2.3.4",
private_ip="10.0.0.1",
num_gpus=2,
nproc=4,
ssh_user="test",
ssh_pub_key_path=str(tmp_path / "test_key.pub"),
ssh_port=22,
ansible_playbook=None
)
worker_nodes = [
Node(
public_ip="1.2.3.5",
private_ip="10.0.0.2",
num_gpus=4,
nproc=8,
ssh_user="test",
ssh_pub_key_path=str(tmp_path / "test_key.pub"),
ssh_port=22,
ansible_playbook=None
)
]
config.add_cluster("test-cluster", head_node, worker_nodes)
return "test-cluster"

def test_provision_cluster(tmp_path, mock_cluster):
"""Test the cluster provision command."""
# Create test playbook
playbook = tmp_path / "test.yml"
playbook.write_text("""
- hosts: all
tasks:
- name: Test task
debug:
msg: "Test message"
""")

# Test provision command with both head and worker playbooks
result = runner.invoke(app, [
"provision",
mock_cluster,
"--head-playbook", str(playbook),
"--worker-playbook", str(playbook)
])
assert result.exit_code == 0

# Verify cluster configuration was updated
cluster = config.get_cluster(mock_cluster)
assert cluster.head_node.ansible_playbook == str(playbook)
assert cluster.worker_nodes[0].ansible_playbook == str(playbook)

def test_provision_cluster_not_found():
"""Test provision command with non-existent cluster."""
result = runner.invoke(app, [
"provision",
"nonexistent-cluster",
"--head-playbook", "test.yml"
])
assert result.exit_code == 1
assert "not found" in result.stdout

def test_provision_cluster_head_only(tmp_path, mock_cluster):
"""Test provision command with head node only."""
playbook = tmp_path / "head.yml"
playbook.write_text("""
- hosts: all
tasks:
- name: Head node task
debug:
msg: "Head node test"
""")

result = runner.invoke(app, [
"provision",
mock_cluster,
"--head-playbook", str(playbook)
])
assert result.exit_code == 0

cluster = config.get_cluster(mock_cluster)
assert cluster.head_node.ansible_playbook == str(playbook)
assert cluster.worker_nodes[0].ansible_playbook is None
53 changes: 49 additions & 4 deletions torch_submit/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rich.table import Table

from ..config import Config, Node
from ..executor import AnsibleExecutor

app = typer.Typer()
console = Console()
Expand Down Expand Up @@ -119,7 +120,7 @@ def remove_cluster(name: str):
Remove a cluster configuration.

Prompts the user for confirmation before removing the specified cluster configuration from the config.

Args:
name (str): The name of the cluster to remove.
"""
Expand All @@ -136,7 +137,7 @@ def edit_cluster(name: str):
Edit an existing cluster configuration.

Prompts the user for new cluster details and updates the specified cluster configuration in the config.

Args:
name (str): The name of the cluster to edit.
"""
Expand Down Expand Up @@ -167,7 +168,7 @@ def edit_cluster(name: str):
nproc = typer.prompt("Number of processes on worker node", default=worker.nproc, type=int)
ssh_user = typer.prompt("SSH user for worker node (optional)", default=worker.ssh_user or "")
ssh_pub_key_path = typer.prompt("SSH public key path for worker node (optional)", default=worker.ssh_pub_key_path or "")

worker_node = Node(public_ip, private_ip or None, num_gpus, nproc, ssh_user, ssh_pub_key_path)
worker_nodes.append(worker_node)

Expand All @@ -176,4 +177,48 @@ def edit_cluster(name: str):

# Update the cluster configuration
config.update_cluster(name, head_node, worker_nodes)
console.print(f"Cluster [bold green]{name}[/bold green] updated successfully.")
console.print(f"Cluster [bold green]{name}[/bold green] updated successfully.")


@app.command("provision")
def provision_cluster(
cluster_name: str = typer.Argument(..., help="Name of the cluster to provision"),
head_playbook: str = typer.Option(None, help="Path to ansible playbook for head node"),
worker_playbook: str = typer.Option(None, help="Path to ansible playbook for worker nodes"),
):
"""Provision a cluster using ansible playbooks.

Args:
cluster_name (str): Name of the cluster to provision.
head_playbook (str, optional): Path to ansible playbook for head node.
worker_playbook (str, optional): Path to ansible playbook for worker nodes.
"""
try:
cluster = config.get_cluster(cluster_name)
except ValueError:
console.print(f"[bold red]Error:[/bold red] Cluster '{cluster_name}' not found.")
raise typer.Exit(code=1)

executor = AnsibleExecutor(cluster_name)

# Update playbook paths
if head_playbook:
cluster.head_node.ansible_playbook = head_playbook
if worker_playbook:
for node in cluster.worker_nodes:
node.ansible_playbook = worker_playbook

# Save updated configuration
config.save_config()

# Execute playbooks
if head_playbook:
console.print(f"Provisioning head node {cluster.head_node.public_ip}...")
executor.execute_playbook(cluster.head_node)

if worker_playbook:
for node in cluster.worker_nodes:
console.print(f"Provisioning worker node {node.public_ip}...")
executor.execute_playbook(node)

console.print("[bold green]Cluster provisioning completed.[/bold green]")
4 changes: 4 additions & 0 deletions torch_submit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Node:
nproc (int): The number of processes that can run on the node.
ssh_user (Optional[str]): The SSH username for accessing the node, if available.
ssh_pub_key_path (Optional[str]): The path to the SSH public key file, if available.
ansible_playbook (Optional[str]): The path to the ansible playbook for node provisioning.
"""

public_ip: str
Expand All @@ -27,6 +28,7 @@ class Node:
ssh_user: Optional[str]
ssh_pub_key_path: Optional[str]
ssh_port: Optional[int]
ansible_playbook: Optional[str]

def __post_init__(self):
"""Initialize the Node object after creation."""
Expand Down Expand Up @@ -284,6 +286,7 @@ def save_config(self):
"ssh_user": cluster.head_node.ssh_user or None,
"ssh_pub_key_path": cluster.head_node.ssh_pub_key_path or None,
"ssh_port": cluster.head_node.ssh_port or None,
"ansible_playbook": cluster.head_node.ansible_playbook or None,
},
"worker_nodes": [
{
Expand All @@ -294,6 +297,7 @@ def save_config(self):
"ssh_user": node.ssh_user or None,
"ssh_pub_key_path": node.ssh_pub_key_path or None,
"ssh_port": node.ssh_port or None,
"ansible_playbook": node.ansible_playbook or None,
}
for node in cluster.worker_nodes
],
Expand Down
25 changes: 25 additions & 0 deletions torch_submit/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,31 @@ def _prepare_command(self, rank: int):
return f"{self.get_command(rank)} -- {self.job.command}"


class AnsibleExecutor(BaseExecutor):
"""Executes ansible playbooks across cluster nodes."""

def __init__(self, cluster_name: str):
self.cluster = Config().get_cluster(cluster_name)

def execute_playbook(self, node: Node):
if not node.ansible_playbook:
return

with NodeConnection(node) as conn:
# Copy playbook to remote
remote_playbook = f"/tmp/playbook_{node.public_ip}.yml"
conn.put(node.ansible_playbook, remote_playbook)

# Run ansible-playbook
try:
conn.run(f"ansible-playbook {remote_playbook}")
except Exception as e:
console.print(f"[bold red]Error running ansible playbook on {node.public_ip}:[/bold red] {str(e)}")
finally:
# Clean up remote playbook
conn.run(f"rm -f {remote_playbook}")


class JobExecutionManager:
@staticmethod
def submit_job(job: Job):
Expand Down