Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Jul 4, 2024
2 parents 3e21aa2 + 2803a91 commit 043616f
Show file tree
Hide file tree
Showing 10 changed files with 354 additions and 110 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ To control custom models, use `+` to add a custom model, use `-` to hide a model

User `-all` to disable all default models, `+all` to enable all default models.

### `DEFAULT_MODEL` (optional)

Change default model

### `WHITE_WEBDEV_ENDPOINTS` (optional)

You can use this option if you want to increase the number of webdav service addresses you are allowed to access, as required by the format:
Expand Down
15 changes: 10 additions & 5 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,23 @@ Azure 密钥。

Azure Api 版本,你可以在这里找到:[Azure 文档](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)

### `GOOGLE_API_KEY` (optional)
### `GOOGLE_API_KEY` (可选)

Google Gemini Pro 密钥.

### `GOOGLE_URL` (optional)
### `GOOGLE_URL` (可选)

Google Gemini Pro Api Url.

### `ANTHROPIC_API_KEY` (optional)
### `ANTHROPIC_API_KEY` (可选)

anthropic claude Api Key.

### `ANTHROPIC_API_VERSION` (optional)
### `ANTHROPIC_API_VERSION` (可选)

anthropic claude Api version.

### `ANTHROPIC_URL` (optional)
### `ANTHROPIC_URL` (可选)

anthropic claude Api Url.

Expand Down Expand Up @@ -156,7 +156,12 @@ anthropic claude Api Url.
用来控制模型列表,使用 `+` 增加一个模型,使用 `-` 来隐藏一个模型,使用 `模型名=展示名` 来自定义模型的展示名,用英文逗号隔开。

### `DEFAULT_MODEL` (可选)

更改默认模型

### `DEFAULT_INPUT_TEMPLATE` (可选)

自定义默认的 template,用于初始化『设置』中的『用户输入预处理』配置项

## 开发
Expand Down
4 changes: 3 additions & 1 deletion app/api/google/[...path]/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ async function handle(
);
}

