diff --git a/docs/source/reference/config.rst b/docs/source/reference/config.rst index 286788625bd..2c4862fbfa1 100644 --- a/docs/source/reference/config.rst +++ b/docs/source/reference/config.rst @@ -456,6 +456,21 @@ Available fields and semantics: # Reference: https://learn.microsoft.com/en-us/azure/storage/common/storage-account-overview storage_account: user-storage-account-name + # Specify subnet_id to use for instances (optional). + # SkyPilot created new vnet and subnet by default but it will reuse exisiting subnet if specified. + subnet_id: /subscriptions/subscription-id/resourceGroups/resource-group-name/providers/Microsoft.Network/virtualNetworks/vnet-name/subnets/subnet-name + + # Should instances be assigned private IPs only? (optional) + # + # Set to true to use private IPs to communicate between the local client and + # any SkyPilot nodes. This requires the networking stack be properly set up. + # + # When set to true, SkyPilot will only use private subnets to launch nodes and won't expose + # instances on public IP addresses. + # Reference: https://learn.microsoft.com/en-us/azure/virtual-network/virtual-network-manage-subnet?tabs=azure-portal + # Default: false. + use_internal_ips: true + # Advanced Kubernetes configurations (optional). kubernetes: # The networking mode for accessing SSH jump pod (optional). diff --git a/sky/clouds/azure.py b/sky/clouds/azure.py index eb76d2b5e48..f3848e01b28 100644 --- a/sky/clouds/azure.py +++ b/sky/clouds/azure.py @@ -366,6 +366,14 @@ def make_deploy_resources_variables( if resource_group_name is None: resource_group_name = f'{cluster_name.name_on_cloud}-{region_name}' + # Determine subnet_id if configured + subnet_id = skypilot_config.get_nested(('azure', 'subnet_id'), None) + + # Determine if internal IPs should be used + use_internal_ips = skypilot_config.get_nested( + ('azure', 'use_internal_ips'), False) + + # Setup commands to eliminate the banner and restart sshd. # This script will modify /etc/ssh/sshd_config and add a bash script # into .bashrc. The bash script will restart sshd if it has not been @@ -423,6 +431,8 @@ def _failover_disk_tier() -> Optional[resources_utils.DiskTier]: 'azure_subscription_id': self.get_project_id(dryrun), 'resource_group': resource_group_name, 'use_external_resource_group': use_external_resource_group, + 'subnet_id': subnet_id, + 'use_internal_ips': use_internal_ips, } # Setting disk performance tier for high disk tier. diff --git a/sky/provision/azure/azure-config-template.json b/sky/provision/azure/azure-config-template.json index 0c70c4d3999..ecb97fff7cf 100644 --- a/sky/provision/azure/azure-config-template.json +++ b/sky/provision/azure/azure-config-template.json @@ -25,6 +25,12 @@ "metadata": { "description": "Name of the Network Security Group associated with the SkyPilot cluster." } + }, + "existingSubnet": { + "type": "string", + "metadata": { + "description": "Existing subnet id to use." + } } }, "variables": { @@ -86,6 +92,7 @@ "apiVersion": "2019-11-01", "name": "[variables('vnetName')]", "location": "[variables('location')]", + "condition": "[equals(parameters('existingSubnet'), '')]", "properties": { "addressSpace": { "addressPrefixes": [ diff --git a/sky/provision/azure/config.py b/sky/provision/azure/config.py index e7ab59daa33..a80a53dbd0f 100644 --- a/sky/provision/azure/config.py +++ b/sky/provision/azure/config.py @@ -86,6 +86,8 @@ def bootstrap_instances( 'use_external_resource_group field') use_external_resource_group = provider_config['use_external_resource_group'] + subnet_id = provider_config.get('subnet_id', '') + if 'tags' in provider_config: params['tags'] = provider_config['tags'] @@ -142,12 +144,15 @@ def bootstrap_instances( cluster_id, nsg_name = get_cluster_id_and_nsg_name( resource_group=provider_config['resource_group'], cluster_name_on_cloud=cluster_name_on_cloud) + + # subnet_mask is generated only for new subnets subnet_mask = provider_config.get('subnet_mask') - if subnet_mask is None: - # choose a random subnet, skipping most common value of 0 - random.seed(cluster_id) - subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' - logger.info(f'Using subnet mask: {subnet_mask}') + # choose a random subnet, skipping most common value of 0 + random.seed(cluster_id) + subnet_mask = f'10.{random.randint(1, 254)}.0.0/16' + if subnet_id == '': + # subnet_mask is not used if subnet_id is provided + logger.info(f'Using subnet mask: {subnet_mask}') parameters = { 'properties': { @@ -165,7 +170,10 @@ def bootstrap_instances( }, 'location': { 'value': params['location'] - } + }, + 'existingSubnet': { + 'value': subnet_id + }, }, } } @@ -215,6 +223,7 @@ def bootstrap_instances( # append output resource ids to be used with vm creation provider_config['msi'] = outputs['msi']['value'] provider_config['nsg'] = outputs['nsg']['value'] - provider_config['subnet'] = outputs['subnet']['value'] + provider_config[ + 'subnet'] = outputs['subnet']['value'] if subnet_id == '' else subnet_id return config diff --git a/sky/templates/azure-ray.yml.j2 b/sky/templates/azure-ray.yml.j2 index 1140704a708..f9866a2f116 100644 --- a/sky/templates/azure-ray.yml.j2 +++ b/sky/templates/azure-ray.yml.j2 @@ -47,6 +47,10 @@ provider: # leakage. disable_launch_config_check: true + {%- if subnet_id is not none %} + subnet_id: {{subnet_id}} + {%- endif %} + use_internal_ips: {{use_internal_ips}} auth: ssh_user: azureuser diff --git a/sky/utils/schemas.py b/sky/utils/schemas.py index 851e77a57fc..9a4f08a5bfe 100644 --- a/sky/utils/schemas.py +++ b/sky/utils/schemas.py @@ -797,6 +797,12 @@ def get_config_schema(): 'resource_group_vm': { 'type': 'string', }, + 'subnet_id': { + 'type': 'string', + }, + 'use_internal_ips': { + 'type': 'boolean', + }, } }, 'kubernetes': {