forked from contrebande-labs/charred
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs.py
118 lines (114 loc) · 3.5 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--tokenizer_max_length",
type=int,
default=1024,
help="Maximum length of tokenized string. Longer strings will be truncated. Shorter strings will be padded.",
)
parser.add_argument(
"--output_dir",
type=str,
default="/data/output",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="/data/dataset/output/charred",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default="/data/dataset/cache",
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--seed", type=int, default=0, help="A seed for reproducible training."
)
parser.add_argument(
"--resolution",
type=int,
default=1024,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=10_000,
help="Total number of training steps per epoch to perform.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="The beta1 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="The beta2 parameter for the Adam optimizer.",
)
parser.add_argument(
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--push_to_hub",
type=bool,
default=False,
help="Whether or not to push the model to the Hub.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default="character-aware-diffusion/charred",
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
parser.add_argument(
"--log_wandb",
type=bool,
default=True,
choices=[True, False],
help=("Whether to use WandB to log the metrics or not"),
)
args = parser.parse_args()
return args