-
Notifications
You must be signed in to change notification settings - Fork 18
/
train.py
138 lines (117 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import argparse
import wandb
import torch
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pathlib import Path
from datetime import datetime
from models.GTM import GTM
from models.FCN import FCN
from utils.data_multitrends import ZeroShotDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def run(args):
print(args)
# Seeds for reproducibility (By default we use the number 21)
pl.seed_everything(args.seed)
# Load sales data
train_df = pd.read_csv(Path(args.data_folder + 'train.csv'), parse_dates=['release_date'])
test_df = pd.read_csv(Path(args.data_folder + 'test.csv'), parse_dates=['release_date'])
# Load category and color encodings
cat_dict = torch.load(Path(args.data_folder + 'category_labels.pt'))
col_dict = torch.load(Path(args.data_folder + 'color_labels.pt'))
fab_dict = torch.load(Path(args.data_folder + 'fabric_labels.pt'))
# Load Google trends
gtrends = pd.read_csv(Path(args.data_folder + 'gtrends.csv'), index_col=[0], parse_dates=True)
train_loader = ZeroShotDataset(train_df, Path(args.data_folder + '/images'), gtrends, cat_dict, col_dict,
fab_dict, args.trend_len).get_loader(batch_size=args.batch_size, train=True)
test_loader = ZeroShotDataset(test_df, Path(args.data_folder + '/images'), gtrends, cat_dict, col_dict,
fab_dict, args.trend_len).get_loader(batch_size=1, train=False)
# Create model
if args.model_type == 'FCN':
model = FCN(
embedding_dim=args.embedding_dim,
hidden_dim=args.hidden_dim,
output_dim=args.output_dim,
cat_dict=cat_dict,
col_dict=col_dict,
fab_dict=fab_dict,
use_trends=args.use_trends,
use_text=args.use_text,
use_img=args.use_img,
trend_len=args.trend_len,
num_trends=args.num_trends,
use_encoder_mask=args.use_encoder_mask,
gpu_num=args.gpu_num
)
else:
model = GTM(
embedding_dim=args.embedding_dim,
hidden_dim=args.hidden_dim,
output_dim=args.output_dim,
num_heads=args.num_attn_heads,
num_layers=args.num_hidden_layers,
cat_dict=cat_dict,
col_dict=col_dict,
fab_dict=fab_dict,
use_text=args.use_text,
use_img=args.use_img,
trend_len=args.trend_len,
num_trends=args.num_trends,
use_encoder_mask=args.use_encoder_mask,
autoregressive=args.autoregressive,
gpu_num=args.gpu_num
)
# Model Training
# Define model saving procedure
dt_string = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
model_savename = args.model_type + '_' + args.wandb_run
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=args.log_dir + '/'+args.model_type,
filename=model_savename+'---{epoch}---'+dt_string,
monitor='val_mae',
mode='min',
save_top_k=1
)
wandb.init(entity=args.wandb_entity, project=args.wandb_proj, name=args.wandb_run)
wandb_logger = pl_loggers.WandbLogger()
wandb_logger.watch(model)
# If you wish to use Tensorboard you can change the logger to:
# tb_logger = pl_loggers.TensorBoardLogger(args.log_dir+'/', name=model_savename)
trainer = pl.Trainer(gpus=[args.gpu_num], max_epochs=args.epochs, check_val_every_n_epoch=5,
logger=wandb_logger, callbacks=[checkpoint_callback])
# Fit model
trainer.fit(model, train_dataloaders=train_loader,
val_dataloaders=test_loader)
# Print out path of best model
print(checkpoint_callback.best_model_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Zero-shot sales forecasting')
# General arguments
parser.add_argument('--data_folder', type=str, default='dataset/')
parser.add_argument('--log_dir', type=str, default='log')
parser.add_argument('--seed', type=int, default=21)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--gpu_num', type=int, default=0)
# Model specific arguments
parser.add_argument('--model_type', type=str, default='GTM', help='Choose between GTM or FCN')
parser.add_argument('--use_trends', type=int, default=1)
parser.add_argument('--use_img', type=int, default=1)
parser.add_argument('--use_text', type=int, default=1)
parser.add_argument('--trend_len', type=int, default=52)
parser.add_argument('--num_trends', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--embedding_dim', type=int, default=32)
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--output_dim', type=int, default=12)
parser.add_argument('--use_encoder_mask', type=int, default=1)
parser.add_argument('--autoregressive', type=int, default=0)
parser.add_argument('--num_attn_heads', type=int, default=4)
parser.add_argument('--num_hidden_layers', type=int, default=1)
# wandb arguments
parser.add_argument('--wandb_entity', type=str, default='username-here')
parser.add_argument('--wandb_proj', type=str, default='GTM')
parser.add_argument('--wandb_run', type=str, default='Run1')
args = parser.parse_args()
run(args)