From 39b3b30cdbce904b62a749b562f2c9295c07d38f Mon Sep 17 00:00:00 2001 From: cblmemo Date: Wed, 4 Dec 2024 11:32:38 -0800 Subject: [PATCH] [Serve] Add and adopt least load policy as default poicy. --- sky/serve/load_balancer.py | 8 ++++- sky/serve/load_balancing_policies.py | 54 ++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/sky/serve/load_balancer.py b/sky/serve/load_balancer.py index 30697532a22..babd5c09722 100644 --- a/sky/serve/load_balancer.py +++ b/sky/serve/load_balancer.py @@ -128,6 +128,7 @@ async def _proxy_request_to( encountered if anything goes wrong. """ logger.info(f'Proxy request to {url}') + self._load_balancing_policy.pre_execute_hook(url, request) try: # We defer the get of the client here on purpose, for case when the # replica is ready in `_proxy_with_retries` but refreshed before @@ -147,11 +148,16 @@ async def _proxy_request_to( content=await request.body(), timeout=constants.LB_STREAM_TIMEOUT) proxy_response = await client.send(proxy_request, stream=True) + + async def background_func(): + await proxy_response.aclose() + self._load_balancing_policy.post_execute_hook(url, request) + return fastapi.responses.StreamingResponse( content=proxy_response.aiter_raw(), status_code=proxy_response.status_code, headers=proxy_response.headers, - background=background.BackgroundTask(proxy_response.aclose)) + background=background.BackgroundTask(background_func)) except (httpx.RequestError, httpx.HTTPStatusError) as e: logger.error(f'Error when proxy request to {url}: ' f'{common_utils.format_exception(e)}') diff --git a/sky/serve/load_balancing_policies.py b/sky/serve/load_balancing_policies.py index aec6eb01487..fb59d196f40 100644 --- a/sky/serve/load_balancing_policies.py +++ b/sky/serve/load_balancing_policies.py @@ -1,7 +1,9 @@ """LoadBalancingPolicy: Policy to select endpoint.""" +import collections import random +import threading import typing -from typing import List, Optional +from typing import Dict, List, Optional from sky import sky_logging @@ -65,8 +67,16 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]: def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: raise NotImplementedError + def pre_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + pass -class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True): + def post_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + pass + + +class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'): """Round-robin load balancing policy.""" def __init__(self) -> None: @@ -90,3 +100,43 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: ready_replica_url = self.ready_replicas[self.index] self.index = (self.index + 1) % len(self.ready_replicas) return ready_replica_url + + +class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True): + """Least load load balancing policy.""" + + def __init__(self) -> None: + super().__init__() + self.load_map: Dict[str, int] = collections.defaultdict(int) + self.lock = threading.Lock() + + def set_ready_replicas(self, ready_replicas: List[str]) -> None: + if set(self.ready_replicas) == set(ready_replicas): + return + with self.lock: + self.ready_replicas = ready_replicas + for r in self.ready_replicas: + if r not in ready_replicas: + del self.load_map[r] + for replica in ready_replicas: + self.load_map[replica] = self.load_map.get(replica, 0) + + def _select_replica(self, request: 'fastapi.Request') -> Optional[str]: + del request # Unused. + if not self.ready_replicas: + return None + with self.lock: + return min(self.ready_replicas, + key=lambda replica: self.load_map.get(replica, 0)) + + def pre_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + del request # Unused. + with self.lock: + self.load_map[replica_url] += 1 + + def post_execute_hook(self, replica_url: str, + request: 'fastapi.Request') -> None: + del request # Unused. + with self.lock: + self.load_map[replica_url] -= 1