-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #83 from fractalego/development
Development
- Loading branch information
Showing
237 changed files
with
2,015 additions
and
9,414 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import asyncio | ||
import pandas as pd | ||
|
||
from wafl.config import Configuration | ||
from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector | ||
|
||
|
||
def get_prompt(df, theme): | ||
prompt = "" | ||
for _, row in df.sample(9).iterrows(): | ||
prompt += ( | ||
f""" | ||
<task> | ||
Create a plausible dialogue about the theme \"{row["Theme"]}\" based on the following summary and rules. | ||
The rules are as follows: | ||
{row["Rules"]} | ||
The conversation goes as follows: | ||
{row["Conversation"]} | ||
</task> | ||
""".strip() | ||
+ "\n\n" | ||
) | ||
|
||
return ( | ||
prompt | ||
+ f'<task>\nCreate plausible dialogue about the theme "{theme}" based on the following summary and rules.\n\nThe rules are as follows:\n' | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
config = Configuration.load_local_config() | ||
remote_llm_connector = RemoteLLMConnector( | ||
config.get_value("llm_model"), last_strings=["</task>"] | ||
) | ||
|
||
df = pd.read_csv("data/complex_instructions.csv") | ||
theme = "playing a song that the user likes" | ||
prompt = get_prompt(df, theme) | ||
print( | ||
asyncio.run( | ||
remote_llm_connector.predict(prompt, temperature=0.5, num_tokens=1500) | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import random | ||
|
||
import pandas as pd | ||
from datasets import Dataset | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModelForCausalLM, | ||
TrainingArguments, | ||
Trainer, | ||
DataCollatorForLanguageModeling, | ||
) | ||
|
||
model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1" | ||
max_length = 1024 + 512 | ||
|
||
|
||
def get_prompts(df): | ||
prompts = [] | ||
for _, row in df.sample(frac=1).iterrows(): | ||
memory = "" | ||
if memory == "": | ||
memory = "The user has no memory." | ||
|
||
current_rule = row["Rules"] | ||
rules = df.sample(random.choice([1, 2]))["Rules"].tolist() + [current_rule] | ||
random.shuffle(rules) | ||
rules = "\n".join(rules) | ||
prompt = ( | ||
f""" | ||
The user is talking with a chatbot about the theme \"{row["Theme"]}\" based on the following summary. | ||
<summary> | ||
{memory} | ||
</summary> | ||
The rules are as follows: | ||
<rules> | ||
{rules} | ||
</rules> | ||
The conversation goes as follows: | ||
{row["Conversation"]} | ||
""".strip() | ||
+ "\n\n" | ||
) | ||
prompts.append(prompt) | ||
|
||
return prompts | ||
|
||
|
||
def preprocess_function(sample): | ||
model_inputs = tokenizer( | ||
sample["prompt"], | ||
return_tensors="pt", | ||
max_length=max_length, | ||
padding="max_length", | ||
) | ||
labels = tokenizer( | ||
sample["prompt"], | ||
return_tensors="pt", | ||
max_length=max_length, | ||
padding="max_length", | ||
) | ||
|
||
model_inputs["labels"] = labels["input_ids"] | ||
return model_inputs | ||
|
||
|
||
def model_init(): | ||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path) | ||
parameters = model.parameters() | ||
for parameter in parameters: | ||
parameter.requires_grad = False | ||
|
||
model.model.enable_input_require_grads() | ||
model.lm_head.training = True | ||
for index in range(len(model.model.layers)): | ||
model.model.layers[index].self_attn.k_proj.training = True | ||
|
||
return model | ||
|
||
|
||
def create_dataset_from_file(filepath): | ||
df = pd.read_csv(filepath) | ||
prompts = get_prompts(df) | ||
return Dataset.from_dict({"prompt": prompts}) | ||
|
||
|
||
if __name__ == "__main__": | ||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
dataset = create_dataset_from_file("data/complex_instructions.csv") | ||
train_dataset = dataset.map( | ||
preprocess_function, batched=True, batch_size=1, num_proc=4 | ||
) | ||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) | ||
learning_rate = 1e-6 | ||
output_dir_name = f"checkpoint_lr{learning_rate}" | ||
training_args = TrainingArguments( | ||
output_dir=output_dir_name, | ||
per_device_train_batch_size=1, | ||
per_device_eval_batch_size=1, | ||
evaluation_strategy="steps", | ||
use_cpu=True, | ||
learning_rate=learning_rate, | ||
num_train_epochs=2, | ||
logging_steps=200, | ||
eval_steps=200, | ||
save_total_limit=1, | ||
) | ||
model = model_init() | ||
trainer = Trainer( | ||
model=model, | ||
args=training_args, | ||
tokenizer=tokenizer, | ||
data_collator=data_collator, | ||
train_dataset=train_dataset, | ||
) | ||
trainer.train() | ||
trainer.save_model("wafl-mistral") | ||
model = trainer.model | ||
model.push_to_hub("fractalego/wafl-mistral") | ||
tokenizer.push_to_hub("fractalego/wafl-mistral") |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
43 changes: 0 additions & 43 deletions
43
documentation/build/html/_sources/directory_structure.rst.txt
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ License | |
|
||
This software is licensed under the MIT License: | ||
|
||
Copyright (c) 2023 [email protected] | ||
Copyright (c) 2024 [email protected] | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | ||
|
||
|
This file was deleted.
Oops, something went wrong.
28 changes: 0 additions & 28 deletions
28
documentation/build/html/_sources/rules_and_backtracking.rst.txt
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.