Skip to content

Commit

Permalink
Add basic AI Horde support
Browse files Browse the repository at this point in the history
  • Loading branch information
lmg-anon authored Dec 15, 2024
1 parent eb98d22 commit b05fe94
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 20 deletions.
156 changes: 138 additions & 18 deletions mikupad.html
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,7 @@
const API_LLAMA_CPP = 0;
const API_KOBOLD_CPP = 2;
const API_OPENAI_COMPAT = 3;
const API_AI_HORDE = 4;

// Polyfill for piece of shit Chromium
if (!(Symbol.asyncIterator in ReadableStream.prototype)) {
Expand Down Expand Up @@ -1459,6 +1460,8 @@
urlString = urlString.replace(/\/v1\/?$/, ""); // remove "/v1" from the end of the string
if (endpointAPI == API_KOBOLD_CPP)
urlString = urlString.replace(/\/api\/?$/, ""); // remove "/api" from the end of the string
if (endpointAPI == API_AI_HORDE)
urlString = "https://aihorde.net/api";
urlString = urlString.replace(/\/$/, ""); // remove "/" from the end of the string

return urlString;
Expand Down Expand Up @@ -1489,6 +1492,8 @@
if (tokenCount != -1)
return tokenCount;
return 0;
default:
return 0;
}
}

Expand Down Expand Up @@ -1518,6 +1523,8 @@
if (tokens !== null)
return tokens;
return [];
default:
return [];
}
}

Expand All @@ -1526,6 +1533,8 @@
switch (endpointAPI) {
case API_OPENAI_COMPAT:
return await openaiModels({ endpoint, endpointAPIKey, signal, ...options });
case API_AI_HORDE:
return await aiHordeModels({ endpoint, endpointAPIKey, signal, ...options });
default:
return [];
}
Expand All @@ -1540,6 +1549,8 @@
return yield* await koboldCppCompletion({ endpoint, signal, ...options });
case API_OPENAI_COMPAT:
return yield* await openaiCompletion({ endpoint, endpointAPIKey, signal, ...options });
case API_AI_HORDE:
return yield* await aiHordeCompletion({ endpoint, endpointAPIKey, signal, ...options });
}
}

Expand All @@ -1550,6 +1561,8 @@
return await koboldCppAbortCompletion({ endpoint, ...options });
case API_OPENAI_COMPAT:
return await openaiOobaAbortCompletion({ endpoint, ...options });
case API_AI_HORDE:
return await aiHordeAbortCompletion({ endpoint, ...options });
}
}

Expand Down Expand Up @@ -1737,15 +1750,19 @@

}

function koboldCppConvertOptions(options) {
function koboldCppConvertOptions(options, endpoint) {
const isHorde = endpoint.toLowerCase().includes("aihorde.net");
const swapOption = (lhs, rhs) => {
if (lhs in options) {
options[rhs] = options[lhs];
delete options[lhs];
}
};
if (options.n_predict === -1) {
options.n_predict = 1024;
options.n_predict = isHorde ? 512 : 1024;
}
if (options.n_predict < 16 && isHorde) {
options.n_predict = 16;
}
swapOption("n_ctx", "max_context_length");
swapOption("n_predict", "max_length");
Expand All @@ -1767,7 +1784,7 @@
...(proxyEndpoint ? { 'X-Real-URL': endpoint } : {})
},
body: JSON.stringify({
...koboldCppConvertOptions(options),
...koboldCppConvertOptions(options, endpoint),
stream: true,
}),
signal,
Expand Down Expand Up @@ -2093,6 +2110,88 @@
}
}

async function aiHordeModels({ endpoint, endpointAPIKey, proxyEndpoint, signal, ...options }) {
const res = await fetch(`${proxyEndpoint ?? endpoint}/v2/status/models?type=text`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
...(proxyEndpoint ? { 'X-Real-URL': endpoint } : {})
},
signal,
});

if (!res.ok)
throw new Error(`HTTP ${res.status}`);

const response = await res.json();

return response
.filter(model => model.type === "text")
.map(model => model.name);
}

