From 941fce9d4188597f937819fe1b9daf9d74b24f82 Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Thu, 10 Oct 2024 11:28:59 +0200 Subject: [PATCH 1/4] geo_data_policy initial comit --- sky/serve/load_balancer.py | 25 +++++-- sky/serve/load_balancing_policies.py | 99 +++++++++++++++++++++++++++- 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index c15f71e214a..17b3b26c5f2 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -27,7 +27,7 @@ class SkyServeLoadBalancer: policy. """ - def __init__(self, controller_url: str, load_balancer_port: int) -> None: + def __init__(self, controller_url: str, load_balancer_port: int, load_balancing_policy: lb_policies.LoadBalancingPolicy) -> None: """Initialize the load balancer. Args: @@ -37,8 +37,7 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None: self._app = fastapi.FastAPI() self._controller_url: str = controller_url self._load_balancer_port: int = load_balancer_port - self._load_balancing_policy: lb_policies.LoadBalancingPolicy = ( - lb_policies.RoundRobinPolicy()) + self._load_balancing_policy: load_balancing_policy self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) # TODO(tian): httpx.Client has a resource limit of 100 max connections @@ -223,9 +222,19 @@ async def startup(): uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port) -def run_load_balancer(controller_addr: str, load_balancer_port: int): +def run_load_balancer(controller_addr: str, load_balancer_port: int, policy_name: str): + if policy_name == 'round_robin': + policy = lb_policies.RoundRobinPolicy() + elif policy_name == 'geo_data': + #TODO(acuadron): Right now the locations of the VMs have to be inputed manually during the GeoDataPolicy + # instantiation, change this behaviour. We should, store the location of the replicas during their creation. + policy = lb_policies.GeoDataPolicy({}) + else: + raise ValueError(f"Unknown load balancing policy: {policy_name}") + load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, - load_balancer_port=load_balancer_port) + load_balancer_port=load_balancer_port, + load_balancing_policy=policy) load_balancer.run() @@ -241,5 +250,9 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int): required=True, default=8890, help='The port where the load balancer listens to.') + parser.add_argument('--load-balancing-policy', + choices=['round_robin', 'geo_data'], + default='round_robin', #TODO(acuadron): Change it to geo_data when ready + help='The load balancing policy to use.') args = parser.parse_args() - run_load_balancer(args.controller_addr, args.load_balancer_port) + run_load_balancer(args.controller_addr, args.load_balancer_port, args.load_balancing_policy) diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index 34c1fa4249b..7ae0edfd403 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -1,7 +1,9 @@ """LoadBalancingPolicy: Policy to select endpoint.""" import random import typing -from typing import List, Optional +from typing import Dict, List, Optional, Tuple + +import httpx from sky import sky_logging @@ -68,3 +70,98 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: ready_replica_url = self.ready_replicas[self.index] self.index = (self.index + 1) % len(self.ready_replicas) return ready_replica_url + + +class GeoDataPolicy(LoadBalancingPolicy): + """Geo-data load balancing policy using an online GeoIP service.""" + + def __init__(self, replica_locations: Dict[str, Tuple[float, + float]]) -> None: + """ + Initialize with a mapping from replica URLs to their (latitude, longitude). + """ + super().__init__() + self.replica_locations = replica_locations # type: Dict[str, Tuple[float, float]] + + def set_ready_replicas(self, ready_replicas: List[str]) -> None: + # Ensure all replicas have associated locations + for replica in ready_replicas: + if replica not in self.replica_locations: + # Every replica must have a valid location + raise ValueError( + f"Replica {replica} does not have a corresponding location." + ) + self.ready_replicas = ready_replicas + + def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: + user_location = self._get_user_location(request) + if not user_location: + # If user location can't be determined, select a random replica + random.shuffle(self.ready_replicas) + return self.ready_replicas[0] + + # Find the closest replica + min_distance = float('inf') + nearest_replica = None + for replica in self.ready_replicas: + replica_location = self.replica_locations[replica] + distance = self._calculate_distance(user_location, replica_location) + if distance < min_distance: + min_distance = distance + nearest_replica = replica + + return nearest_replica + + async def _get_user_location( + self, request: 'fastapi.Request') -> Optional[Tuple[float, float]]: + # Extract the user's IP address + ip_address = request.client.host + if not ip_address: + logger.warning('Could not extract IP address from request.') + return None + + # Perform GeoIP lookup using httpx, limited to 150 requests per minute. + # TODO(acuadron): Use IP caching to reduce the number of requests. + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f'http://ip-api.com/json/{ip_address}', timeout=2) + if response.status_code == 200: + data = response.json() + if data['status'] == 'success': + latitude = data['lat'] + longitude = data['lon'] + return (latitude, longitude) + else: + logger.warning( + f"GeoIP lookup failed: {data.get('message', 'Unknown error')}" + ) + return None + else: + logger.warning( + f'GeoIP lookup failed with status code {response.status_code}' + ) + return None + except Exception as e: + logger.warning(f'Failed to get location for IP {ip_address}: {e}') + return None + + def _calculate_distance(self, loc1: Tuple[float, float], + loc2: Tuple[float, float]) -> float: + # Haversine formula to calculate the great-circle distance + import math + lat1, lon1 = loc1 + lat2, lon2 = loc2 + R = 6371 # Earth radius in kilometers + + phi1 = math.radians(lat1) + phi2 = math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = (math.sin(delta_phi / 2)**2 + + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2)**2) + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + + distance = R * c + return distance From 4ad361654b4818f6cc1c9af805612bb8b3fbf59c Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Thu, 10 Oct 2024 11:38:28 +0200 Subject: [PATCH 2/4] maked sync --- sky/serve/load_balancing_policies.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index 7ae0edfd403..becf5906a05 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -112,7 +112,7 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: return nearest_replica - async def _get_user_location( + def _get_user_location( self, request: 'fastapi.Request') -> Optional[Tuple[float, float]]: # Extract the user's IP address ip_address = request.client.host @@ -121,10 +121,10 @@ async def _get_user_location( return None # Perform GeoIP lookup using httpx, limited to 150 requests per minute. - # TODO(acuadron): Use IP caching to reduce the number of requests. + # TODO(acuadron): Use IP caching to reduce the number of requests. Async? try: - async with httpx.AsyncClient() as client: - response = await client.get( + with httpx.Client() as client: + response = client.get( f'http://ip-api.com/json/{ip_address}', timeout=2) if response.status_code == 200: data = response.json() From 5bc7315aaab61dc40edb44f5df07ed87b461dffc Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Thu, 10 Oct 2024 11:39:52 +0200 Subject: [PATCH 3/4] made round robing the default --- sky/serve/load_balancer.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 17b3b26c5f2..9ce7ace92a9 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -27,7 +27,9 @@ class SkyServeLoadBalancer: policy. """ - def __init__(self, controller_url: str, load_balancer_port: int, load_balancing_policy: lb_policies.LoadBalancingPolicy) -> None: + def __init__( + self, controller_url: str, load_balancer_port: int, + load_balancing_policy: lb_policies.LoadBalancingPolicy) -> None: """Initialize the load balancer. Args: @@ -37,7 +39,8 @@ def __init__(self, controller_url: str, load_balancer_port: int, load_balancing_ self._app = fastapi.FastAPI() self._controller_url: str = controller_url self._load_balancer_port: int = load_balancer_port - self._load_balancing_policy: load_balancing_policy + self._load_balancing_policy: lb_policies.LoadBalancingPolicy = ( + load_balancing_policy) self._request_aggregator: serve_utils.RequestsAggregator = ( serve_utils.RequestTimestamp()) # TODO(tian): httpx.Client has a resource limit of 100 max connections @@ -222,15 +225,14 @@ async def startup(): uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port) -def run_load_balancer(controller_addr: str, load_balancer_port: int, policy_name: str): - if policy_name == 'round_robin': - policy = lb_policies.RoundRobinPolicy() - elif policy_name == 'geo_data': +def run_load_balancer(controller_addr: str, load_balancer_port: int, + policy_name: str): + # By default, the round robin policy is used. + policy: lb_policies.LoadBalancingPolicy = lb_policies.RoundRobinPolicy() + if policy_name == 'geo_data': #TODO(acuadron): Right now the locations of the VMs have to be inputed manually during the GeoDataPolicy # instantiation, change this behaviour. We should, store the location of the replicas during their creation. - policy = lb_policies.GeoDataPolicy({}) - else: - raise ValueError(f"Unknown load balancing policy: {policy_name}") + policy: lb_policies.LoadBalancingPolicy = lb_policies.GeoDataPolicy({}) load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, load_balancer_port=load_balancer_port, @@ -250,9 +252,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int, policy_name required=True, default=8890, help='The port where the load balancer listens to.') - parser.add_argument('--load-balancing-policy', - choices=['round_robin', 'geo_data'], - default='round_robin', #TODO(acuadron): Change it to geo_data when ready - help='The load balancing policy to use.') + parser.add_argument( + '--load-balancing-policy', + choices=['round_robin', 'geo_data'], + default='round_robin', #TODO(acuadron): Change it to geo_data when ready + help='The load balancing policy to use.') args = parser.parse_args() - run_load_balancer(args.controller_addr, args.load_balancer_port, args.load_balancing_policy) + run_load_balancer(args.controller_addr, args.load_balancer_port, + args.load_balancing_policy) From 4df4adec7d6e8d96ecf96b14ddd5827dff1c7239 Mon Sep 17 00:00:00 2001 From: AlexCuadron Date: Thu, 10 Oct 2024 13:10:42 +0200 Subject: [PATCH 4/4] linting --- sky/serve/load_balancer.py | 18 ++++++----- sky/serve/load_balancing_policies.py | 45 ++++++++++++++-------------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 9ce7ace92a9..64a11bcbe10 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -230,9 +230,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int, # By default, the round robin policy is used. policy: lb_policies.LoadBalancingPolicy = lb_policies.RoundRobinPolicy() if policy_name == 'geo_data': - #TODO(acuadron): Right now the locations of the VMs have to be inputed manually during the GeoDataPolicy - # instantiation, change this behaviour. We should, store the location of the replicas during their creation. - policy: lb_policies.LoadBalancingPolicy = lb_policies.GeoDataPolicy({}) + #TODO(acuadron): Right now the locations of the VMs have + # to be inputed manually during the GeoDataPolicy + # instantiation, change this behaviour. We should, store + # the location of the replicas during their creation. + policy = lb_policies.GeoDataPolicy({}) load_balancer = SkyServeLoadBalancer(controller_url=controller_addr, load_balancer_port=load_balancer_port, @@ -252,11 +254,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int, required=True, default=8890, help='The port where the load balancer listens to.') - parser.add_argument( - '--load-balancing-policy', - choices=['round_robin', 'geo_data'], - default='round_robin', #TODO(acuadron): Change it to geo_data when ready - help='The load balancing policy to use.') + #TODO(acuadron): Change default to geo_data when ready + parser.add_argument('--load-balancing-policy', + choices=['round_robin', 'geo_data'], + default='round_robin', + help='The load balancing policy to use.') args = parser.parse_args() run_load_balancer(args.controller_addr, args.load_balancer_port, args.load_balancing_policy) diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index becf5906a05..b67b33b7615 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -1,4 +1,5 @@ """LoadBalancingPolicy: Policy to select endpoint.""" +import math import random import typing from typing import Dict, List, Optional, Tuple @@ -77,26 +78,23 @@ class GeoDataPolicy(LoadBalancingPolicy): def __init__(self, replica_locations: Dict[str, Tuple[float, float]]) -> None: - """ - Initialize with a mapping from replica URLs to their (latitude, longitude). - """ super().__init__() - self.replica_locations = replica_locations # type: Dict[str, Tuple[float, float]] + self.replica_locations = replica_locations def set_ready_replicas(self, ready_replicas: List[str]) -> None: # Ensure all replicas have associated locations for replica in ready_replicas: if replica not in self.replica_locations: # Every replica must have a valid location - raise ValueError( - f"Replica {replica} does not have a corresponding location." - ) + raise ValueError(f'Replica {replica} does not have \ + a corresponding location.') self.ready_replicas = ready_replicas def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: user_location = self._get_user_location(request) if not user_location: - # If user location can't be determined, select a random replica + # If user location can't be determined, + # select a random replica random.shuffle(self.ready_replicas) return self.ready_replicas[0] @@ -117,15 +115,18 @@ def _get_user_location( # Extract the user's IP address ip_address = request.client.host if not ip_address: - logger.warning('Could not extract IP address from request.') + logger.warning('Could not extract IP \ + address from request.') return None # Perform GeoIP lookup using httpx, limited to 150 requests per minute. - # TODO(acuadron): Use IP caching to reduce the number of requests. Async? + # TODO(acuadron): + # - Use IP caching to reduce the number of requests. + # - Make Async? try: with httpx.Client() as client: - response = client.get( - f'http://ip-api.com/json/{ip_address}', timeout=2) + response = client.get(f'http://ip-api.com/json/{ip_address}', + timeout=2) if response.status_code == 200: data = response.json() if data['status'] == 'success': @@ -133,26 +134,24 @@ def _get_user_location( longitude = data['lon'] return (latitude, longitude) else: - logger.warning( - f"GeoIP lookup failed: {data.get('message', 'Unknown error')}" - ) + logger.warning(f'GeoIP lookup failed: \ + {data.get("message", "Unknown error")}') return None else: - logger.warning( - f'GeoIP lookup failed with status code {response.status_code}' - ) + logger.warning(f'GeoIP lookup failed with \ + status code {response.status_code}') return None - except Exception as e: - logger.warning(f'Failed to get location for IP {ip_address}: {e}') + except httpx.RequestError as e: + logger.warning(f'Failed to get location \ + for IP {ip_address}: {e}') return None def _calculate_distance(self, loc1: Tuple[float, float], loc2: Tuple[float, float]) -> float: # Haversine formula to calculate the great-circle distance - import math lat1, lon1 = loc1 lat2, lon2 = loc2 - R = 6371 # Earth radius in kilometers + earth_radius = 6371 # Earth radius in kilometers phi1 = math.radians(lat1) phi2 = math.radians(lat2) @@ -163,5 +162,5 @@ def _calculate_distance(self, loc1: Tuple[float, float], math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2)**2) c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) - distance = R * c + distance = earth_radius * c return distance