Skip to content

Commit

Permalink
[Serve] Add and adopt least load policy as default poicy.
Browse files Browse the repository at this point in the history
  • Loading branch information
cblmemo committed Dec 4, 2024
1 parent 51a7e17 commit 39b3b30
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
8 changes: 7 additions & 1 deletion sky/serve/load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def _proxy_request_to(
encountered if anything goes wrong.
"""
logger.info(f'Proxy request to {url}')
self._load_balancing_policy.pre_execute_hook(url, request)
try:
# We defer the get of the client here on purpose, for case when the
# replica is ready in `_proxy_with_retries` but refreshed before
Expand All @@ -147,11 +148,16 @@ async def _proxy_request_to(
content=await request.body(),
timeout=constants.LB_STREAM_TIMEOUT)
proxy_response = await client.send(proxy_request, stream=True)

async def background_func():
await proxy_response.aclose()
self._load_balancing_policy.post_execute_hook(url, request)

return fastapi.responses.StreamingResponse(
content=proxy_response.aiter_raw(),
status_code=proxy_response.status_code,
headers=proxy_response.headers,
background=background.BackgroundTask(proxy_response.aclose))
background=background.BackgroundTask(background_func))
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f'Error when proxy request to {url}: '
f'{common_utils.format_exception(e)}')
Expand Down
54 changes: 52 additions & 2 deletions sky/serve/load_balancing_policies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""LoadBalancingPolicy: Policy to select endpoint."""
import collections
import random
import threading
import typing
from typing import List, Optional
from typing import Dict, List, Optional

from sky import sky_logging

Expand Down Expand Up @@ -65,8 +67,16 @@ def select_replica(self, request: 'fastapi.Request') -> Optional[str]:
def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
raise NotImplementedError

def pre_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
pass

class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin', default=True):
def post_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
pass


class RoundRobinPolicy(LoadBalancingPolicy, name='round_robin'):
"""Round-robin load balancing policy."""

def __init__(self) -> None:
Expand All @@ -90,3 +100,43 @@ def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
ready_replica_url = self.ready_replicas[self.index]
self.index = (self.index + 1) % len(self.ready_replicas)
return ready_replica_url


class LeastLoadPolicy(LoadBalancingPolicy, name='least_load', default=True):
"""Least load load balancing policy."""

def __init__(self) -> None:
super().__init__()
self.load_map: Dict[str, int] = collections.defaultdict(int)
self.lock = threading.Lock()

def set_ready_replicas(self, ready_replicas: List[str]) -> None:
if set(self.ready_replicas) == set(ready_replicas):
return
with self.lock:
self.ready_replicas = ready_replicas
for r in self.ready_replicas:
if r not in ready_replicas:
del self.load_map[r]
for replica in ready_replicas:
self.load_map[replica] = self.load_map.get(replica, 0)

def _select_replica(self, request: 'fastapi.Request') -> Optional[str]:
del request # Unused.
if not self.ready_replicas:
return None
with self.lock:
return min(self.ready_replicas,
key=lambda replica: self.load_map.get(replica, 0))

def pre_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
del request # Unused.
with self.lock:
self.load_map[replica_url] += 1

def post_execute_hook(self, replica_url: str,
request: 'fastapi.Request') -> None:
del request # Unused.
with self.lock:
self.load_map[replica_url] -= 1

0 comments on commit 39b3b30

Please sign in to comment.