From 95686959603832a7eba3c0e7b0ca052b09f21a93 Mon Sep 17 00:00:00 2001 From: ZhouEEEEEE Date: Mon, 6 Mar 2023 22:05:39 -0500 Subject: [PATCH 1/2] Create data_classes.py which includes data transformation from database to AI usage data --- server/AI/data_classes.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 server/AI/data_classes.py diff --git a/server/AI/data_classes.py b/server/AI/data_classes.py new file mode 100644 index 0000000..76967e2 --- /dev/null +++ b/server/AI/data_classes.py @@ -0,0 +1,36 @@ +# This file is used for handel data in the database and send to the AI for prediction +from transformers import TrainingArguments, Trainer +import os +import torch +from torch.utils.data import Dataset +# +class WardrobeDataset(Dataset): + def __init__(self, weather_lst, occasion_lst, color_lst, budget_lst, style_lst, tokenizer): + self.input_ids = [] + self.attention_mask = [] + self.labels = [] + # self.map_label = label_maps + + for weather, occasion, color, budget, style in zip(weather_lst, occasion_lst, color_lst, budget_lst, style_lst): + # prep_txt = f'Content: {txt}\nLabel: {self.map_label[label]}' + + prep_txt = f"Today’s weather is {weather}. I’m having a {occasion}. I prefer my clothing " \ + f"color in {color}. Please give my an outfit in {style}. " \ + f"Please suggest clothes that in budget {budget} if not selected " \ + f"from my wardrobe" + + encodings_dict = tokenizer(prep_txt, truncation=True, padding="max_length") + + self.input_ids.append(torch.tensor(encodings_dict['input_ids'])) + self.attention_mask.append(torch.tensor(encodings_dict['attention_mask'])) + + def __len__(self): + return len(self.input_ids) + + def __getitem__(self, idx): + dic = { + 'input_ids': self.input_ids[idx], + 'attention_mask': self.attention_mask[idx] + } + # return self.input_ids[idx], self.attention_mask[idx], self.labels[idx] + return dic \ No newline at end of file From 6a7d6162f0d1aad3c81760f8fc2c7b9fefaa82a1 Mon Sep 17 00:00:00 2001 From: ZhouEEEEEE Date: Mon, 6 Mar 2023 22:06:24 -0500 Subject: [PATCH 2/2] Added recommand.py --- server/AI/recommand.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 server/AI/recommand.py diff --git a/server/AI/recommand.py b/server/AI/recommand.py new file mode 100644 index 0000000..1fcb8ef --- /dev/null +++ b/server/AI/recommand.py @@ -0,0 +1,31 @@ +from data_classes import WardrobeDataset +from datasets import load_dataset, load_metric +from transformers import AutoTokenizer +from transformers import AutoModelForSequenceClassification +import numpy as np +from transformers import TrainingArguments, Trainer +import os +import torch +from torch.utils.data import Dataset + +def recommand_outfit(weather_lst, occasion_lst, color_lst, budget_lst, style_lst): + ''' + :param weather_lst: + :param occasion_lst: + :param color_lst: + :param budget_lst: + :param style_lst: + :return: + This function is used to recommand outfit based on the user's preferences + ''' + # Load the tokenizer + tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") + tokenizer.pad_token = tokenizer.eos_token + data = WardrobeDataset(weather_lst, occasion_lst, color_lst, budget_lst, style_lst, tokenizer) + # Load the model + model = AutoModelForSequenceClassification.from_pretrained("./outfit_recommand_model") + trainer = Trainer(model=model) + os.environ["WANDB_DISABLED"] = "true" + predictions = trainer.predict(data) + + return np.argmax(predictions.predictions, axis=-1) \ No newline at end of file