Skip to content

Commit

Permalink
Support new webllm syntax
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <[email protected]>
  • Loading branch information
xiaohk committed Apr 29, 2024
1 parent 7713346 commit 155461f
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 101 deletions.
2 changes: 1 addition & 1 deletion examples/rag-playground/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
},
"devDependencies": {
"@floating-ui/dom": "^1.6.1",
"@mlc-ai/web-llm": "^0.2.18",
"@mlc-ai/web-llm": "0.2.35",
"@types/d3-array": "^3.2.1",
"@types/d3-format": "^3.0.4",
"@types/d3-random": "^3.0.3",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ const apiKeyDescriptionMap: Record<ModelFamily, TemplateResult> = {
const localModelSizeMap: Record<SupportedLocalModel, string> = {
[SupportedLocalModel['tinyllama-1.1b']]: '630 MB',
[SupportedLocalModel['llama-2-7b']]: '3.6 GB',
[SupportedLocalModel['phi-2']]: '1.5 GB'
// [SupportedLocalModel['gpt-2']]: '311 MB'
[SupportedLocalModel['phi-2']]: '1.5 GB',
[SupportedLocalModel['gemma-2b']]: '1.3 GB'
// [SupportedLocalModel['mistral-7b-v0.2']]: '3.5 GB'
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ export class MememoPlayground extends LitElement {
}

// case SupportedLocalModel['mistral-7b-v0.2']:
// case SupportedLocalModel['gpt-2']:
case SupportedLocalModel['gemma-2b']:
case SupportedLocalModel['phi-2']:
case SupportedLocalModel['llama-2-7b']:
case SupportedLocalModel['tinyllama-1.1b']: {
Expand Down
11 changes: 4 additions & 7 deletions examples/rag-playground/src/components/playground/user-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ import { get, set, del, clear } from 'idb-keyval';
const PREFIX = 'user-config';

export enum SupportedLocalModel {
'gemma-2b' = 'Gemma (2B)',
'llama-2-7b' = 'Llama 2 (7B)',
// 'mistral-7b-v0.2' = 'Mistral (7B)',
'phi-2' = 'Phi 2 (2.7B)',
'tinyllama-1.1b' = 'TinyLlama (1.1B)'
// 'gpt-2' = 'GPT 2 (124M)'
}

export enum SupportedRemoteModel {
Expand All @@ -27,9 +26,8 @@ export const supportedModelReverseLookup: Record<
[SupportedRemoteModel['gemini-pro']]: 'gemini-pro',
[SupportedLocalModel['tinyllama-1.1b']]: 'tinyllama-1.1b',
[SupportedLocalModel['llama-2-7b']]: 'llama-2-7b',
[SupportedLocalModel['phi-2']]: 'phi-2'
// [SupportedLocalModel['gpt-2']]: 'gpt-2'
// [SupportedLocalModel['mistral-7b-v0.2']]: 'mistral-7b-v0.2'
[SupportedLocalModel['phi-2']]: 'phi-2',
[SupportedLocalModel['gemma-2b']]: 'gemma-2b'
};

export enum ModelFamily {
Expand All @@ -48,8 +46,7 @@ export const modelFamilyMap: Record<
[SupportedRemoteModel['gemini-pro']]: ModelFamily.google,
[SupportedLocalModel['tinyllama-1.1b']]: ModelFamily.local,
[SupportedLocalModel['llama-2-7b']]: ModelFamily.local,
// [SupportedLocalModel['gpt-2']]: ModelFamily.local
// [SupportedLocalModel['mistral-7b-v0.2']]: ModelFamily.local
[SupportedLocalModel['gemma-2b']]: ModelFamily.local,
[SupportedLocalModel['phi-2']]: ModelFamily.local
};

Expand Down
171 changes: 81 additions & 90 deletions examples/rag-playground/src/llms/web-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,101 +30,78 @@ export type TextGenLocalWorkerMessage =
//==========================================================================||
// Worker Initialization ||
//==========================================================================||
const APP_CONFIGS: webllm.AppConfig = {
model_list: [
{
model_url:
'https://huggingface.co/mlc-ai/TinyLlama-1.1B-Chat-v0.4-q4f16_1-MLC/resolve/main/',
local_id: 'TinyLlama-1.1B-Chat-v0.4-q4f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/TinyLlama-1.1B-Chat-v0.4/TinyLlama-1.1B-Chat-v0.4-q4f16_1-ctx1k-webgpu.wasm'
},
{
model_url:
'https://huggingface.co/mlc-ai/Llama-2-7b-chat-hf-q4f16_1-MLC/resolve/main/',
local_id: 'Llama-2-7b-chat-hf-q4f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-ctx1k-webgpu.wasm'
},
{
model_url: 'https://huggingface.co/mlc-ai/gpt2-q0f16-MLC/resolve/main/',
local_id: 'gpt2-q0f16',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/gpt2/gpt2-q0f16-ctx1k-webgpu.wasm'
},
{
model_url:
'https://huggingface.co/mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC/resolve/main/',
local_id: 'Mistral-7B-Instruct-v0.2-q3f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-sw4k_cs1k-webgpu.wasm'
},
{
model_url:
'https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC/resolve/main/',
local_id: 'Phi2-q4f16_1',
model_lib_url:
'https://raw.githubusercontent.com/mlc-ai/binary-mlc-llm-libs/main/phi-2/phi-2-q4f16_1-ctx2k-webgpu.wasm',
vram_required_MB: 3053.97,
low_resource_required: false,
required_features: ['shader-f16']
}
]
};
enum Role {
user = 'user',
assistant = 'assistant'
}

const CONV_TEMPLATES: Record<
SupportedLocalModel,
Partial<ConvTemplateConfig>
> = {
[SupportedLocalModel['tinyllama-1.1b']]: {
system: '<|im_start|><|im_end|> ',
roles: ['<|im_start|>user', '<|im_start|>assistant'],
system_template: '<|im_start|><|im_end|> ',
roles: {
[Role.user]: '<|im_start|>user',
[Role.assistant]: '<|im_start|>assistant'
},
offset: 0,
seps: ['', ''],
separator_style: 'Two',
stop_str: '<|im_end|>',
add_bos: false,
stop_tokens: [2]
stop_str: ['<|im_end|>'],
stop_token_ids: [2]
},
[SupportedLocalModel['llama-2-7b']]: {
system: '[INST] <<SYS>><</SYS>>\n\n ',
roles: ['[INST]', '[/INST]'],
system_template: '[INST] <<SYS>><</SYS>>\n\n ',
roles: {
[Role.user]: '[INST]',
[Role.assistant]: '[/INST]'
},
offset: 0,
seps: [' ', ' '],
separator_style: 'Two',
stop_str: '[INST]',
add_bos: true,
stop_tokens: [2]
role_content_sep: ' ',
role_empty_sep: ' ',
stop_str: ['[INST]'],
system_prefix_token_ids: [1],
stop_token_ids: [2],
add_role_after_system_message: false
},
[SupportedLocalModel['phi-2']]: {
system: '',
roles: ['Instruct', 'Output'],
system_template: '',
system_message: '',
roles: {
[Role.user]: 'Instruct',
[Role.assistant]: 'Output'
},
offset: 0,
seps: ['\n'],
separator_style: 'Two',
stop_str: '<|endoftext|>',
add_bos: false,
stop_tokens: [50256]
stop_str: ['<|endoftext|>'],
stop_token_ids: [50256]
},
[SupportedLocalModel['gemma-2b']]: {
system_template: '',
system_message: '',
roles: {
[Role.user]: '<start_of_turn>user',
[Role.assistant]: '<start_of_turn>model'
},
offset: 0,
seps: ['<end_of_turn>\n', '<end_of_turn>\n'],
role_content_sep: '\n',
role_empty_sep: '\n',
stop_str: ['<end_of_turn>'],
system_prefix_token_ids: [2],
stop_token_ids: [1, 107]
}
};

const modelMap: Record<SupportedLocalModel, string> = {
[SupportedLocalModel['tinyllama-1.1b']]: 'TinyLlama-1.1B-Chat-v0.4-q4f16_1',
[SupportedLocalModel['llama-2-7b']]: 'Llama-2-7b-chat-hf-q4f16_1',
[SupportedLocalModel['phi-2']]: 'Phi2-q4f16_1'
// [SupportedLocalModel['gpt-2']]: 'gpt2-q0f16'
// [SupportedLocalModel['mistral-7b-v0.2']]: 'Mistral-7B-Instruct-v0.2-q3f16_1'
[SupportedLocalModel['phi-2']]: 'Phi2-q4f16_1',
[SupportedLocalModel['gemma-2b']]: 'gemma-2b-it-q4f16_1'
};

const chat = new webllm.ChatModule();

// To reset temperature, WebLLM requires to reload the model. Therefore, we just
// fix the temperature for now.
let _temperature = 0.2;

let _modelLoadingComplete: Promise<void> | null = null;

chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
const initProgressCallback = (report: webllm.InitProgressReport) => {
// Update the main thread about the progress
console.log(report.text);
const message: TextGenLocalWorkerMessage = {
Expand All @@ -135,7 +112,9 @@ chat.setInitProgressCallback((report: webllm.InitProgressReport) => {
}
};
postMessage(message);
});
};

let engine: Promise<webllm.EngineInterface> | null = null;

//==========================================================================||
// Worker Event Handlers ||
Expand Down Expand Up @@ -179,15 +158,25 @@ const startLoadModel = async (
model: SupportedLocalModel,
temperature: number
) => {
_temperature = temperature;
const curModel = modelMap[model];
const chatOption: webllm.ChatOptions = {
temperature: temperature,
conv_config: CONV_TEMPLATES[model],
conv_template: 'custom'
};
_modelLoadingComplete = chat.reload(curModel, chatOption, APP_CONFIGS);
await _modelLoadingComplete;

// Only use custom conv template for Llama to override the pre-included system
// prompt from WebLLM
let chatOption: webllm.ChatOptions | undefined = undefined;

if (model === SupportedLocalModel['llama-2-7b']) {
chatOption = {
conv_config: CONV_TEMPLATES[model],
conv_template: 'custom'
};
}

engine = webllm.CreateEngine(curModel, {
initProgressCallback: initProgressCallback,
chatOpts: chatOption
});

await engine;

try {
// Send back the data to the main thread
Expand Down Expand Up @@ -220,24 +209,26 @@ const startLoadModel = async (
*/
const startTextGen = async (prompt: string, temperature: number) => {
try {
if (_modelLoadingComplete) {
await _modelLoadingComplete;
}

const truncated = prompt.slice(0, 2000);

const response = await chat.generate(truncated);
const curEngine = await engine!;
const response = await curEngine.chat.completions.create({
messages: [{ role: 'user', content: prompt }],
n: 1,
max_gen_len: 2048,
// Override temperature to 0 because local models are very unstable
temperature: 0
// logprobs: false
});

// Reset the chat cache to avoid memorizing previous messages
await chat.resetChat();
await curEngine.resetChat();

// Send back the data to the main thread
const message: TextGenLocalWorkerMessage = {
command: 'finishTextGen',
payload: {
requestID: 'web-llm',
apiKey: '',
result: response,
result: response.choices[0].message.content || '',
prompt: prompt,
detail: ''
}
Expand All @@ -263,7 +254,7 @@ const startTextGen = async (prompt: string, temperature: number) => {

export const hasLocalModelInCache = async (model: SupportedLocalModel) => {
const curModel = modelMap[model];
const inCache = await webllm.hasModelInCache(curModel, APP_CONFIGS);
const inCache = await webllm.hasModelInCache(curModel);
return inCache;
};

Expand Down

0 comments on commit 155461f

Please sign in to comment.