Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCP] fix firewall rule #4445

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sky/clouds/scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class SCP(clouds.Cloud):
(f'Spot instances are not supported in {_REPR}.'),
clouds.CloudImplementationFeatures.CUSTOM_DISK_TIER:
(f'Custom disk tiers are not supported in {_REPR}.'),
clouds.CloudImplementationFeatures.OPEN_PORTS:
(f'Opening ports is currently not supported on {_REPR}.'),
}

_INDENT_PREFIX = ' '
Expand Down Expand Up @@ -236,7 +234,7 @@ def _get_default_ami(cls, region_name: str, instance_type: str) -> str:
if acc is not None:
assert len(acc) == 1, acc
image_id = service_catalog.get_image_id_from_tag(
'skypilot:gpu-ubuntu-1804', region_name, clouds='scp')
'skypilot:gpu-ubuntu-2204', region_name, clouds='scp')
if image_id is not None:
return image_id
# Raise ResourcesUnavailableError to make sure the failover in
Expand Down
154 changes: 151 additions & 3 deletions sky/clouds/utils/scp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def del_security_group(self, sg_id):
url = f'{API_ENDPOINT}/security-group/v2/security-groups/{sg_id}'
return self._delete(url)

def del_firwall_rules(self, firewall_id, rule_id_list):
def del_firewall_rules(self, firewall_id, rule_id_list):
url = f'{API_ENDPOINT}/firewall/v2/firewalls/{firewall_id}/rules'
request_body = {'ruleDeletionType': 'PARTIAL', 'ruleIds': rule_id_list}
return self._delete(url, request_body=request_body)
Expand All @@ -422,11 +422,11 @@ def get_vm_info(self, vm_id):
url = f'{API_ENDPOINT}/virtual-server/v3/virtual-servers/{vm_id}'
return self._get(url, contents_key=None)

def get_firewal_rule_info(self, firewall_id, rule_id):
def get_firewall_rule_info(self, firewall_id, rule_id):
url = f'{API_ENDPOINT}/firewall/v2/firewalls/{firewall_id}/rules/{rule_id}' # pylint: disable=line-too-long
return self._get(url, contents_key=None)

def list_firwalls(self):
def list_firewalls(self):
url = f'{API_ENDPOINT}/firewall/v2/firewalls'
return self._get(url)

Expand All @@ -442,3 +442,151 @@ def start_instance(self, vm_id):
def stop_instance(self, vm_id):
url = f'{API_ENDPOINT}/virtual-server/v2/virtual-servers/{vm_id}/stop'
return self._post(url=url, request_body={})

def list_security_group_rules(self, sg_id):
url = f'{API_ENDPOINT}/security-group/v2/security-groups/{sg_id}/rules'
return self._get(url)

def _check_existing_security_group_in_rule(self, sg_id, port):
response = self.list_security_group_rules(sg_id)
rules = []
for rule in response:
rule_direction = rule['ruleDirection']
if rule_direction == 'IN':
rules.append(rule)
for rule in rules:
port_list = rule['tcpServices']
if port in port_list:
return False
return True

def _check_existing_security_group_out_rule(self, sg_id, port):
response = self.list_security_group_rules(sg_id)
rules = []
for rule in response:
rule_direction = rule['ruleDirection']
if rule_direction == 'OUT':
rules.append(rule)
for rule in rules:
port_list = rule['tcpServices']
if port in port_list:
return False
return True

def add_new_security_group_in_rule(self, sg_id, port):
if self._check_existing_security_group_in_rule(sg_id, port):
url = f'{API_ENDPOINT}/security-group/v2/security-groups/{sg_id}/rules' # pylint: disable=line-too-long
request_body = {
'ruleDirection': 'IN',
'services': [{
'serviceType': 'TCP',
'serviceValue': port
}],
'sourceIpAddresses': ['0.0.0.0/0'],
'ruleDescription': 'skyserve rule'
}
return self._post(url, request_body)

def add_new_security_group_out_rule(self, sg_id, port):
if self._check_existing_security_group_out_rule(sg_id, port):
url = f'{API_ENDPOINT}/security-group/v2/security-groups/{sg_id}/rules' # pylint: disable=line-too-long
request_body = {
'ruleDirection': 'OUT',
'services': [{
'serviceType': 'TCP',
'serviceValue': port
}],
'destinationIpAddresses': ['0.0.0.0/0'],
'ruleDescription': 'skyserve rule'
}
return self._post(url, request_body)

def list_firewall_rules(self, firewall_id):
url = f'{API_ENDPOINT}/firewall/v2/firewalls/{firewall_id}/rules'
return self._get(url)

def _check_existing_firewall_in_rule(self, firewall_id, internal_ip, port):
response = self.list_firewall_rules(firewall_id)
rule_list = []
for rule in response:
rule_direction = rule['ruleDirection']
if rule_direction == 'IN' and internal_ip == rule[
'destinationIpAddresses'][0]:
rule_list.append(rule)
for rule in rule_list:
port_list = rule['tcpServices']
if port in port_list:
return False
return True

