Skip to content

Commit

Permalink
Add customizable ban tokens (#3899)
Browse files Browse the repository at this point in the history
  • Loading branch information
sALTaccount authored Sep 15, 2023
1 parent fb864da commit f01b9aa
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 5 deletions.
1 change: 1 addition & 0 deletions api-examples/api-example-chat-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def run(user_input, history):
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def run(user_input, history):
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example-stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async def run(context):
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
Expand Down
1 change: 1 addition & 0 deletions api-examples/api-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def run(prompt):
'add_bos_token': True,
'truncation_length': 2048,
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'stopping_strings': []
}
Expand Down
1 change: 1 addition & 0 deletions extensions/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def build_parameters(body, chat=False):
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'custom_token_bans': str(body.get('custom_token_bans', '')),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
Expand Down
1 change: 1 addition & 0 deletions extensions/openai/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
'guidance_scale': 1,
'negative_prompt': '',
'ban_eos_token': False,
'custom_token_bans': '',
'skip_special_tokens': True,
'custom_stopping_strings': '',
# 'logits_processor' - conditionally passed
Expand Down
5 changes: 5 additions & 0 deletions modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def generate_with_streaming(self, prompt, state):
else:
self.generator.disallow_tokens(None)

if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
self.generator.disallow_tokens(self.tokenizer, to_ban)

# Case 1: no CFG
if state['guidance_scale'] == 1:
self.generator.end_beam_search()
Expand Down
7 changes: 6 additions & 1 deletion modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def from_pretrained(self, path_to_model):
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value

model = ExLlamaV2(config)

split = None
Expand Down Expand Up @@ -60,6 +60,11 @@ def generate_with_streaming(self, prompt, state):
if state['ban_eos_token']:
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])

if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban)

ids = self.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1]
Expand Down
20 changes: 17 additions & 3 deletions modules/llamacpp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits):
return logits


def custom_token_ban_logits_processor(token_ids, input_ids, logits):
for token_id in token_ids:
logits[token_id] = -float('inf')

return logits


class LlamaCppModel:
def __init__(self):
self.initialized = False
Expand Down Expand Up @@ -104,6 +111,15 @@ def generate(self, prompt, state, callback=None):
prompt = prompt[-get_max_prompt_length(state):]
prompt = self.decode(prompt).decode('utf-8')

logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
logit_processors.append(partial(ban_eos_logits_processor, self.model.tokenizer.eos_token_id))

if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))

completion_chunks = self.model.create_completion(
prompt=prompt,
max_tokens=state['max_new_tokens'],
Expand All @@ -116,9 +132,7 @@ def generate(self, prompt, state, callback=None):
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=LogitsProcessorList([
partial(ban_eos_logits_processor, self.model.token_eos()),
]) if state['ban_eos_token'] else None,
logits_processor=logit_processors,
)

output = ""
Expand Down
9 changes: 9 additions & 0 deletions modules/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand All @@ -176,6 +177,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand All @@ -191,6 +193,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'auto_max_new_tokens',
},
'ExLlamav2': {
Expand All @@ -201,6 +204,7 @@
'repetition_penalty_range',
'seed',
'ban_eos_token',
'custom_token_bans',
'auto_max_new_tokens',
},
'ExLlamav2_HF': {
Expand All @@ -225,6 +229,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand Down Expand Up @@ -255,6 +260,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand Down Expand Up @@ -285,6 +291,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand All @@ -299,6 +306,7 @@
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'custom_token_bans',
},
'llamacpp_HF': {
'temperature',
Expand All @@ -322,6 +330,7 @@
'guidance_scale',
'negative_prompt',
'ban_eos_token',
'custom_token_bans',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
Expand Down
1 change: 1 addition & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def default_preset():
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
'custom_token_bans': '',
}


Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'ban_eos_token': False,
'custom_token_bans': '',
'add_bos_token': True,
'skip_special_tokens': True,
'stream': True,
Expand Down
8 changes: 8 additions & 0 deletions modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]

if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
if generate_params.get('suppress_tokens', None):
generate_params['suppress_tokens'] += to_ban
else:
generate_params['suppress_tokens'] = to_ban

generate_params.update({'use_cache': not shared.args.no_cache})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
Expand Down
1 change: 1 addition & 0 deletions modules/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def list_interface_input_elements():
'guidance_scale',
'add_bos_token',
'ban_eos_token',
'custom_token_bans',
'truncation_length',
'custom_stopping_strings',
'skip_special_tokens',
Expand Down
2 changes: 1 addition & 1 deletion modules/ui_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def create_ui(default_preset):
with gr.Column():
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Custom token bans', info='Specific token IDs to ban from generating, comma-separated. The IDs can be found in a tokenizer.json file.')
shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')

shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming')

Expand Down
1 change: 1 addition & 0 deletions settings-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ custom_stopping_strings: ''
auto_max_new_tokens: false
max_tokens_second: 0
ban_eos_token: false
custom_token_bans: ''
add_bos_token: true
skip_special_tokens: true
stream: true
Expand Down

0 comments on commit f01b9aa

Please sign in to comment.