Skip to content

Commit

Permalink
feat: Implement user limits #82 (#188)
Browse files Browse the repository at this point in the history
Co-authored-by: Aliaksandr Stsiapanay <[email protected]>
  • Loading branch information
astsiapanay and astsiapanay authored Mar 13, 2024
1 parent e1f7ee4 commit 5696b28
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 38 deletions.
35 changes: 18 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Static settings are used on startup and cannot be changed while application is r
|--------------------------------------------|--------------------|-------------------------------------------------------------------------------------------------------------------
| config.files | aidial.config.json | Config files with parts of the whole config.
| config.reload | 60000 | Config reload interval in milliseconds.
| identityProviders | - | List of identity providers. **Note**. At least one identity provider must be provided.
| identityProviders | - | Map of identity providers. **Note**. At least one identity provider must be provided.
| identityProviders.*.jwksUrl | - | Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`
| identityProviders.*.rolePath | - | Path to the claim user roles in JWT token, e.g. `resource_access.chatbot-ui.roles` or just `roles`. **Required**.
| identityProviders.*.loggingKey | - | User information to search in claims of JWT token.
Expand Down Expand Up @@ -164,22 +164,23 @@ Dynamic settings include:
* Access Permissions
* Rate Limits

| Parameter | Description |
| ------------------------------- | ------------ |
| routes | Path(s) for specific upstream routing or to respond with a configured body. |
| applications | A list of deployed AI DIAL Applications and their parameters:<br />`<application_name>`: Unique application name. |
| applications.<application_name> | `endpoint`: AI DIAL Application API for chat completions.<br />`iconUrl`: Icon path for the AI DIAL Application on UI.<br />`description`: Brief AI DIAL Application description.<br />`displayName`: AI DIAL Application name on UI.<br />`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.<br />`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) |
| models | A list of deployed models and their parameters:<br />`<model_name>`: Unique model name. |
| models.<model_name> | `type`: Model type—`chat` or `embedding`.<br />`iconUrl`: Icon path for the model on UI.<br />`description`: Brief model description.<br />`displayName`: Model name on UI.<br />`displayVersion`: Model version on UI.<br />`endpoint`: Model API for chat completions or embeddings.<br />`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.<br />`features`: Model features.<br />`limits`: Model token limits.<br />`pricing`: Model pricing.<br />`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. |
| models.<model_name>.limits | `maxPromptTokens`: maximum number of tokens in a completion request.<br />`maxCompletionTokens`: maximum number of tokens in a completion response.<br />`maxTotalTokens`: maximum number of tokens in completion request and response combined.<br />Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. |
| models.<model_name>.pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).<br />`prompt`: per-unit price for the completion request in USD.<br />`completion`: per-unit price for the completion response in USD. |
| models.<model_name>.features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `<deployment name>/rate`)*.<br />`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `<deployment name>/tokenize`)*.<br />`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `<deployment name>/truncate_prompt`)*.<br />`systemPromptSupported`: does the model support system prompt (default is `true`).<br />`toolsSupported`: does the model support tools (default is `false`).<br />`seedSupported`: does the model support `seed` request parameter (default is `false`).<br />`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) |
| models.<model_name>.upstreams | `endpoint`: Model endpoint.<br />`key`: Your API key. |
| keys | API Keys parameters:<br />`<core_key>`: Your API key. |
| keys.<core_key> | `project`: Project name assigned to this key.<br />`role`: A configured role name that defines key permissions. |
| roles | API key roles `<role_name>` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. |
| roles.<role_name> | `limits`: Limits for models, applications, or assistants. |
| roles.<role_name>.limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.<br />`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. |
| Parameter | Description |
|---------------------------------------| ------------ |
| routes | Path(s) for specific upstream routing or to respond with a configured body. |
| applications | A list of deployed AI DIAL Applications and their parameters:<br />`<application_name>`: Unique application name. |
| applications.<application_name> | `endpoint`: AI DIAL Application API for chat completions.<br />`iconUrl`: Icon path for the AI DIAL Application on UI.<br />`description`: Brief AI DIAL Application description.<br />`displayName`: AI DIAL Application name on UI.<br />`inputAttachmentTypes`: A list of allowed MIME types for the input attachments.<br />`maxInputAttachments`: Maximum number of input attachments (default is zero when `inputAttachmentTypes` is unset, otherwise, infinity) |
| models | A list of deployed models and their parameters:<br />`<model_name>`: Unique model name. |
| models.<model_name> | `type`: Model type—`chat` or `embedding`.<br />`iconUrl`: Icon path for the model on UI.<br />`description`: Brief model description.<br />`displayName`: Model name on UI.<br />`displayVersion`: Model version on UI.<br />`endpoint`: Model API for chat completions or embeddings.<br />`tokenizerModel`: Identifies the specific model whose tokenization algorithm exactly matches that of the referenced model. This is typically the name of the earliest-released model in a series of models sharing an identical tokenization algorithm (e.g. `gpt-3.5-turbo-0301`, `gpt-4-0314`, or `gpt-4-1106-vision-preview`). This parameter is essential for DIAL clients that reimplement tokenization algorithms on their side, instead of utilizing the `tokenizeEndpoint` provided by the model.<br />`features`: Model features.<br />`limits`: Model token limits.<br />`pricing`: Model pricing.<br />`upstreams`: Used for load-balancing—request is sent to model endpoint containing X-UPSTREAM-ENDPOINT and X-UPSTREAM-KEY headers. |
| models.<model_name>.limits | `maxPromptTokens`: maximum number of tokens in a completion request.<br />`maxCompletionTokens`: maximum number of tokens in a completion response.<br />`maxTotalTokens`: maximum number of tokens in completion request and response combined.<br />Typically either `maxTotalTokens` is specified or `maxPromptTokens` and `maxCompletionTokens`. |
| models.<model_name>.pricing | `unit`: the pricing units (currently `token` and `char_without_whitespace` are supported).<br />`prompt`: per-unit price for the completion request in USD.<br />`completion`: per-unit price for the completion response in USD. |
| models.<model_name>.features | `rateEndpoint`: endpoint for rate requests *(exposed by core as `<deployment name>/rate`)*.<br />`tokenizeEndpoint`: endpoint for requests to the model tokenizer *(exposed by core as `<deployment name>/tokenize`)*.<br />`truncatePromptEndpoint`: endpoint for truncating prompt requests *(exposed by core as `<deployment name>/truncate_prompt`)*.<br />`systemPromptSupported`: does the model support system prompt (default is `true`).<br />`toolsSupported`: does the model support tools (default is `false`).<br />`seedSupported`: does the model support `seed` request parameter (default is `false`).<br />`urlAttachmentsSupported`: does the model/application support attachments with URLs (default is `false`) |
| models.<model_name>.upstreams | `endpoint`: Model endpoint.<br />`key`: Your API key. |
| models.<model_name>.defaultUserLimit | Default user limit for the given model.<br /> `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.<br />`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting.|
| keys | API Keys parameters:<br />`<core_key>`: Your API key. |
| keys.<core_key> | `project`: Project name assigned to this key.<br />`role`: A configured role name that defines key permissions. |
| roles | API key roles `<role_name>` with associated limits. Each API key has one role defined in the list of roles. Roles are associated with models, applications, assistants, and defined limits. |
| roles.<role_name> | `limits`: Limits for models, applications, or assistants. |
| roles.<role_name>.limits | `minute`: Total tokens per minute limit sent to the model, managed via floating window approach for well-distributed rate limiting.<br />`day`: Total tokens per day limit sent to the model, managed via floating window approach for balanced rate limiting. |

## License

Expand Down
26 changes: 24 additions & 2 deletions sample/aidial.config.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
"endpoint": "http://localhost:7003",
"key": "modelKey3"
}
]
],
"userRoles": ["role1", "role2"],
"defaultUserLimit": {
"minute": "100000",
"day": "10000000"
}
},
"embedding-ada": {
"type": "embedding",
Expand All @@ -57,7 +62,8 @@
"key": "modelKey4"
}
]
}
},
"userRoles": ["role3"]
},
"keys": {
"proxyKey1": {
Expand Down Expand Up @@ -86,6 +92,22 @@
"search_assistant": {},
"app": {}
}
},
"role1": {
"limits": {
"chat-gpt-35-turbo": {
"minute": "200000",
"day": "10000000"
}
}
},
"role2": {
"limits": {
"chat-gpt-35-turbo": {
"minute": "100000",
"day": "20000000"
}
}
}
}
}
1 change: 1 addition & 0 deletions src/main/java/com/epam/aidial/core/config/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ public class Model extends Deployment {
private List<Upstream> upstreams = List.of();
// if it's set then the model name is overridden with that name in the request body to the model adapter
private String overrideName;
private Limit defaultUserLimit;
}
45 changes: 39 additions & 6 deletions src/main/java/com/epam/aidial/core/limiter/RateLimiter.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.epam.aidial.core.config.Deployment;
import com.epam.aidial.core.config.Key;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Model;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.data.LimitStats;
import com.epam.aidial.core.data.ResourceType;
Expand All @@ -20,10 +21,16 @@
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.util.List;
import java.util.Map;
import java.util.Optional;

@Slf4j
@RequiredArgsConstructor
public class RateLimiter {

private static final Limit DEFAULT_LIMIT = new Limit();

private final Vertx vertx;

private final ResourceService resourceService;
Expand All @@ -34,10 +41,7 @@ public Future<Void> increase(ProxyContext context) {
if (resourceService == null) {
return Future.succeededFuture();
}
Key key = context.getKey();
if (key == null) {
return Future.succeededFuture();
}

Deployment deployment = context.getDeployment();
TokenUsage usage = context.getTokenUsage();

Expand All @@ -62,8 +66,7 @@ public Future<RateLimitResult> limit(ProxyContext context) {
Deployment deployment = context.getDeployment();
Limit limit;
if (key == null) {
// don't support user limits yet
return Future.succeededFuture(RateLimitResult.SUCCESS);
limit = getLimitByUser(context);
} else {
limit = getLimitByApiKey(context, deployment.getName());
}
Expand Down Expand Up @@ -176,8 +179,38 @@ private Limit getLimitByApiKey(ProxyContext context, String deploymentName) {
return role.getLimits().get(deploymentName);
}

private Limit getLimitByUser(ProxyContext context) {
List<String> userRoles = context.getUserRoles();
Limit defaultUserLimit = getDefaultUserLimit(context.getDeployment());
if (userRoles.isEmpty()) {
return defaultUserLimit;
}
String deploymentName = context.getDeployment().getName();
Map<String, Role> userRoleToDeploymentLimits = context.getConfig().getRoles();
long minuteLimit = 0;
long dayLimit = 0;
for (String userRole : userRoles) {
Limit limit = Optional.ofNullable(userRoleToDeploymentLimits.get(userRole))
.map(role -> role.getLimits().get(deploymentName))
.orElse(defaultUserLimit);
minuteLimit = Math.max(minuteLimit, limit.getMinute());
dayLimit = Math.max(dayLimit, limit.getDay());
}
Limit limit = new Limit();
limit.setMinute(minuteLimit);
limit.setDay(dayLimit);
return limit;
}

private static String getPath(String deploymentName) {
return String.format("%s/tokens", deploymentName);
}

private static Limit getDefaultUserLimit(Deployment deployment) {
if (deployment instanceof Model model) {
return model.getDefaultUserLimit() == null ? DEFAULT_LIMIT : model.getDefaultUserLimit();
}
return DEFAULT_LIMIT;
}

}
111 changes: 98 additions & 13 deletions src/test/java/com/epam/aidial/core/limiter/RateLimiterTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.epam.aidial.core.storage.BlobStorage;
import com.epam.aidial.core.token.TokenUsage;
import com.epam.aidial.core.util.HttpStatus;
import com.epam.aidial.core.util.ProxyUtil;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServerRequest;
Expand All @@ -34,7 +33,7 @@
import redis.embedded.RedisServer;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

Expand Down Expand Up @@ -128,17 +127,6 @@ public void testLimit_EntityNotFound() {
assertEquals(HttpStatus.FORBIDDEN, result.result().status());
}

@Test
public void testLimit_SuccessUser() {
ProxyContext proxyContext = new ProxyContext(new Config(), request, new ApiKeyData(), new ExtractedClaims("sub", Collections.emptyList(), "hash"), "trace-id", "span-id");

Future<RateLimitResult> result = rateLimiter.limit(proxyContext);

assertNotNull(result);
assertNotNull(result.result());
assertEquals(HttpStatus.OK, result.result().status());
}

@Test
public void testLimit_ApiKeyLimitNotFound() {
Key key = new Key();
Expand Down Expand Up @@ -340,4 +328,101 @@ public void testGetLimitStats_ApiKey() {

}

@Test
public void testLimit_User_LimitFound() {
Config config = new Config();

Role role1 = new Role();
Limit limit = new Limit();
limit.setDay(10000);
limit.setMinute(100);
role1.setLimits(Map.of("model", limit));

Role role2 = new Role();
limit = new Limit();
limit.setDay(20000);
limit.setMinute(200);
role2.setLimits(Map.of("model", limit));

config.getRoles().put("role1", role1);
config.getRoles().put("role2", role2);

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setTotalTokens(150);
proxyContext.setTokenUsage(tokenUsage);

Future<Void> increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

Future<RateLimitResult> checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.TOO_MANY_REQUESTS, checkLimitFuture.result().status());

}

@Test
public void testLimit_User_LimitNotFound() {
Config config = new Config();

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1"), "user-hash"), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);

when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

TokenUsage tokenUsage = new TokenUsage();
tokenUsage.setTotalTokens(90);
proxyContext.setTokenUsage(tokenUsage);

Future<Void> increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

Future<RateLimitResult> checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

increaseLimitFuture = rateLimiter.increase(proxyContext);
assertNotNull(increaseLimitFuture);
assertNull(increaseLimitFuture.cause());

checkLimitFuture = rateLimiter.limit(proxyContext);

assertNotNull(checkLimitFuture);
assertNotNull(checkLimitFuture.result());
assertEquals(HttpStatus.OK, checkLimitFuture.result().status());

}

}

0 comments on commit 5696b28

Please sign in to comment.