Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] User Location LB policy [WIP] #2

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class SkyServeLoadBalancer:
policy.
"""

def __init__(self, controller_url: str, load_balancer_port: int) -> None:
def __init__(
self, controller_url: str, load_balancer_port: int,
load_balancing_policy: lb_policies.LoadBalancingPolicy) -> None:
"""Initialize the load balancer.

Args:
Expand All @@ -38,7 +40,7 @@ def __init__(self, controller_url: str, load_balancer_port: int) -> None:
self._controller_url: str = controller_url
self._load_balancer_port: int = load_balancer_port
self._load_balancing_policy: lb_policies.LoadBalancingPolicy = (
lb_policies.RoundRobinPolicy())
load_balancing_policy)
self._request_aggregator: serve_utils.RequestsAggregator = (
serve_utils.RequestTimestamp())
# TODO(tian): httpx.Client has a resource limit of 100 max connections
Expand Down Expand Up @@ -223,9 +225,20 @@ async def startup():
uvicorn.run(self._app, host='0.0.0.0', port=self._load_balancer_port)


def run_load_balancer(controller_addr: str, load_balancer_port: int):
def run_load_balancer(controller_addr: str, load_balancer_port: int,
policy_name: str):
# By default, the round robin policy is used.
policy: lb_policies.LoadBalancingPolicy = lb_policies.RoundRobinPolicy()
if policy_name == 'geo_data':
#TODO(acuadron): Right now the locations of the VMs have
# to be inputed manually during the GeoDataPolicy
# instantiation, change this behaviour. We should, store
# the location of the replicas during their creation.
policy = lb_policies.GeoDataPolicy({})

load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
load_balancer_port=load_balancer_port)
load_balancer_port=load_balancer_port,
load_balancing_policy=policy)
load_balancer.run()


Expand All @@ -241,5 +254,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int):
required=True,
default=8890,
help='The port where the load balancer listens to.')
#TODO(acuadron): Change default to geo_data when ready
parser.add_argument('--load-balancing-policy',
choices=['round_robin', 'geo_data'],
default='round_robin',
help='The load balancing policy to use.')
args = parser.parse_args()
run_load_balancer(args.controller_addr, args.load_balancer_port)
run_load_balancer(args.controller_addr, args.load_balancer_port,
args.load_balancing_policy)
98 changes: 97 additions & 1 deletion sky/serve/load_balancing_policies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""LoadBalancingPolicy: Policy to select endpoint."""
import math
import random
import typing
from typing import List, Optional
from typing import Dict, List, Optional, Tuple

import httpx

from sky import sky_logging

Expand Down Expand Up @@ -68,3 +71,96 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
return ready_replica_url


class GeoDataPolicy(LoadBalancingPolicy):
"""Geo-data load balancing policy using an online GeoIP service."""

def __init__(self, replica_locations: Dict[str, Tuple[float,
float]]) -> None:
super().__init__()
self.replica_locations = replica_locations

def set_ready_replicas(self, ready_replicas: List[str]) -> None:
# Ensure all replicas have associated locations
for replica in ready_replicas:
if replica not in self.replica_locations:
# Every replica must have a valid location
raise ValueError(f'Replica {replica} does not have \
a corresponding location.')
self.ready_replicas = ready_replicas

def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
user_location = self._get_user_location(request)
if not user_location:
# If user location can't be determined,
# select a random replica
random.shuffle(self.ready_replicas)
return self.ready_replicas[0]

# Find the closest replica
min_distance = float('inf')
nearest_replica = None
for replica in self.ready_replicas:
replica_location = self.replica_locations[replica]
distance = self._calculate_distance(user_location, replica_location)
if distance < min_distance:
min_distance = distance
nearest_replica = replica

return nearest_replica

def _get_user_location(
self, request: 'fastapi.Request') -> Optional[Tuple[float, float]]:
# Extract the user's IP address
ip_address = request.client.host
if not ip_address:
logger.warning('Could not extract IP \
address from request.')
return None

# Perform GeoIP lookup using httpx, limited to 150 requests per minute.
# TODO(acuadron):
# - Use IP caching to reduce the number of requests.
# - Make Async?
try:
with httpx.Client() as client:
response = client.get(f'http://ip-api.com/json/{ip_address}',
timeout=2)
if response.status_code == 200:
data = response.json()
if data['status'] == 'success':
latitude = data['lat']
longitude = data['lon']
return (latitude, longitude)
else:
logger.warning(f'GeoIP lookup failed: \
{data.get("message", "Unknown error")}')
return None
else:
logger.warning(f'GeoIP lookup failed with \
status code {response.status_code}')
return None
except httpx.RequestError as e:
logger.warning(f'Failed to get location \
for IP {ip_address}: {e}')
return None

def _calculate_distance(self, loc1: Tuple[float, float],
loc2: Tuple[float, float]) -> float:
# Haversine formula to calculate the great-circle distance
lat1, lon1 = loc1
lat2, lon2 = loc2
earth_radius = 6371 # Earth radius in kilometers

phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)

a = (math.sin(delta_phi / 2)**2 +
math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2)**2)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

distance = earth_radius * c
return distance
Loading