Skip to content

Commit

Permalink
feat: Implement Randomized Weighted Balancer #589
Browse files Browse the repository at this point in the history
  • Loading branch information
astsiapanay committed Dec 2, 2024
1 parent 707ff10 commit 145b9a6
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 213 deletions.
3 changes: 2 additions & 1 deletion server/src/main/java/com/epam/aidial/core/server/AiDial.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -100,7 +101,7 @@ void start() throws Exception {
client = vertx.createHttpClient(new HttpClientOptions(settings("client")));

LogStore logStore = new GfLogStore(vertx);
UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx);
UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, Random::new);

if (accessTokenValidator == null) {
accessTokenValidator = new AccessTokenValidator(settings("identityProviders"), vertx, client);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.epam.aidial.core.server.upstream;

import com.epam.aidial.core.config.Upstream;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

import java.util.Comparator;
import java.util.List;
import java.util.Random;

/**
* Load balancer distributes load in the proportion of probability of upstream weights.
* The higher upstream weight, the higher probability the upstream takes more load.
*/
@Slf4j
class RandomizedWeightedBalancer implements Comparable<RandomizedWeightedBalancer> {

private final int tier;
@Getter
private final List<UpstreamState> upstreamStates;
private final Random generator;

RandomizedWeightedBalancer(String deploymentName, List<Upstream> upstreams, Random generator) {
if (upstreams == null || upstreams.isEmpty()) {
throw new IllegalArgumentException("Upstream list is null or empty for deployment: " + deploymentName);
}
int tier = upstreams.get(0).getTier();
for (Upstream upstream : upstreams) {
if (upstream.getTier() != tier) {
throw new IllegalArgumentException("Tier mismatch for deployment " + deploymentName);
}
}
this.tier = tier;
this.upstreamStates = upstreams.stream()
.filter(upstream -> upstream.getWeight() > 0)
.map(UpstreamState::new)
.toList();
this.generator = generator;
if (this.upstreamStates.isEmpty()) {
log.warn("No available upstreams for deployment {} and tier {}", deploymentName, tier);
}
}

public Upstream next() {
if (upstreamStates.isEmpty()) {
return null;
}

List<Upstream> availableUpstreams = upstreamStates.stream().filter(UpstreamState::isUpstreamAvailable)
.map(UpstreamState::getUpstream).toList();
if (availableUpstreams.isEmpty()) {
return null;
}
int total = availableUpstreams.stream().map(Upstream::getWeight).reduce(0, Integer::sum);
// make sure the upper bound `total` is inclusive
int random = generator.nextInt(total + 1);
int current = 0;

Upstream result = null;

for (Upstream upstream : availableUpstreams) {
current += upstream.getWeight();
if (current >= random) {
result = upstream;
break;
}
}

return result;
}

@Override
public int compareTo(RandomizedWeightedBalancer randomizedWeightedBalancer) {
return Integer.compare(tier, randomizedWeightedBalancer.tier);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

Expand All @@ -19,16 +21,16 @@
*/
class TieredBalancer {

private final List<WeightedRoundRobinBalancer> tiers;
private final List<RandomizedWeightedBalancer> tiers;

private final List<UpstreamState> upstreamStates = new ArrayList<>();

private final List<Predicate<UpstreamState>> predicates = new ArrayList<>();

public TieredBalancer(String deploymentName, List<Upstream> upstreams) {
this.tiers = buildTiers(deploymentName, upstreams);
for (WeightedRoundRobinBalancer tier : tiers) {
upstreamStates.addAll(tier.getUpstreams());
public TieredBalancer(String deploymentName, List<Upstream> upstreams, Random random) {
this.tiers = buildTiers(deploymentName, upstreams, random);
for (RandomizedWeightedBalancer tier : tiers) {
upstreamStates.addAll(tier.getUpstreamStates());
}
predicates.add(state -> state.getStatus().is5xx()
&& state.getSource() == UpstreamState.RetryAfterSource.CORE);
Expand All @@ -42,10 +44,10 @@ public TieredBalancer(String deploymentName, List<Upstream> upstreams) {

@Nullable
synchronized Upstream next(Set<Upstream> usedUpstreams) {
for (WeightedRoundRobinBalancer tier : tiers) {
UpstreamState upstreamState = tier.next();
if (upstreamState != null) {
return upstreamState.getUpstream();
for (RandomizedWeightedBalancer tier : tiers) {
Upstream upstream = tier.next();
if (upstream != null) {
return upstream;
}
}
// fallback
Expand Down Expand Up @@ -82,13 +84,13 @@ private UpstreamState findUpstreamState(Upstream upstream) {
throw new IllegalArgumentException("Upstream is not found: " + upstream);
}

private static List<WeightedRoundRobinBalancer> buildTiers(String deploymentName, List<Upstream> upstreams) {
List<WeightedRoundRobinBalancer> balancers = new ArrayList<>();
private static List<RandomizedWeightedBalancer> buildTiers(String deploymentName, List<Upstream> upstreams, Random random) {
List<RandomizedWeightedBalancer> balancers = new ArrayList<>();
Map<Integer, List<Upstream>> groups = upstreams.stream()
.collect(Collectors.groupingBy(Upstream::getTier));

for (Map.Entry<Integer, List<Upstream>> entry : groups.entrySet()) {
balancers.add(new WeightedRoundRobinBalancer(deploymentName, entry.getValue()));
balancers.add(new RandomizedWeightedBalancer(deploymentName, entry.getValue(), random));
}

balancers.sort(Comparator.naturalOrder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

/**
* This class caches load balancers for deployments and routes.
Expand All @@ -33,7 +35,10 @@ public class UpstreamRouteProvider {
*/
private final ConcurrentHashMap<String, BalancerWrapper> balancers = new ConcurrentHashMap<>();

public UpstreamRouteProvider(Vertx vertx) {
private final Supplier<Random> generatorFactory;

public UpstreamRouteProvider(Vertx vertx, Supplier<Random> generatorFactory) {
this.generatorFactory = generatorFactory;
vertx.setPeriodic(0, TimeUnit.MINUTES.toMillis(1), event -> evictExpiredBalancers());
}

Expand All @@ -55,7 +60,8 @@ private UpstreamRoute get(String key, List<Upstream> upstreams, int maxRetryAtte
&& maxRetryAttempts == cur.maxRetryAttempts) {
result = cur;
} else {
result = new BalancerWrapper(key, maxRetryAttempts, upstreams);
TieredBalancer balancer = new TieredBalancer(key, upstreams, generatorFactory.get());
result = new BalancerWrapper(balancer, maxRetryAttempts, upstreams);
}
result.lastAccessTime = System.currentTimeMillis();
return result;
Expand Down Expand Up @@ -122,8 +128,8 @@ private static class BalancerWrapper {

final List<Upstream> upstreams;

public BalancerWrapper(String key, int maxRetryAttempts, List<Upstream> upstreams) {
this.balancer = new TieredBalancer(key, upstreams);
public BalancerWrapper(TieredBalancer balancer, int maxRetryAttempts, List<Upstream> upstreams) {
this.balancer = balancer;
this.maxRetryAttempts = maxRetryAttempts;
this.upstreams = upstreams;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import java.util.concurrent.TimeUnit;

@Slf4j
class UpstreamState implements Comparable<UpstreamState> {
class UpstreamState {

static final long DEFAULT_RETRY_AFTER_SECONDS_VALUE = 30;
@Getter
Expand Down Expand Up @@ -89,11 +89,6 @@ boolean isUpstreamAvailable() {
return System.currentTimeMillis() > retryAfter;
}

@Override
public int compareTo(UpstreamState upstreamState) {
return Integer.compare(upstream.getWeight(), upstreamState.getUpstream().getWeight());
}

enum RetryAfterSource {
UPSTREAM, CORE
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.epam.aidial.core.server.upstream;

import com.epam.aidial.core.config.Upstream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.List;
import java.util.Random;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
public class RandomizedWeightedBalancerTest {

@Mock
private Random generator;

@Test
void testWeightedLoadBalancer() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, null, 1, 0),
new Upstream("endpoint2", null, null, 2, 0),
new Upstream("endpoint3", null, null, 3, 0),
new Upstream("endpoint4", null, null, 4, 0)
);

RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator);

when(generator.nextInt(11)).thenReturn(0);

Upstream upstream = balancer.next();
assertNotNull(upstream);
assertEquals(upstreams.get(0), upstream);

when(generator.nextInt(11)).thenReturn(2);

upstream = balancer.next();
assertNotNull(upstream);
assertEquals(upstreams.get(1), upstream);

when(generator.nextInt(11)).thenReturn(6);

upstream = balancer.next();
assertNotNull(upstream);
assertEquals(upstreams.get(2), upstream);

when(generator.nextInt(11)).thenReturn(10);

upstream = balancer.next();
assertNotNull(upstream);
assertEquals(upstreams.get(3), upstream);

}

@Test
void testZeroWeightLoadBalancer() {
List<Upstream> upstreams = List.of(
new Upstream("endpoint1", null, null, 0, 1),
new Upstream("endpoint2", null, null, -9, 1)
);
RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator);

for (int i = 0; i < 10; i++) {
Upstream upstream = balancer.next();
assertNull(upstream);
}
}

}
Loading

0 comments on commit 145b9a6

Please sign in to comment.