-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
70 lines (57 loc) · 1.83 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import string
# Set cuda devices to use
DEVICE_IDS=[1]
PRIMARY_DEVICE = 'cuda:' + str(DEVICE_IDS[0])
'''
Experiment paramenters
'''
# Name of experiement (used to save files)
EXPERIMENT = 'TextOCR_PreEncoder_Overlap'
# Pretained model to use
SAVED_MODEL = './results/models/scratch.pt'#'./results/models/TextOCR_base_from_scratch.pt'#
RANDOM_SEED = 999
BATCH_SIZE = 192
EPOCHS = 8
MAX_TEXT_LENGTH = 25
CHARS = string.printable[:-6]
MODEL_SAVE_THRESHOLD = 0 # Once val accuraccy % passes this threshold highest accuraccy models are saved to ./results/models
'''
Model design
'''
ENCODER = 'Transformer' # LSTM | Transformer | Oscar
DECODER = 'Transformer' # LSTM | Transformer | Linear
# Dimensions
EMBED_DIM = 256
HIDDEN_DIM = 512
# Semantic vector processing
SEMANTIC_VECTOR = 'overlap' # overlap | scene | combined
SEMANTIC_SOURCE = 'vinvl' # coco | vg | vinvl | zero | rand
SEMANTIC_ASSIGNMENT = 'resize' # .25 | .50 | .75 | resize (if .25/.50/.75 then using iou assignment)
SEMANTIC_EMBEDDING = 'linear' # bert | linear
'''
Fusion Strategies
'''
PRINT_ATTENTION_SCORES = False
# Encoder
PRE_ENCODER_MLP = False
OSCAR_ENCODER = False # No pretrained models yet
# Decoder
PRE_DECODER_MLP = False
CLS_DECODER_INIT = False
MULTIHEAD_PRE_TARGET = False
MULTIHEAD_PRE_MEMORY = False
MULTIHEAD_POST_MEMORY = False
POST_DECODER_MLP = False
'''
Local paths
'''
# COCO text json file path
COCOTEXT_API_PATH = './annotations/COCO_Text_2014.json'
# COCO train 2014 image folder path
COCOTEXT_IMAGE_PATH = "/data_ssd1/jplacidi/coco_data/images/train2014/"
# Deep Text dataset folders
DEEP_TEXT_DATASET_PATH = "/data_ssd1/jplacidi/deep_text_datasets/"
# TextOCR path to folder containing train, val and test .jsons
TEXTOCR_ANNO_PATH = "/data_ssd1/jplacidi/TextOCR/"
# TextOCR image path
TEXTOCR_IMAGE_PATH = "/data_ssd1/jplacidi/TextOCR/"