From 5a5f10ed05f476ab41aa6baf1be88c21dec7c3b2 Mon Sep 17 00:00:00 2001 From: Aliaksandr Stsiapanay Date: Mon, 2 Dec 2024 15:03:50 +0300 Subject: [PATCH] feat: Implement Randomized Weighted Balancer #589 (#596) --- .../com/epam/aidial/core/server/AiDial.java | 3 +- .../upstream/RandomizedWeightedBalancer.java | 77 ++++++++++++ .../core/server/upstream/TieredBalancer.java | 26 ++-- .../upstream/UpstreamRouteProvider.java | 14 ++- .../core/server/upstream/UpstreamState.java | 7 +- .../upstream/WeightedRoundRobinBalancer.java | 88 -------------- .../RandomizedWeightedBalancerTest.java | 74 ++++++++++++ ...ancerTest.java => TieredBalancerTest.java} | 114 +++--------------- .../upstream/UpstreamRouteProviderTest.java | 8 +- .../server/upstream/UpstreamRouteTest.java | 8 +- 10 files changed, 206 insertions(+), 213 deletions(-) create mode 100644 server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java delete mode 100644 server/src/main/java/com/epam/aidial/core/server/upstream/WeightedRoundRobinBalancer.java create mode 100644 server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java rename server/src/test/java/com/epam/aidial/core/server/upstream/{LoadBalancerTest.java => TieredBalancerTest.java} (53%) diff --git a/server/src/main/java/com/epam/aidial/core/server/AiDial.java b/server/src/main/java/com/epam/aidial/core/server/AiDial.java index 4470196a..2eef7a22 100644 --- a/server/src/main/java/com/epam/aidial/core/server/AiDial.java +++ b/server/src/main/java/com/epam/aidial/core/server/AiDial.java @@ -60,6 +60,7 @@ import java.util.Map; import java.util.Objects; import java.util.Properties; +import java.util.Random; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; @@ -100,7 +101,7 @@ void start() throws Exception { client = vertx.createHttpClient(new HttpClientOptions(settings("client"))); LogStore logStore = new GfLogStore(vertx); - UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, Random::new); if (accessTokenValidator == null) { accessTokenValidator = new AccessTokenValidator(settings("identityProviders"), vertx, client); diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java b/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java new file mode 100644 index 00000000..baa10617 --- /dev/null +++ b/server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java @@ -0,0 +1,77 @@ +package com.epam.aidial.core.server.upstream; + +import com.epam.aidial.core.config.Upstream; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + +import java.util.Comparator; +import java.util.List; +import java.util.Random; + +/** + * Load balancer distributes load in the proportion of probability of upstream weights. + * The higher upstream weight, the higher probability the upstream takes more load. + */ +@Slf4j +class RandomizedWeightedBalancer implements Comparable { + + private final int tier; + @Getter + private final List upstreamStates; + private final Random generator; + + RandomizedWeightedBalancer(String deploymentName, List upstreams, Random generator) { + if (upstreams == null || upstreams.isEmpty()) { + throw new IllegalArgumentException("Upstream list is null or empty for deployment: " + deploymentName); + } + int tier = upstreams.get(0).getTier(); + for (Upstream upstream : upstreams) { + if (upstream.getTier() != tier) { + throw new IllegalArgumentException("Tier mismatch for deployment " + deploymentName); + } + } + this.tier = tier; + this.upstreamStates = upstreams.stream() + .filter(upstream -> upstream.getWeight() > 0) + .map(UpstreamState::new) + .toList(); + this.generator = generator; + if (this.upstreamStates.isEmpty()) { + log.warn("No available upstreams for deployment {} and tier {}", deploymentName, tier); + } + } + + public Upstream next() { + if (upstreamStates.isEmpty()) { + return null; + } + + List availableUpstreams = upstreamStates.stream().filter(UpstreamState::isUpstreamAvailable) + .map(UpstreamState::getUpstream).toList(); + if (availableUpstreams.isEmpty()) { + return null; + } + int total = availableUpstreams.stream().map(Upstream::getWeight).reduce(0, Integer::sum); + // make sure the upper bound `total` is inclusive + int random = generator.nextInt(total + 1); + int current = 0; + + Upstream result = null; + + for (Upstream upstream : availableUpstreams) { + current += upstream.getWeight(); + if (current >= random) { + result = upstream; + break; + } + } + + return result; + } + + @Override + public int compareTo(RandomizedWeightedBalancer randomizedWeightedBalancer) { + return Integer.compare(tier, randomizedWeightedBalancer.tier); + } + +} diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/TieredBalancer.java b/server/src/main/java/com/epam/aidial/core/server/upstream/TieredBalancer.java index 053ec6b4..a211c303 100644 --- a/server/src/main/java/com/epam/aidial/core/server/upstream/TieredBalancer.java +++ b/server/src/main/java/com/epam/aidial/core/server/upstream/TieredBalancer.java @@ -8,8 +8,10 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Random; import java.util.Set; import java.util.function.Predicate; +import java.util.function.Supplier; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -19,16 +21,16 @@ */ class TieredBalancer { - private final List tiers; + private final List tiers; private final List upstreamStates = new ArrayList<>(); private final List> predicates = new ArrayList<>(); - public TieredBalancer(String deploymentName, List upstreams) { - this.tiers = buildTiers(deploymentName, upstreams); - for (WeightedRoundRobinBalancer tier : tiers) { - upstreamStates.addAll(tier.getUpstreams()); + public TieredBalancer(String deploymentName, List upstreams, Random random) { + this.tiers = buildTiers(deploymentName, upstreams, random); + for (RandomizedWeightedBalancer tier : tiers) { + upstreamStates.addAll(tier.getUpstreamStates()); } predicates.add(state -> state.getStatus().is5xx() && state.getSource() == UpstreamState.RetryAfterSource.CORE); @@ -42,10 +44,10 @@ public TieredBalancer(String deploymentName, List upstreams) { @Nullable synchronized Upstream next(Set usedUpstreams) { - for (WeightedRoundRobinBalancer tier : tiers) { - UpstreamState upstreamState = tier.next(); - if (upstreamState != null) { - return upstreamState.getUpstream(); + for (RandomizedWeightedBalancer tier : tiers) { + Upstream upstream = tier.next(); + if (upstream != null) { + return upstream; } } // fallback @@ -82,13 +84,13 @@ private UpstreamState findUpstreamState(Upstream upstream) { throw new IllegalArgumentException("Upstream is not found: " + upstream); } - private static List buildTiers(String deploymentName, List upstreams) { - List balancers = new ArrayList<>(); + private static List buildTiers(String deploymentName, List upstreams, Random random) { + List balancers = new ArrayList<>(); Map> groups = upstreams.stream() .collect(Collectors.groupingBy(Upstream::getTier)); for (Map.Entry> entry : groups.entrySet()) { - balancers.add(new WeightedRoundRobinBalancer(deploymentName, entry.getValue())); + balancers.add(new RandomizedWeightedBalancer(deploymentName, entry.getValue(), random)); } balancers.sort(Comparator.naturalOrder()); diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamRouteProvider.java b/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamRouteProvider.java index f0659a27..19edfe72 100644 --- a/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamRouteProvider.java +++ b/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamRouteProvider.java @@ -12,8 +12,10 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; /** * This class caches load balancers for deployments and routes. @@ -33,7 +35,10 @@ public class UpstreamRouteProvider { */ private final ConcurrentHashMap balancers = new ConcurrentHashMap<>(); - public UpstreamRouteProvider(Vertx vertx) { + private final Supplier generatorFactory; + + public UpstreamRouteProvider(Vertx vertx, Supplier generatorFactory) { + this.generatorFactory = generatorFactory; vertx.setPeriodic(0, TimeUnit.MINUTES.toMillis(1), event -> evictExpiredBalancers()); } @@ -55,7 +60,8 @@ private UpstreamRoute get(String key, List upstreams, int maxRetryAtte && maxRetryAttempts == cur.maxRetryAttempts) { result = cur; } else { - result = new BalancerWrapper(key, maxRetryAttempts, upstreams); + TieredBalancer balancer = new TieredBalancer(key, upstreams, generatorFactory.get()); + result = new BalancerWrapper(balancer, maxRetryAttempts, upstreams); } result.lastAccessTime = System.currentTimeMillis(); return result; @@ -122,8 +128,8 @@ private static class BalancerWrapper { final List upstreams; - public BalancerWrapper(String key, int maxRetryAttempts, List upstreams) { - this.balancer = new TieredBalancer(key, upstreams); + public BalancerWrapper(TieredBalancer balancer, int maxRetryAttempts, List upstreams) { + this.balancer = balancer; this.maxRetryAttempts = maxRetryAttempts; this.upstreams = upstreams; } diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamState.java b/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamState.java index 255020de..f74ee1e2 100644 --- a/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamState.java +++ b/server/src/main/java/com/epam/aidial/core/server/upstream/UpstreamState.java @@ -9,7 +9,7 @@ import java.util.concurrent.TimeUnit; @Slf4j -class UpstreamState implements Comparable { +class UpstreamState { static final long DEFAULT_RETRY_AFTER_SECONDS_VALUE = 30; @Getter @@ -89,11 +89,6 @@ boolean isUpstreamAvailable() { return System.currentTimeMillis() > retryAfter; } - @Override - public int compareTo(UpstreamState upstreamState) { - return Integer.compare(upstream.getWeight(), upstreamState.getUpstream().getWeight()); - } - enum RetryAfterSource { UPSTREAM, CORE } diff --git a/server/src/main/java/com/epam/aidial/core/server/upstream/WeightedRoundRobinBalancer.java b/server/src/main/java/com/epam/aidial/core/server/upstream/WeightedRoundRobinBalancer.java deleted file mode 100644 index 69e53690..00000000 --- a/server/src/main/java/com/epam/aidial/core/server/upstream/WeightedRoundRobinBalancer.java +++ /dev/null @@ -1,88 +0,0 @@ -package com.epam.aidial.core.server.upstream; - -import com.epam.aidial.core.config.Upstream; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; - -import java.util.Comparator; -import java.util.List; -import java.util.PriorityQueue; - -/** - * Implementation of weighted round-robin load balancer. - * Load balancer tracks upstream statistics and guaranty spreading the load according to the upstreams weight - */ -@Slf4j -class WeightedRoundRobinBalancer implements Comparable { - - private final int tier; - @Getter - private final List upstreams; - private final long[] upstreamsWeights; - private final long[] upstreamsUsage; - private final long totalWeight; - private long totalUsage; - private final PriorityQueue upstreamPriority = new PriorityQueue<>((a, b) -> Double.compare(b.delta, a.delta)); - - WeightedRoundRobinBalancer(String deploymentName, List upstreams) { - if (upstreams == null || upstreams.isEmpty()) { - throw new IllegalArgumentException("Upstream list is null or empty for deployment: " + deploymentName); - } - int tier = upstreams.get(0).getTier(); - for (Upstream upstream : upstreams) { - if (upstream.getTier() != tier) { - throw new IllegalArgumentException("Tier mismatch for deployment " + deploymentName); - } - } - this.tier = tier; - this.upstreams = upstreams.stream() - .filter(upstream -> upstream.getWeight() > 0) - .map(UpstreamState::new) - .sorted(Comparator.reverseOrder()) - .toList(); - this.totalWeight = this.upstreams.stream().map(UpstreamState::getUpstream).mapToLong(Upstream::getWeight).sum(); - this.upstreamsUsage = new long[this.upstreams.size()]; - this.upstreamsWeights = this.upstreams.stream().map(UpstreamState::getUpstream).mapToLong(Upstream::getWeight).toArray(); - if (this.upstreams.isEmpty()) { - log.warn("No available upstreams for deployment {} and tier {}", deploymentName, tier); - } - } - - public UpstreamState next() { - if (upstreams.isEmpty()) { - return null; - } - try { - int size = upstreams.size(); - for (int i = 0; i < size; i++) { - UpstreamState upstreamState = upstreams.get(i); - double actualUsageRate = upstreamsUsage[i] == 0 ? 0 : (double) upstreamsUsage[i] / totalUsage; - double expectedUsageRate = (double) upstreamsWeights[i] / totalWeight; - double delta = expectedUsageRate - actualUsageRate; - // for precise load balancing we need to add all upstreams to the priority queue - upstreamPriority.offer(new UpstreamUsage(upstreamState, i, delta)); - } - // find the best available upstream and return it - while (!upstreamPriority.isEmpty()) { - UpstreamUsage candidate = upstreamPriority.poll(); - totalUsage += 1; - upstreamsUsage[candidate.upstreamIndex] += 1; - if (candidate.upstream.isUpstreamAvailable()) { - return candidate.upstream; - } - } - return null; - } finally { - // clear state - upstreamPriority.clear(); - } - } - - @Override - public int compareTo(WeightedRoundRobinBalancer weightedRoundRobinBalancer) { - return Integer.compare(tier, weightedRoundRobinBalancer.tier); - } - - private record UpstreamUsage(UpstreamState upstream, int upstreamIndex, double delta) { - } -} diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java new file mode 100644 index 00000000..c2d41559 --- /dev/null +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java @@ -0,0 +1,74 @@ +package com.epam.aidial.core.server.upstream; + +import com.epam.aidial.core.config.Upstream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.List; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.when; + +@ExtendWith(MockitoExtension.class) +public class RandomizedWeightedBalancerTest { + + @Mock + private Random generator; + + @Test + void testWeightedLoadBalancer() { + List upstreams = List.of( + new Upstream("endpoint1", null, null, 1, 0), + new Upstream("endpoint2", null, null, 2, 0), + new Upstream("endpoint3", null, null, 3, 0), + new Upstream("endpoint4", null, null, 4, 0) + ); + + RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator); + + when(generator.nextInt(11)).thenReturn(0); + + Upstream upstream = balancer.next(); + assertNotNull(upstream); + assertEquals(upstreams.get(0), upstream); + + when(generator.nextInt(11)).thenReturn(2); + + upstream = balancer.next(); + assertNotNull(upstream); + assertEquals(upstreams.get(1), upstream); + + when(generator.nextInt(11)).thenReturn(6); + + upstream = balancer.next(); + assertNotNull(upstream); + assertEquals(upstreams.get(2), upstream); + + when(generator.nextInt(11)).thenReturn(10); + + upstream = balancer.next(); + assertNotNull(upstream); + assertEquals(upstreams.get(3), upstream); + + } + + @Test + void testZeroWeightLoadBalancer() { + List upstreams = List.of( + new Upstream("endpoint1", null, null, 0, 1), + new Upstream("endpoint2", null, null, -9, 1) + ); + RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator); + + for (int i = 0; i < 10; i++) { + Upstream upstream = balancer.next(); + assertNull(upstream); + } + } + +} diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/LoadBalancerTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java similarity index 53% rename from server/src/test/java/com/epam/aidial/core/server/upstream/LoadBalancerTest.java rename to server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java index 178e916e..452e5bac 100644 --- a/server/src/test/java/com/epam/aidial/core/server/upstream/LoadBalancerTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/TieredBalancerTest.java @@ -4,111 +4,40 @@ import com.epam.aidial.core.config.Upstream; import com.epam.aidial.core.storage.http.HttpStatus; import io.vertx.core.Vertx; -import org.apache.commons.lang3.mutable.MutableInt; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Random; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) -public class LoadBalancerTest { +public class TieredBalancerTest { @Mock private Vertx vertx; - - @Test - void testWeightedLoadBalancer() { - List upstreams = List.of( - new Upstream("endpoint1", null, null, 1, 0), - new Upstream("endpoint2", null, null, 9, 0) - ); - WeightedRoundRobinBalancer balancer = new WeightedRoundRobinBalancer("model1", upstreams); - - Map usage = new HashMap<>(); - usage.put("endpoint1", new MutableInt(0)); - usage.put("endpoint2", new MutableInt(0)); - - for (int i = 0; i < 20; i++) { - UpstreamState upstream = balancer.next(); - assertNotNull(upstream); - String endpoint = upstream.getUpstream().getEndpoint(); - usage.get(endpoint).increment(); - } - - assertEquals(2, usage.get("endpoint1").getValue()); - assertEquals(18, usage.get("endpoint2").getValue()); - - upstreams = List.of( - new Upstream("endpoint1", null, null, 1, 0), - new Upstream("endpoint2", null, null, 1, 0), - new Upstream("endpoint3", null, null, 1, 0), - new Upstream("endpoint4", null, null, 1, 0) - ); - balancer = new WeightedRoundRobinBalancer("model1", upstreams); - - usage = new HashMap<>(); - usage.put("endpoint1", new MutableInt(0)); - usage.put("endpoint2", new MutableInt(0)); - usage.put("endpoint3", new MutableInt(0)); - usage.put("endpoint4", new MutableInt(0)); - - for (int i = 0; i < 100; i++) { - UpstreamState upstream = balancer.next(); - assertNotNull(upstream); - String endpoint = upstream.getUpstream().getEndpoint(); - usage.get(endpoint).increment(); - } - - assertEquals(25, usage.get("endpoint1").getValue()); - assertEquals(25, usage.get("endpoint2").getValue()); - assertEquals(25, usage.get("endpoint3").getValue()); - assertEquals(25, usage.get("endpoint4").getValue()); - upstreams = List.of( - new Upstream("endpoint1", null, null, 49, 0), - new Upstream("endpoint2", null, null, 44, 0), - new Upstream("endpoint3", null, null, 47, 0), - new Upstream("endpoint4", null, null, 59, 0) - ); - balancer = new WeightedRoundRobinBalancer("model1", upstreams); - - usage = new HashMap<>(); - usage.put("endpoint1", new MutableInt(0)); - usage.put("endpoint2", new MutableInt(0)); - usage.put("endpoint3", new MutableInt(0)); - usage.put("endpoint4", new MutableInt(0)); - - for (int i = 0; i < 398; i++) { - UpstreamState upstream = balancer.next(); - assertNotNull(upstream); - String endpoint = upstream.getUpstream().getEndpoint(); - usage.get(endpoint).increment(); - } - - assertEquals(98, usage.get("endpoint1").getValue()); - assertEquals(88, usage.get("endpoint2").getValue()); - assertEquals(94, usage.get("endpoint3").getValue()); - assertEquals(118, usage.get("endpoint4").getValue()); - } + @Mock + private Random generator; @Test - void testTieredLoadBalancer() { + void testTierPriority() { List upstreams = List.of( new Upstream("endpoint1", null, null, 1, 0), new Upstream("endpoint2", null, null, 9, 1) ); - TieredBalancer balancer = new TieredBalancer("model1", upstreams); + TieredBalancer balancer = new TieredBalancer("model1", upstreams, generator); // verify all requests go to the highest tier for (int j = 0; j < 50; j++) { @@ -119,12 +48,12 @@ void testTieredLoadBalancer() { } @Test - void testLoadBalancerFailure() throws InterruptedException { + void testFail() throws InterruptedException { List upstreams = List.of( new Upstream("endpoint1", null, null, 1, 0), new Upstream("endpoint2", null, null, 9, 1) ); - TieredBalancer balancer = new TieredBalancer("model1", upstreams); + TieredBalancer balancer = new TieredBalancer("model1", upstreams, generator); Upstream upstream = balancer.next(Set.of()); assertNotNull(upstream); @@ -148,27 +77,13 @@ void testLoadBalancerFailure() throws InterruptedException { assertEquals("endpoint1", upstream.getEndpoint()); } - @Test - void testZeroWeightLoadBalancer() { - List upstreams = List.of( - new Upstream("endpoint1", null, null, 0, 1), - new Upstream("endpoint2", null, null, -9, 1) - ); - WeightedRoundRobinBalancer balancer = new WeightedRoundRobinBalancer("model1", upstreams); - - for (int i = 0; i < 10; i++) { - UpstreamState upstream = balancer.next(); - assertNull(upstream); - } - } - @Test void test5xxErrorsHandling() { List upstreams = List.of( new Upstream("endpoint0", null, null, 1, 0), new Upstream("endpoint1", null, null, 1, 1) ); - TieredBalancer balancer = new TieredBalancer("model1", upstreams); + TieredBalancer balancer = new TieredBalancer("model1", upstreams, generator); Set used = new HashSet<>(); // report upstream failure 4 times @@ -193,8 +108,11 @@ void testUpstreamFallback() { .map(index -> new Upstream("endpoint" + index, null, null, 1, 1)) .toList(); model.setUpstreams(upstreams); + AtomicInteger counter = new AtomicInteger(); + when(generator.nextInt(5)).thenAnswer(cb -> counter.incrementAndGet()); + Supplier factory = () -> generator; - UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, factory); UpstreamRoute route1 = upstreamRouteProvider.get(model); assertEquals(upstreams.get(0), route1.get()); diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteProviderTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteProviderTest.java index e0421b5c..d6bfef5f 100644 --- a/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteProviderTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteProviderTest.java @@ -11,6 +11,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.util.List; +import java.util.Random; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -21,9 +22,12 @@ public class UpstreamRouteProviderTest { @Mock private Vertx vertx; + @Mock + private Random generator; + @Test public void testGet_UpstreamsNotChanged() { - UpstreamRouteProvider provider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider provider = new UpstreamRouteProvider(vertx, () -> generator); Application application = new Application(); application.setName("app"); UpstreamRoute route1 = provider.get(application); @@ -44,7 +48,7 @@ public void testGet_UpstreamsChanged() { upstream1.setWeight(2); model.setUpstreams(List.of(upstream1)); - UpstreamRouteProvider provider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider provider = new UpstreamRouteProvider(vertx, () -> generator); UpstreamRoute route1 = provider.get(model); route1.fail(HttpStatus.TOO_MANY_REQUESTS); assertNull(route1.next()); diff --git a/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteTest.java b/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteTest.java index a9136d4e..3824c055 100644 --- a/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteTest.java +++ b/server/src/test/java/com/epam/aidial/core/server/upstream/UpstreamRouteTest.java @@ -10,6 +10,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.util.List; +import java.util.Random; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -23,6 +24,9 @@ public class UpstreamRouteTest { @Mock private Vertx vertx; + @Mock + private Random generator; + @Test void testUpstreamRouteWithRetry() { Model model = new Model(); @@ -34,7 +38,7 @@ void testUpstreamRouteWithRetry() { new Upstream("endpoint4", null, null, 1, 1) )); - UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, () -> generator); UpstreamRoute route = upstreamRouteProvider.get(model); assertTrue(route.available()); @@ -80,7 +84,7 @@ void testUpstreamRouteWithRetry2() { new Upstream("endpoint2", null, null, 1, 1) )); - UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx); + UpstreamRouteProvider upstreamRouteProvider = new UpstreamRouteProvider(vertx, () -> generator); UpstreamRoute route = upstreamRouteProvider.get(model); assertTrue(route.available());