def _check_existing_firewall_out_rule(self, firewall_id, internal_ip, port):
response = self.list_firewall_rules(firewall_id)
rule_list = []
for rule in response:
rule_direction = rule['ruleDirection']
if rule_direction == 'OUT' and internal_ip == rule[
'sourceIpAddresses'][0]:
rule_list.append(rule)
for rule in rule_list:
port_list = rule['tcpServices']
if port in port_list:
return False
return True

def add_new_firewall_inbound_rule(self, firewall_id, internal_ip, port):
if self._check_existing_firewall_in_rule(firewall_id, internal_ip,
port):
url = f'{API_ENDPOINT}/firewall/v2/firewalls/{firewall_id}/rules'
request_body = {
'sourceIpAddresses': ['0.0.0.0/0'],
'destinationIpAddresses': [internal_ip],
'services': [{
'serviceType': 'TCP',
'serviceValue': port
}],
'ruleDirection': 'IN',
'ruleAction': 'ALLOW',
'isRuleEnabled': True,
'ruleLocationType': 'FIRST',
'ruleDescription': 'skyserve rule'
}
return self._post(url, request_body)

def add_new_firewall_outbound_rule(self, firewall_id, internal_ip, port):
if self._check_existing_firewall_out_rule(firewall_id, internal_ip,
port):
url = f'{API_ENDPOINT}/firewall/v2/firewalls/{firewall_id}/rules'
request_body = {
'sourceIpAddresses': [internal_ip],
'destinationIpAddresses': ['0.0.0.0/0'],
'services': [{
'serviceType': 'TCP',
'serviceValue': port
}],
'ruleDirection': 'OUT',
'ruleAction': 'ALLOW',
'isRuleEnabled': True,
'ruleLocationType': 'FIRST',
'ruleDescription': 'skyserve rule'
}
return self._post(url, request_body)

def wait_firewall_inbound_rule_complete(self, firewall_id, rule_id):
while True:
time.sleep(5)
rule_info = self.get_firewall_rule_info(firewall_id, rule_id)
if rule_info['ruleState'] == 'ACTIVE':
break
return

def wait_firewall_outbound_rule_complete(self, firewall_id, rule_id):
while True:
time.sleep(5)
rule_info = self.get_firewall_rule_info(firewall_id, rule_id)
if rule_info['ruleState'] == 'ACTIVE':
break
return

def get_virtual_server_info(self, vm_id):
url = f'{API_ENDPOINT}/virtual-server/v3/virtual-servers/{vm_id}'
return self._get(url=url, contents_key=None)
1 change: 1 addition & 0 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sky.provision import lambda_cloud
from sky.provision import oci
from sky.provision import runpod
from sky.provision import scp
from sky.provision import vsphere
from sky.utils import command_runner
from sky.utils import timeline
Expand Down
4 changes: 4 additions & 0 deletions sky/provision/scp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""SCP provisioner for SkyPilot."""

from sky.provision.scp.instance import cleanup_ports
from sky.provision.scp.instance import open_ports
73 changes: 73 additions & 0 deletions sky/provision/scp/instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""SCP instance provisioning."""

import time
from typing import Any, Dict, List, Optional
from sky.clouds.utils import scp_utils


def open_ports( # pylint: disable=unused-argument
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
scp_client = scp_utils.SCPClient()
vm_list = scp_client.list_instances()

for vm in vm_list:
vm_info = scp_client.get_virtual_server_info(vm['virtualServerId'])
sg_id = vm_info['securityGroupIds'][0]['securityGroupId']
scp_client.add_new_security_group_in_rule(sg_id, ports[0])
scp_client.add_new_security_group_out_rule(sg_id, ports[0])

vpc_id = vm_info['vpcId']
firewall_list = scp_client.list_firewalls()
internal_ip = vm_info['ip']

for firewall in firewall_list:
if firewall['vpcId'] == vpc_id:
firewall_id = firewall['firewallId']

attempts = 0
max_attempts = 300
while attempts < max_attempts:
try:
rule_info = scp_client.add_new_firewall_inbound_rule(
firewall_id, internal_ip, ports[0])
if rule_info is not None:
rule_id = rule_info['resourceId']
scp_client.wait_firewall_inbound_rule_complete(
firewall_id, rule_id)
break
except Exception as e:
attempts += 1
time.sleep(10)
continue

attempts = 0
max_attempts = 300
while attempts < max_attempts:
try:
rule_info = scp_client.add_new_firewall_outbound_rule(
firewall_id, internal_ip, ports[0])
if rule_info is not None:
rule_id = rule_info['resourceId']
scp_client.wait_firewall_outbound_rule_complete(
firewall_id, rule_id)
break
except Exception as e:
attempts += 1
time.sleep(10)
continue


def cleanup_ports( # pylint: disable=unused-argument, pointless-string-statement
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
"""cleanup_ports is implemented
in sky/skylet/providers/scp/node_provider.py$terminate_node
because it cannot be reached for SCP after terminate_node
"""
Loading
Loading