diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 5670d4cff1..bf4201ca23 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -70,6 +70,7 @@ async def run(user_input, history): 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 26c69b7384..42ba0a6268 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -64,6 +64,7 @@ def run(user_input, history): 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index c042a50b60..5382216259 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -53,6 +53,7 @@ async def run(context): 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 4736275490..e6d79f9bc0 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -45,6 +45,7 @@ def run(prompt): 'add_bos_token': True, 'truncation_length': 2048, 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'stopping_strings': [] } diff --git a/extensions/api/util.py b/extensions/api/util.py index 6d0cb170ec..499706caa2 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -49,6 +49,7 @@ def build_parameters(body, chat=False): 'seed': int(body.get('seed', -1)), 'add_bos_token': bool(body.get('add_bos_token', True)), 'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))), + 'custom_token_bans': str(body.get('custom_token_bans', '')), 'ban_eos_token': bool(body.get('ban_eos_token', False)), 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), 'custom_stopping_strings': '', # leave this blank diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index c6a6adfde7..052862f74a 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -37,6 +37,7 @@ 'guidance_scale': 1, 'negative_prompt': '', 'ban_eos_token': False, + 'custom_token_bans': '', 'skip_special_tokens': True, 'custom_stopping_strings': '', # 'logits_processor' - conditionally passed diff --git a/modules/exllama.py b/modules/exllama.py index c9ff12283d..177f028f3c 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -108,6 +108,11 @@ def generate_with_streaming(self, prompt, state): else: self.generator.disallow_tokens(None) + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + self.generator.disallow_tokens(self.tokenizer, to_ban) + # Case 1: no CFG if state['guidance_scale'] == 1: self.generator.end_beam_search() diff --git a/modules/exllamav2.py b/modules/exllamav2.py index 6d4603c58f..a325a4d376 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -30,7 +30,7 @@ def from_pretrained(self, path_to_model): config.max_seq_len = shared.args.max_seq_len config.scale_pos_emb = shared.args.compress_pos_emb config.scale_alpha_value = shared.args.alpha_value - + model = ExLlamaV2(config) split = None @@ -60,6 +60,11 @@ def generate_with_streaming(self, prompt, state): if state['ban_eos_token']: settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + settings.disallow_tokens(self.tokenizer, to_ban) + ids = self.tokenizer.encode(prompt) ids = ids[:, -get_max_prompt_length(state):] initial_len = ids.shape[-1] diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index f09ca50547..5db6e27e69 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -31,6 +31,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits): return logits +def custom_token_ban_logits_processor(token_ids, input_ids, logits): + for token_id in token_ids: + logits[token_id] = -float('inf') + + return logits + + class LlamaCppModel: def __init__(self): self.initialized = False @@ -104,6 +111,15 @@ def generate(self, prompt, state, callback=None): prompt = prompt[-get_max_prompt_length(state):] prompt = self.decode(prompt).decode('utf-8') + logit_processors = LogitsProcessorList() + if state['ban_eos_token']: + logit_processors.append(partial(ban_eos_logits_processor, self.model.tokenizer.eos_token_id)) + + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + logit_processors.append(partial(custom_token_ban_logits_processor, to_ban)) + completion_chunks = self.model.create_completion( prompt=prompt, max_tokens=state['max_new_tokens'], @@ -116,9 +132,7 @@ def generate(self, prompt, state, callback=None): mirostat_tau=state['mirostat_tau'], mirostat_eta=state['mirostat_eta'], stream=True, - logits_processor=LogitsProcessorList([ - partial(ban_eos_logits_processor, self.model.token_eos()), - ]) if state['ban_eos_token'] else None, + logits_processor=logit_processors, ) output = "" diff --git a/modules/loaders.py b/modules/loaders.py index ff2f50501c..b7187e5f47 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -150,6 +150,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', @@ -176,6 +177,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', @@ -191,6 +193,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'auto_max_new_tokens', }, 'ExLlamav2': { @@ -201,6 +204,7 @@ 'repetition_penalty_range', 'seed', 'ban_eos_token', + 'custom_token_bans', 'auto_max_new_tokens', }, 'ExLlamav2_HF': { @@ -225,6 +229,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', @@ -255,6 +260,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', @@ -285,6 +291,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', @@ -299,6 +306,7 @@ 'mirostat_tau', 'mirostat_eta', 'ban_eos_token', + 'custom_token_bans', }, 'llamacpp_HF': { 'temperature', @@ -322,6 +330,7 @@ 'guidance_scale', 'negative_prompt', 'ban_eos_token', + 'custom_token_bans', 'add_bos_token', 'skip_special_tokens', 'auto_max_new_tokens', diff --git a/modules/presets.py b/modules/presets.py index 32b7f71c52..96d6e994e4 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -28,6 +28,7 @@ def default_preset(): 'num_beams': 1, 'length_penalty': 1, 'early_stopping': False, + 'custom_token_bans': '', } diff --git a/modules/shared.py b/modules/shared.py index 2555eca499..30fa1393af 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -49,6 +49,7 @@ 'auto_max_new_tokens': False, 'max_tokens_second': 0, 'ban_eos_token': False, + 'custom_token_bans': '', 'add_bos_token': True, 'skip_special_tokens': True, 'stream': True, diff --git a/modules/text_generation.py b/modules/text_generation.py index 67833d8c2c..98682bb23a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -266,6 +266,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if state['ban_eos_token']: generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] + if state['custom_token_bans']: + to_ban = [int(x) for x in state['custom_token_bans'].split(',')] + if len(to_ban) > 0: + if generate_params.get('suppress_tokens', None): + generate_params['suppress_tokens'] += to_ban + else: + generate_params['suppress_tokens'] = to_ban + generate_params.update({'use_cache': not shared.args.no_cache}) if shared.args.deepspeed: generate_params.update({'synced_gpus': True}) diff --git a/modules/ui.py b/modules/ui.py index 790bc3b59d..0a19b2315a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -118,6 +118,7 @@ def list_interface_input_elements(): 'guidance_scale', 'add_bos_token', 'ban_eos_token', + 'custom_token_bans', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 169ab5002a..32fb1c023c 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -118,8 +118,8 @@ def create_ui(default_preset): with gr.Column(): shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') + shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in a tokenizer.json file.') shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') - shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') diff --git a/settings-template.yaml b/settings-template.yaml index d4a3c70957..66d98d396f 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -19,6 +19,7 @@ custom_stopping_strings: '' auto_max_new_tokens: false max_tokens_second: 0 ban_eos_token: false +custom_token_bans: '' add_bos_token: true skip_special_tokens: true stream: true