async function* aiHordeCompletion({ endpoint, endpointAPIKey, proxyEndpoint, signal, ...options }) {
const { model, prompt, ...params } = options;
const submitRes = await fetch(`${proxyEndpoint ?? endpoint}/v2/generate/text/async`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Apikey': endpointAPIKey?.trim() ? endpointAPIKey : '0000000000',
...(proxyEndpoint ? { 'X-Real-URL': endpoint } : {})
},
body: JSON.stringify({
...(model ? { models: [ model ] } : {}),
params: { ...koboldCppConvertOptions(params, endpoint) },
prompt: prompt
}),
signal,
});
if (!submitRes.ok)
throw new Error(`HTTP ${submitRes.status}`);
const { id: taskId } = await submitRes.json();

yield { status: 'queued', taskId: taskId };

// Poll for results
while (true) {
const checkRes = await fetch(`${proxyEndpoint ?? endpoint}/v2/generate/text/status/${taskId}`, {
headers: {
...(proxyEndpoint ? { 'X-Real-URL': endpoint } : {})
},
signal,
});

if (!checkRes.ok)
throw new Error(`HTTP ${checkRes.status}`);
const status = await checkRes.json();

yield { status: 'queue', position: status.queue_position };

if (status.done) {
if (status.generations && status.generations.length > 0) {
yield { status: 'done', content: status.generations[0].text };
}
break;
}

// Wait before polling again
await new Promise(resolve => setTimeout(resolve, 1000));
}
}

async function aiHordeAbortCompletion({ endpoint, proxyEndpoint, hordeTaskId, ...options }) {
try {
await fetch(`${proxyEndpoint ?? endpoint}/v2/generate/text/status/${hordeTaskId}`, {
method: 'DELETE',
headers: {
...(proxyEndpoint ? { 'X-Real-URL': endpoint } : {})
},
});
} catch (e) {
reportError(e);
}
}