const fetchUrl = `${baseUrl}/${path}?key=${key}`;
const fetchUrl = `${baseUrl}/${path}?key=${key}${
req?.nextUrl?.searchParams?.get("alt") == "sse" ? "&alt=sse" : ""
}`;
const fetchOptions: RequestInit = {
headers: {
"Content-Type": "application/json",
Expand Down
10 changes: 7 additions & 3 deletions app/client/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,17 @@ export function getHeaders() {
const modelConfig = useChatStore.getState().currentSession().mask.modelConfig;
const isGoogle = modelConfig.model.startsWith("gemini");
const isAzure = accessStore.provider === ServiceProvider.Azure;
const authHeader = isAzure ? "api-key" : "Authorization";
const isAnthropic = accessStore.provider === ServiceProvider.Anthropic;
const authHeader = isAzure ? "api-key" : isAnthropic ? 'x-api-key' : "Authorization";
const apiKey = isGoogle
? accessStore.googleApiKey
: isAzure
? accessStore.azureApiKey
: isAnthropic
? accessStore.anthropicApiKey
: accessStore.openaiApiKey;
const clientConfig = getClientConfig();
const makeBearer = (s: string) => `${isAzure ? "" : "Bearer "}${s.trim()}`;
const makeBearer = (s: string) => `${isAzure || isAnthropic ? "" : "Bearer "}${s.trim()}`;
const validString = (x: string) => x && x.length > 0;

// when using google api in app, not set auth header
Expand All @@ -181,7 +184,8 @@ export function getHeaders() {
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers[authHeader] = makeBearer(
// access_code must send with header named `Authorization`, will using in auth middleware.
headers['Authorization'] = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}
Expand Down
33 changes: 4 additions & 29 deletions app/client/platforms/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant";
import { ChatOptions, LLMApi, MultimodalContent } from "../api";
import { ChatOptions, getHeaders, LLMApi, MultimodalContent, } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant";
Expand Down Expand Up @@ -190,11 +190,10 @@ export class ClaudeApi implements LLMApi {
body: JSON.stringify(requestBody),
signal: controller.signal,
headers: {
"Content-Type": "application/json",
Accept: "application/json",
"x-api-key": accessStore.anthropicApiKey,
...getHeaders(), // get common headers
"anthropic-version": accessStore.anthropicApiVersion,
Authorization: getAuthKey(accessStore.anthropicApiKey),
// do not send `anthropicApiKey` in browser!!!
// Authorization: getAuthKey(accessStore.anthropicApiKey),
},
};

Expand Down Expand Up @@ -389,27 +388,3 @@ function trimEnd(s: string, end = " ") {

return s;
}

function bearer(value: string) {
return `Bearer ${value.trim()}`;
}

function getAuthKey(apiKey = "") {
const accessStore = useAccessStore.getState();
const isApp = !!getClientConfig()?.isApp;
let authKey = "";

if (apiKey) {
// use user's api key first
authKey = bearer(apiKey);
} else if (
accessStore.enabledAccessControl() &&
!isApp &&
!!accessStore.accessCode
) {
// or use access code
authKey = bearer(ACCESS_CODE_PREFIX + accessStore.accessCode);
}

return authKey;
}
146 changes: 83 additions & 63 deletions app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import {
getMessageTextContent,
getMessageImages,
Expand All @@ -20,7 +26,7 @@ export class GeminiProApi implements LLMApi {
);
}
async chat(options: ChatOptions): Promise<void> {
// const apiClient = this;
const apiClient = this;
let multimodal = false;
const messages = options.messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }];
Expand Down Expand Up @@ -120,7 +126,9 @@ export class GeminiProApi implements LLMApi {

if (!baseUrl) {
baseUrl = isApp
? DEFAULT_API_HOST + "/api/proxy/google/" + Google.ChatPath(modelConfig.model)
? DEFAULT_API_HOST +
"/api/proxy/google/" +
Google.ChatPath(modelConfig.model)
: this.path(Google.ChatPath(modelConfig.model));
}

Expand All @@ -139,16 +147,15 @@ export class GeminiProApi implements LLMApi {
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);

if (shouldStream) {
let responseText = "";
let remainText = "";
let finished = false;

let existingTexts: string[] = [];
const finish = () => {
finished = true;
options.onFinish(existingTexts.join(""));
options.onFinish(responseText + remainText);
};

// animate response to make it looks smooth
Expand All @@ -173,72 +180,85 @@ export class GeminiProApi implements LLMApi {
// start animaion
animateResponseText();

fetch(
baseUrl.replace("generateContent", "streamGenerateContent"),
chatPayload,
)
.then((response) => {
const reader = response?.body?.getReader();
const decoder = new TextDecoder();
let partialData = "";
controller.signal.onabort = finish;

// https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
const chatPath =
baseUrl.replace("generateContent", "streamGenerateContent") +
(baseUrl.indexOf("?") > -1 ? "&alt=sse" : "?alt=sse");
fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
clearTimeout(requestTimeoutId);
const contentType = res.headers.get("content-type");
console.log(
"[Gemini] request response content type: ",
contentType,
);

return reader?.read().then(function processText({
done,
value,
}): Promise<any> {
if (done) {
if (response.status !== 200) {
try {
let data = JSON.parse(ensureProperEnding(partialData));
if (data && data[0].error) {
options.onError?.(new Error(data[0].error.message));
} else {
options.onError?.(new Error("Request failed"));
}
} catch (_) {
options.onError?.(new Error("Request failed"));
}
}
if (contentType?.startsWith("text/plain")) {
responseText = await res.clone().text();
return finish();
}

if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try {
const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}

console.log("Stream complete");
// options.onFinish(responseText + remainText);
finished = true;
return Promise.resolve();
if (res.status === 401) {
responseTexts.push(Locale.Error.Unauthorized);
}

partialData += decoder.decode(value, { stream: true });
if (extraInfo) {
responseTexts.push(extraInfo);
}

try {
let data = JSON.parse(ensureProperEnding(partialData));
responseText = responseTexts.join("\n\n");

const textArray = data.reduce(
(acc: string[], item: { candidates: any[] }) => {
const texts = item.candidates.map((candidate) =>
candidate.content.parts
.map((part: { text: any }) => part.text)
.join(""),
);
return acc.concat(texts);
},
[],
);
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const delta = apiClient.extractMessage(json);

if (textArray.length > existingTexts.length) {
const deltaArray = textArray.slice(existingTexts.length);
existingTexts = textArray;
remainText += deltaArray.join("");
}
} catch (error) {
// console.log("[Response Animation] error: ", error,partialData);
// skip error message when parsing json
if (delta) {
remainText += delta;
}

return reader.read().then(processText);
});
})
.catch((error) => {
console.error("Error:", error);
});
const blockReason = json?.promptFeedback?.blockReason;
if (blockReason) {
// being blocked
console.log(`[Google] [Safety Ratings] result:`, blockReason);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
} else {
const res = await fetch(baseUrl, chatPayload);
clearTimeout(requestTimeoutId);
Expand All @@ -252,7 +272,7 @@ export class GeminiProApi implements LLMApi {
),
);
}
const message = this.extractMessage(resJson);
const message = apiClient.extractMessage(resJson);
options.onFinish(message);
}
} catch (e) {
Expand Down
25 changes: 25 additions & 0 deletions app/masks/build.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import fs from "fs";
import path from "path";
import { CN_MASKS } from "./cn";
import { TW_MASKS } from "./tw";
import { EN_MASKS } from "./en";

import { type BuiltinMask } from "./typing";

const BUILTIN_MASKS: Record<string, BuiltinMask[]> = {
cn: CN_MASKS,
tw: TW_MASKS,
en: EN_MASKS,
};

const dirname = path.dirname(__filename);

fs.writeFile(
dirname + "/../../public/masks.json",
JSON.stringify(BUILTIN_MASKS, null, 4),
function (error) {
if (error) {
console.error("[Build] failed to build masks", error);
}
},
);
20 changes: 17 additions & 3 deletions app/masks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ export const BUILTIN_MASK_STORE = {
},
};

export const BUILTIN_MASKS: BuiltinMask[] = [...CN_MASKS, ...TW_MASKS, ...EN_MASKS].map(
(m) => BUILTIN_MASK_STORE.add(m),
);
export const BUILTIN_MASKS: BuiltinMask[] = [];

if (typeof window != "undefined") {
// run in browser skip in next server
fetch("/masks.json")
.then((res) => res.json())
.catch((error) => {
console.error("[Fetch] failed to fetch masks", error);
return { cn: [], tw: [], en: [] };
})
.then((masks) => {
const { cn = [], tw = [], en = [] } = masks;
return [...cn, ...tw, ...en].map((m) => {
BUILTIN_MASKS.push(BUILTIN_MASK_STORE.add(m));
});
});
}
Loading

0 comments on commit 043616f

Please sign in to comment.