Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCuadron committed Oct 10, 2024
1 parent 5bc7315 commit 4df4ade
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 31 deletions.
18 changes: 10 additions & 8 deletions sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int,
# By default, the round robin policy is used.
policy: lb_policies.LoadBalancingPolicy = lb_policies.RoundRobinPolicy()
if policy_name == 'geo_data':
#TODO(acuadron): Right now the locations of the VMs have to be inputed manually during the GeoDataPolicy
# instantiation, change this behaviour. We should, store the location of the replicas during their creation.
policy: lb_policies.LoadBalancingPolicy = lb_policies.GeoDataPolicy({})
#TODO(acuadron): Right now the locations of the VMs have
# to be inputed manually during the GeoDataPolicy
# instantiation, change this behaviour. We should, store
# the location of the replicas during their creation.
policy = lb_policies.GeoDataPolicy({})

load_balancer = SkyServeLoadBalancer(controller_url=controller_addr,
load_balancer_port=load_balancer_port,
Expand All @@ -252,11 +254,11 @@ def run_load_balancer(controller_addr: str, load_balancer_port: int,
required=True,
default=8890,
help='The port where the load balancer listens to.')
parser.add_argument(
'--load-balancing-policy',
choices=['round_robin', 'geo_data'],
default='round_robin', #TODO(acuadron): Change it to geo_data when ready
help='The load balancing policy to use.')
#TODO(acuadron): Change default to geo_data when ready
parser.add_argument('--load-balancing-policy',
choices=['round_robin', 'geo_data'],
default='round_robin',
help='The load balancing policy to use.')
args = parser.parse_args()
run_load_balancer(args.controller_addr, args.load_balancer_port,
args.load_balancing_policy)
45 changes: 22 additions & 23 deletions sky/serve/load_balancing_policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""LoadBalancingPolicy: Policy to select endpoint."""
import math
import random
import typing
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -77,26 +78,23 @@ class GeoDataPolicy(LoadBalancingPolicy):

def __init__(self, replica_locations: Dict[str, Tuple[float,
float]]) -> None:
"""
Initialize with a mapping from replica URLs to their (latitude, longitude).
"""
super().__init__()
self.replica_locations = replica_locations # type: Dict[str, Tuple[float, float]]
self.replica_locations = replica_locations

def set_ready_replicas(self, ready_replicas: List[str]) -> None:
# Ensure all replicas have associated locations
for replica in ready_replicas:
if replica not in self.replica_locations:
# Every replica must have a valid location
raise ValueError(
f"Replica {replica} does not have a corresponding location."
)
raise ValueError(f'Replica {replica} does not have \
a corresponding location.')
self.ready_replicas = ready_replicas

def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
user_location = self._get_user_location(request)
if not user_location:
# If user location can't be determined, select a random replica
# If user location can't be determined,
# select a random replica
random.shuffle(self.ready_replicas)
return self.ready_replicas[0]

Expand All @@ -117,42 +115,43 @@ def _get_user_location(
# Extract the user's IP address
ip_address = request.client.host
if not ip_address:
logger.warning('Could not extract IP address from request.')
logger.warning('Could not extract IP \
address from request.')
return None

# Perform GeoIP lookup using httpx, limited to 150 requests per minute.
# TODO(acuadron): Use IP caching to reduce the number of requests. Async?
# TODO(acuadron):
# - Use IP caching to reduce the number of requests.
# - Make Async?
try:
with httpx.Client() as client:
response = client.get(
f'http://ip-api.com/json/{ip_address}', timeout=2)
response = client.get(f'http://ip-api.com/json/{ip_address}',
timeout=2)
if response.status_code == 200:
data = response.json()
if data['status'] == 'success':
latitude = data['lat']
longitude = data['lon']
return (latitude, longitude)
else:
logger.warning(
f"GeoIP lookup failed: {data.get('message', 'Unknown error')}"
)
logger.warning(f'GeoIP lookup failed: \
{data.get("message", "Unknown error")}')
return None
else:
logger.warning(
f'GeoIP lookup failed with status code {response.status_code}'
)
logger.warning(f'GeoIP lookup failed with \
status code {response.status_code}')
return None
except Exception as e:
logger.warning(f'Failed to get location for IP {ip_address}: {e}')
except httpx.RequestError as e:
logger.warning(f'Failed to get location \
for IP {ip_address}: {e}')
return None

def _calculate_distance(self, loc1: Tuple[float, float],
loc2: Tuple[float, float]) -> float:
# Haversine formula to calculate the great-circle distance
import math
lat1, lon1 = loc1
lat2, lon2 = loc2
R = 6371 # Earth radius in kilometers
earth_radius = 6371 # Earth radius in kilometers

phi1 = math.radians(lat1)
phi2 = math.radians(lat2)
Expand All @@ -163,5 +162,5 @@ def _calculate_distance(self, loc1: Tuple[float, float],
math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2)**2)
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

distance = R * c
distance = earth_radius * c
return distance

0 comments on commit 4df4ade

Please sign in to comment.