function importSillyTavernWorldInfo(json, setWorldInfo, importBehavior) {
setWorldInfo(prevWorldInfo => {
let updatedEntries;
Expand Down Expand Up @@ -3411,6 +3510,7 @@
llamaCppSetLogitBiasParams();
break;
case API_KOBOLD_CPP:
case API_AI_HORDE:
koboldCppSetLogitBiasParams();
break;
case API_OPENAI_COMPAT:
Expand Down Expand Up @@ -5194,6 +5294,7 @@
const keyState = useRef({});
const sessionReconnectTimer = useRef();
const useScrollSmoothing = useRef(true);
const hordeTaskId = useRef();
const [templates, setTemplates] = useDBTemplates(defaultPresets.instructTemplates);
const [templateReplacements, setTemplateReplacements] = useState(false);
const [templatesImport, setTemplatesImport] = useState(false);
Expand Down Expand Up @@ -5275,6 +5376,7 @@
const [grammar, setGrammar] = useSessionState('grammar', '');
const [contextMenuState, setContextMenuState] = useState({ visible: false, x: 0, y: 0 });
const [instructModalState, setInstructModalState] = useState({});
const [hordeQueuePos, setHordeQueuePos] = useState(undefined);

function replacePlaceholders(string,placeholders) {
// give placeholders as json object
Expand Down Expand Up @@ -5717,6 +5819,7 @@
abortCompletion({
endpoint,
endpointAPI,
...(endpointAPI == API_AI_HORDE ? { hordeTaskId: hordeTaskId.current } : {}),
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
});
ac.abort();
Expand Down Expand Up @@ -5774,7 +5877,7 @@
for await (const chunk of completion({
endpoint,
endpointAPI,
...(endpointAPI == API_OPENAI_COMPAT || endpointAPI == API_LLAMA_CPP ? {
...(endpointAPI == API_OPENAI_COMPAT || endpointAPI == API_LLAMA_CPP || endpointAPI == API_AI_HORDE ? {
endpointAPIKey,
model: endpointModel
} : {}),
Expand Down Expand Up @@ -5851,8 +5954,19 @@
ac.signal.throwIfAborted();
if (chunk.stopping_word)
chunk.content = chunk.stopping_word;
if (!chunk.content)
if (endpointAPI === API_AI_HORDE) {
switch (chunk.status) {
case 'queued':
hordeTaskId.current = chunk.taskId;
continue;
case 'queue':
setHordeQueuePos(chunk.position);
continue;
}
}
if (!chunk.content) {
continue;
}
if (startTime === 0) {
startTime = performance.now();
} else {
Expand Down Expand Up @@ -5892,6 +6006,8 @@
undoStack.current.pop();
}
setTokensPerSec(0.0);
hordeTaskId.current = undefined;
setHordeQueuePos(undefined);
}

// Chat Mode
Expand Down Expand Up @@ -6119,7 +6235,7 @@
}, [modalState["context"], promptText, endpoint, endpointAPI]);

useEffect(() => {
if (endpointAPI !== API_OPENAI_COMPAT) {
if (endpointAPI !== API_OPENAI_COMPAT && endpointAPI !== API_AI_HORDE) {
return;
}
setRejectedAPIKey(false);
Expand All @@ -6129,7 +6245,7 @@
const models = await getModels({
endpoint,
endpointAPI,
...(endpointAPI == API_OPENAI_COMPAT ? { endpointAPIKey } : {}),
endpointAPIKey,
signal: ac.signal,
...(isMikupadEndpoint ? { proxyEndpoint: sessionStorage.proxyEndpoint } : {})
});
Expand Down Expand Up @@ -6673,12 +6789,13 @@
<${Sessions} sessionStorage=${sessionStorage} disabled=${!!cancel}/>
</${CollapsibleGroup}>
<${CollapsibleGroup} label="Parameters" expanded>
<${InputBox} label="Server"
className="${isMixedContent() ? 'mixed-content' : ''}"
tooltip="${isMixedContent() ? 'This URL might be blocked due to mixed content. If the prediction fails, download mikupad.html and run it locally.' : ''}"
readOnly=${!!cancel}
value=${endpoint}
onValueChange=${setEndpoint}/>
${(endpointAPI != API_AI_HORDE) && html`
<${InputBox} label="Server"
className="${isMixedContent() ? 'mixed-content' : ''}"
tooltip="${isMixedContent() ? 'This URL might be blocked due to mixed content. If the prediction fails, download mikupad.html and run it locally.' : ''}"
readOnly=${!!cancel}
value=${endpoint}
onValueChange=${setEndpoint}/>`}
<${SelectBox}
label="API"
disabled=${!!cancel}
Expand All @@ -6687,9 +6804,10 @@
options=${[
{ name: 'llama.cpp' , value: API_LLAMA_CPP },
{ name: 'KoboldCpp' , value: API_KOBOLD_CPP },
{ name: 'OpenAI compatible', value: API_OPENAI_COMPAT },
{ name: 'OpenAI Compatible', value: API_OPENAI_COMPAT },
{ name: 'AI Horde' , value: API_AI_HORDE },
]}/>
${(endpointAPI === API_LLAMA_CPP || endpointAPI === API_OPENAI_COMPAT) && html`
${(endpointAPI === API_LLAMA_CPP || endpointAPI === API_OPENAI_COMPAT || endpointAPI == API_AI_HORDE) && html`
<div className="hbox-flex" style=${{"flex-wrap": "unset"}}>
<${InputBox} label="API Key" type="${!showAPIKey ? "password" : "text"}"
className="${rejectedAPIKey ? 'rejected' : ''}"
Expand All @@ -6706,7 +6824,7 @@
: html`<${SVG_HideKey}/>`}
</button>
</div>`}
${endpointAPI == API_OPENAI_COMPAT && html`
${(endpointAPI == API_OPENAI_COMPAT || endpointAPI == API_AI_HORDE) && html`
<${InputBox} label="Model"
datalist=${openaiModels}
readOnly=${!!cancel}
Expand Down Expand Up @@ -6755,7 +6873,7 @@
readOnly=${!!cancel} value=${seed} onValueChange=${setSeed}/>
<${InputBox} tooltip="Currently not accurate to the token count, it will be used as an estimate." label="Max Context Length" type="text" inputmode="numeric"
readOnly=${!!cancel} value=${contextLength} onValueChange=${setContextLength}/>
<${InputBox} label="Max Predict Tokens${endpointAPI != API_LLAMA_CPP ? ' (-1 = 1024)' : ' (-1 = infinite)'}" type="text" inputmode="numeric"
<${InputBox} label="Max Predict Tokens${endpointAPI != API_LLAMA_CPP ? (endpointAPI == API_AI_HORDE ? ' (-1 = 512)' : ' (-1 = 1024)') : ' (-1 = infinite)'}" type="text" inputmode="numeric"
readOnly=${!!cancel} value=${maxPredictTokens} onValueChange=${setMaxPredictTokens}/>
<${InputBox} label="Stopping Strings (JSON array)" type="text" pattern="^\\[.*?\\]$"
className="${stoppingStringsError ? 'rejected' : ''}"
Expand Down Expand Up @@ -6861,7 +6979,7 @@
<div className="hbox">
<${InputSlider} label="DynaTemp Range" type="number" step="0.01"
readOnly=${!!cancel} value=${dynaTempRange} onValueChange=${setDynaTempRange}/>
${(endpointAPI != API_KOBOLD_CPP) && html`
${(endpointAPI != API_KOBOLD_CPP && endpointAPI != API_AI_HORDE) && html`
<${InputSlider} label="DynaTemp Exp" type="number" step="0.01"
readOnly=${!!cancel} value=${dynaTempExp} onValueChange=${setDynaTempExp}/>`}
</div>`}
Expand Down Expand Up @@ -7029,6 +7147,8 @@
</${CollapsibleGroup}>
${!!tokens && html`
<${InputBox} label="Tokens" value="${tokens}${tokensPerSec ? ` (${tokensPerSec.toFixed(2)} T/s)` : ``}" readOnly/>`}
${!!hordeQueuePos && html`
<${InputBox} label="Queue Position" value="${hordeQueuePos}" readOnly/>`}
<div className="buttons">
<button
title="Run next prediction (Ctrl + Enter)"
Expand Down
42 changes: 40 additions & 2 deletions server/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ app.post('/proxy/*', async (req, res) => {
headers: {
...req.headers,
'Content-Type': 'application/json',
'Host': new URL(targetBaseUrl).hostname // Update the Host header for the target server
'Host': new URL(targetBaseUrl).hostname, // Update the Host header for the target server
'Accept-Encoding': 'identity'
},
responseType: 'stream'
});
Expand Down Expand Up @@ -146,10 +147,47 @@ app.get('/proxy/*', async (req, res) => {

try {
const response = await axios.get(`${targetBaseUrl}/${path}`, {
params: req.query,
headers: {
...req.headers,
'Content-Type': 'application/json',
'Host': new URL(targetBaseUrl).hostname // Update the Host header for the target server
'Host': new URL(targetBaseUrl).hostname, // Update the Host header for the target server
'Accept-Encoding': 'identity'
}
});

res.send(response.data);
} catch (error) {
if (error.response) {
res.status(error.response.status).send(error.response.data);
} else if (error.request) {
res.status(504).send('No response from target server.');
} else {
res.status(500).send(`Error setting up request to target server: ${error.message}`);
}
}
});

// Dynamic DELETE proxy route
app.delete('/proxy/*', async (req, res) => {
// Capture the part of the URL after '/proxy'
const path = req.params[0];

// Target server base URL
const targetBaseUrl = req.headers['x-real-url'];
delete req.headers['x-real-url'];

headersToRemove.forEach(header => {
delete req.headers[header.toLowerCase()];
});

try {
const response = await axios.delete(`${targetBaseUrl}/${path}`, {
headers: {
...req.headers,
'Content-Type': 'application/json',
'Host': new URL(targetBaseUrl).hostname, // Update the Host header for the target server
'Accept-Encoding': 'identity'
}
});

Expand Down

0 comments on commit b05fe94

Please sign in to comment.