-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
1 parent
707ff10
commit 5a5f10e
Showing
10 changed files
with
206 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
77 changes: 77 additions & 0 deletions
77
server/src/main/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package com.epam.aidial.core.server.upstream; | ||
|
||
import com.epam.aidial.core.config.Upstream; | ||
import lombok.Getter; | ||
import lombok.extern.slf4j.Slf4j; | ||
|
||
import java.util.Comparator; | ||
import java.util.List; | ||
import java.util.Random; | ||
|
||
/** | ||
* Load balancer distributes load in the proportion of probability of upstream weights. | ||
* The higher upstream weight, the higher probability the upstream takes more load. | ||
*/ | ||
@Slf4j | ||
class RandomizedWeightedBalancer implements Comparable<RandomizedWeightedBalancer> { | ||
|
||
private final int tier; | ||
@Getter | ||
private final List<UpstreamState> upstreamStates; | ||
private final Random generator; | ||
|
||
RandomizedWeightedBalancer(String deploymentName, List<Upstream> upstreams, Random generator) { | ||
if (upstreams == null || upstreams.isEmpty()) { | ||
throw new IllegalArgumentException("Upstream list is null or empty for deployment: " + deploymentName); | ||
} | ||
int tier = upstreams.get(0).getTier(); | ||
for (Upstream upstream : upstreams) { | ||
if (upstream.getTier() != tier) { | ||
throw new IllegalArgumentException("Tier mismatch for deployment " + deploymentName); | ||
} | ||
} | ||
this.tier = tier; | ||
this.upstreamStates = upstreams.stream() | ||
.filter(upstream -> upstream.getWeight() > 0) | ||
.map(UpstreamState::new) | ||
.toList(); | ||
this.generator = generator; | ||
if (this.upstreamStates.isEmpty()) { | ||
log.warn("No available upstreams for deployment {} and tier {}", deploymentName, tier); | ||
} | ||
} | ||
|
||
public Upstream next() { | ||
if (upstreamStates.isEmpty()) { | ||
return null; | ||
} | ||
|
||
List<Upstream> availableUpstreams = upstreamStates.stream().filter(UpstreamState::isUpstreamAvailable) | ||
.map(UpstreamState::getUpstream).toList(); | ||
if (availableUpstreams.isEmpty()) { | ||
return null; | ||
} | ||
int total = availableUpstreams.stream().map(Upstream::getWeight).reduce(0, Integer::sum); | ||
// make sure the upper bound `total` is inclusive | ||
int random = generator.nextInt(total + 1); | ||
int current = 0; | ||
|
||
Upstream result = null; | ||
|
||
for (Upstream upstream : availableUpstreams) { | ||
current += upstream.getWeight(); | ||
if (current >= random) { | ||
result = upstream; | ||
break; | ||
} | ||
} | ||
|
||
return result; | ||
} | ||
|
||
@Override | ||
public int compareTo(RandomizedWeightedBalancer randomizedWeightedBalancer) { | ||
return Integer.compare(tier, randomizedWeightedBalancer.tier); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 0 additions & 88 deletions
88
server/src/main/java/com/epam/aidial/core/server/upstream/WeightedRoundRobinBalancer.java
This file was deleted.
Oops, something went wrong.
74 changes: 74 additions & 0 deletions
74
...er/src/test/java/com/epam/aidial/core/server/upstream/RandomizedWeightedBalancerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package com.epam.aidial.core.server.upstream; | ||
|
||
import com.epam.aidial.core.config.Upstream; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.jupiter.api.extension.ExtendWith; | ||
import org.mockito.Mock; | ||
import org.mockito.junit.jupiter.MockitoExtension; | ||
|
||
import java.util.List; | ||
import java.util.Random; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
import static org.junit.jupiter.api.Assertions.assertNotNull; | ||
import static org.junit.jupiter.api.Assertions.assertNull; | ||
import static org.mockito.Mockito.when; | ||
|
||
@ExtendWith(MockitoExtension.class) | ||
public class RandomizedWeightedBalancerTest { | ||
|
||
@Mock | ||
private Random generator; | ||
|
||
@Test | ||
void testWeightedLoadBalancer() { | ||
List<Upstream> upstreams = List.of( | ||
new Upstream("endpoint1", null, null, 1, 0), | ||
new Upstream("endpoint2", null, null, 2, 0), | ||
new Upstream("endpoint3", null, null, 3, 0), | ||
new Upstream("endpoint4", null, null, 4, 0) | ||
); | ||
|
||
RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator); | ||
|
||
when(generator.nextInt(11)).thenReturn(0); | ||
|
||
Upstream upstream = balancer.next(); | ||
assertNotNull(upstream); | ||
assertEquals(upstreams.get(0), upstream); | ||
|
||
when(generator.nextInt(11)).thenReturn(2); | ||
|
||
upstream = balancer.next(); | ||
assertNotNull(upstream); | ||
assertEquals(upstreams.get(1), upstream); | ||
|
||
when(generator.nextInt(11)).thenReturn(6); | ||
|
||
upstream = balancer.next(); | ||
assertNotNull(upstream); | ||
assertEquals(upstreams.get(2), upstream); | ||
|
||
when(generator.nextInt(11)).thenReturn(10); | ||
|
||
upstream = balancer.next(); | ||
assertNotNull(upstream); | ||
assertEquals(upstreams.get(3), upstream); | ||
|
||
} | ||
|
||
@Test | ||
void testZeroWeightLoadBalancer() { | ||
List<Upstream> upstreams = List.of( | ||
new Upstream("endpoint1", null, null, 0, 1), | ||
new Upstream("endpoint2", null, null, -9, 1) | ||
); | ||
RandomizedWeightedBalancer balancer = new RandomizedWeightedBalancer("model1", upstreams, generator); | ||
|
||
for (int i = 0; i < 10; i++) { | ||
Upstream upstream = balancer.next(); | ||
assertNull(upstream); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.