diff --git a/pyproject.toml b/pyproject.toml index e53ece3..16c782c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "psycopg2-binary", "optuna", "optuna-dashboard", + "ansible-core>=2.15.0", ] dynamic = ["version"] @@ -42,4 +43,4 @@ exclude = ["tests*"] [project.urls] "Homepage" = "https://github.com/dream3d-ai/torch-submit" -"Bug Tracker" = "https://github.com/dream3d-ai/torch-submit/issues" \ No newline at end of file +"Bug Tracker" = "https://github.com/dream3d-ai/torch-submit/issues" diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 0000000..bae9c2b --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,96 @@ +import pytest +from pathlib import Path +from typer.testing import CliRunner +from rich.console import Console + +from torch_submit.commands.cluster import app, config +from torch_submit.config import Node + +runner = CliRunner() +console = Console() + +@pytest.fixture +def mock_cluster(tmp_path): + """Create a mock cluster configuration for testing.""" + head_node = Node( + public_ip="1.2.3.4", + private_ip="10.0.0.1", + num_gpus=2, + nproc=4, + ssh_user="test", + ssh_pub_key_path=str(tmp_path / "test_key.pub"), + ssh_port=22, + ansible_playbook=None + ) + worker_nodes = [ + Node( + public_ip="1.2.3.5", + private_ip="10.0.0.2", + num_gpus=4, + nproc=8, + ssh_user="test", + ssh_pub_key_path=str(tmp_path / "test_key.pub"), + ssh_port=22, + ansible_playbook=None + ) + ] + config.add_cluster("test-cluster", head_node, worker_nodes) + return "test-cluster" + +def test_provision_cluster(tmp_path, mock_cluster): + """Test the cluster provision command.""" + # Create test playbook + playbook = tmp_path / "test.yml" + playbook.write_text(""" + - hosts: all + tasks: + - name: Test task + debug: + msg: "Test message" + """) + + # Test provision command with both head and worker playbooks + result = runner.invoke(app, [ + "provision", + mock_cluster, + "--head-playbook", str(playbook), + "--worker-playbook", str(playbook) + ]) + assert result.exit_code == 0 + + # Verify cluster configuration was updated + cluster = config.get_cluster(mock_cluster) + assert cluster.head_node.ansible_playbook == str(playbook) + assert cluster.worker_nodes[0].ansible_playbook == str(playbook) + +def test_provision_cluster_not_found(): + """Test provision command with non-existent cluster.""" + result = runner.invoke(app, [ + "provision", + "nonexistent-cluster", + "--head-playbook", "test.yml" + ]) + assert result.exit_code == 1 + assert "not found" in result.stdout + +def test_provision_cluster_head_only(tmp_path, mock_cluster): + """Test provision command with head node only.""" + playbook = tmp_path / "head.yml" + playbook.write_text(""" + - hosts: all + tasks: + - name: Head node task + debug: + msg: "Head node test" + """) + + result = runner.invoke(app, [ + "provision", + mock_cluster, + "--head-playbook", str(playbook) + ]) + assert result.exit_code == 0 + + cluster = config.get_cluster(mock_cluster) + assert cluster.head_node.ansible_playbook == str(playbook) + assert cluster.worker_nodes[0].ansible_playbook is None diff --git a/torch_submit/commands/cluster.py b/torch_submit/commands/cluster.py index dbc1b67..24571bd 100644 --- a/torch_submit/commands/cluster.py +++ b/torch_submit/commands/cluster.py @@ -5,6 +5,7 @@ from rich.table import Table from ..config import Config, Node +from ..executor import AnsibleExecutor app = typer.Typer() console = Console() @@ -119,7 +120,7 @@ def remove_cluster(name: str): Remove a cluster configuration. Prompts the user for confirmation before removing the specified cluster configuration from the config. - + Args: name (str): The name of the cluster to remove. """ @@ -136,7 +137,7 @@ def edit_cluster(name: str): Edit an existing cluster configuration. Prompts the user for new cluster details and updates the specified cluster configuration in the config. - + Args: name (str): The name of the cluster to edit. """ @@ -167,7 +168,7 @@ def edit_cluster(name: str): nproc = typer.prompt("Number of processes on worker node", default=worker.nproc, type=int) ssh_user = typer.prompt("SSH user for worker node (optional)", default=worker.ssh_user or "") ssh_pub_key_path = typer.prompt("SSH public key path for worker node (optional)", default=worker.ssh_pub_key_path or "") - + worker_node = Node(public_ip, private_ip or None, num_gpus, nproc, ssh_user, ssh_pub_key_path) worker_nodes.append(worker_node) @@ -176,4 +177,48 @@ def edit_cluster(name: str): # Update the cluster configuration config.update_cluster(name, head_node, worker_nodes) - console.print(f"Cluster [bold green]{name}[/bold green] updated successfully.") \ No newline at end of file + console.print(f"Cluster [bold green]{name}[/bold green] updated successfully.") + + +@app.command("provision") +def provision_cluster( + cluster_name: str = typer.Argument(..., help="Name of the cluster to provision"), + head_playbook: str = typer.Option(None, help="Path to ansible playbook for head node"), + worker_playbook: str = typer.Option(None, help="Path to ansible playbook for worker nodes"), +): + """Provision a cluster using ansible playbooks. + + Args: + cluster_name (str): Name of the cluster to provision. + head_playbook (str, optional): Path to ansible playbook for head node. + worker_playbook (str, optional): Path to ansible playbook for worker nodes. + """ + try: + cluster = config.get_cluster(cluster_name) + except ValueError: + console.print(f"[bold red]Error:[/bold red] Cluster '{cluster_name}' not found.") + raise typer.Exit(code=1) + + executor = AnsibleExecutor(cluster_name) + + # Update playbook paths + if head_playbook: + cluster.head_node.ansible_playbook = head_playbook + if worker_playbook: + for node in cluster.worker_nodes: + node.ansible_playbook = worker_playbook + + # Save updated configuration + config.save_config() + + # Execute playbooks + if head_playbook: + console.print(f"Provisioning head node {cluster.head_node.public_ip}...") + executor.execute_playbook(cluster.head_node) + + if worker_playbook: + for node in cluster.worker_nodes: + console.print(f"Provisioning worker node {node.public_ip}...") + executor.execute_playbook(node) + + console.print("[bold green]Cluster provisioning completed.[/bold green]") diff --git a/torch_submit/config.py b/torch_submit/config.py index 6577c50..e600313 100644 --- a/torch_submit/config.py +++ b/torch_submit/config.py @@ -18,6 +18,7 @@ class Node: nproc (int): The number of processes that can run on the node. ssh_user (Optional[str]): The SSH username for accessing the node, if available. ssh_pub_key_path (Optional[str]): The path to the SSH public key file, if available. + ansible_playbook (Optional[str]): The path to the ansible playbook for node provisioning. """ public_ip: str @@ -27,6 +28,7 @@ class Node: ssh_user: Optional[str] ssh_pub_key_path: Optional[str] ssh_port: Optional[int] + ansible_playbook: Optional[str] def __post_init__(self): """Initialize the Node object after creation.""" @@ -284,6 +286,7 @@ def save_config(self): "ssh_user": cluster.head_node.ssh_user or None, "ssh_pub_key_path": cluster.head_node.ssh_pub_key_path or None, "ssh_port": cluster.head_node.ssh_port or None, + "ansible_playbook": cluster.head_node.ansible_playbook or None, }, "worker_nodes": [ { @@ -294,6 +297,7 @@ def save_config(self): "ssh_user": node.ssh_user or None, "ssh_pub_key_path": node.ssh_pub_key_path or None, "ssh_port": node.ssh_port or None, + "ansible_playbook": node.ansible_playbook or None, } for node in cluster.worker_nodes ], diff --git a/torch_submit/executor.py b/torch_submit/executor.py index 1eecb78..2034b9b 100644 --- a/torch_submit/executor.py +++ b/torch_submit/executor.py @@ -487,6 +487,31 @@ def _prepare_command(self, rank: int): return f"{self.get_command(rank)} -- {self.job.command}" +class AnsibleExecutor(BaseExecutor): + """Executes ansible playbooks across cluster nodes.""" + + def __init__(self, cluster_name: str): + self.cluster = Config().get_cluster(cluster_name) + + def execute_playbook(self, node: Node): + if not node.ansible_playbook: + return + + with NodeConnection(node) as conn: + # Copy playbook to remote + remote_playbook = f"/tmp/playbook_{node.public_ip}.yml" + conn.put(node.ansible_playbook, remote_playbook) + + # Run ansible-playbook + try: + conn.run(f"ansible-playbook {remote_playbook}") + except Exception as e: + console.print(f"[bold red]Error running ansible playbook on {node.public_ip}:[/bold red] {str(e)}") + finally: + # Clean up remote playbook + conn.run(f"rm -f {remote_playbook}") + + class JobExecutionManager: @staticmethod def submit_job(job: Job):