Skip to content
This repository has been archived by the owner on Mar 5, 2024. It is now read-only.

Commit

Permalink
Added recommand.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhouEEEEEE committed Mar 7, 2023
1 parent 9568695 commit 6a7d616
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions server/AI/recommand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from data_classes import WardrobeDataset
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
import numpy as np
from transformers import TrainingArguments, Trainer
import os
import torch
from torch.utils.data import Dataset

def recommand_outfit(weather_lst, occasion_lst, color_lst, budget_lst, style_lst):
'''
:param weather_lst:
:param occasion_lst:
:param color_lst:
:param budget_lst:
:param style_lst:
:return:
This function is used to recommand outfit based on the user's preferences
'''
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.pad_token = tokenizer.eos_token
data = WardrobeDataset(weather_lst, occasion_lst, color_lst, budget_lst, style_lst, tokenizer)
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("./outfit_recommand_model")
trainer = Trainer(model=model)
os.environ["WANDB_DISABLED"] = "true"
predictions = trainer.predict(data)

return np.argmax(predictions.predictions, axis=-1)

0 comments on commit 6a7d616

Please sign in to comment.