-
Notifications
You must be signed in to change notification settings - Fork 11
/
convert_checkpoint.py
170 lines (140 loc) · 5.18 KB
/
convert_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import sys
from dataclasses import dataclass, field
import warnings
from transformers import (
AutoConfig,
HfArgumentParser,
set_seed,
)
from lsg_converter.albert.convert_albert_checkpoint import *
from lsg_converter.bart.convert_bart_checkpoint import *
from lsg_converter.barthez.convert_barthez_checkpoint import *
from lsg_converter.bert.convert_bert_checkpoint import *
from lsg_converter.camembert.convert_camembert_checkpoint import *
from lsg_converter.distilbert.convert_distilbert_checkpoint import *
from lsg_converter.electra.convert_electra_checkpoint import *
from lsg_converter.mbart.convert_mbart_checkpoint import *
from lsg_converter.pegasus.convert_pegasus_checkpoint import *
from lsg_converter.roberta.convert_roberta_checkpoint import *
from lsg_converter.xlm_roberta.convert_xlm_roberta_checkpoint import *
_AUTH_MODELS = {
"albert": AlbertConversionScript,
"bart": BartConversionScript,
"barthez": BarthezConversionScript,
"bert": BertConversionScript,
"camembert": CamembertConversionScript,
"distilbert": DistilBertConversionScript,
"electra": ElectraConversionScript,
"mbart": MBartConversionScript,
"pegasus": PegasusConversionScript,
"roberta": RobertaConversionScript,
"xlm-roberta": XLMRobertaConversionScript,
}
@dataclass
class FileArguments:
"""
Arguments.
"""
initial_model: str = field(
metadata={"help": "Model to convert to its LSG variant"}
)
model_name: str = field(
metadata={"help": "Name/path of the newly created model"}
)
max_sequence_length: int = field(
default=4096,
metadata={"help": "Max sequence length"}
)
architecture: str = field(
default=None,
metadata={
"help": "Architecture (model specific, optional, e.g BartForConditionalGeneration)"}
)
random_global_init: bool = field(
default=False,
metadata={
"help": "Randomly initialize global tokens (except the first one)"}
)
global_positional_stride: int = field(
default=64,
metadata={
"help": "Positional stride of global tokens (copied from the original)"}
)
keep_first_global_token: bool = field(
default=False,
metadata={
"help": "Do not replace an existing first global token (only used if initial model is already LSG type)"}
)
resize_lsg: bool = field(
default=False,
metadata={
"help": "Only resize the positional embedding from a LSG model (skip global tokens)"}
)
model_kwargs: Optional[str] = field(
default="{}",
metadata={
"help": "Model kwargs, ex: \"{'sparsity_type': 'none', 'mask_first_token': true}\""
},
)
use_token_ids: bool = field(
default=True,
metadata={
"help": "Add token_type_ids (0) to global embeddings if the model allows it"}
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Either use auth token (for private models)"}
)
run_test: bool = field(
default=False,
metadata={
"help": "Test the architecture of the new model"}
)
seed: int = field(
default=123,
metadata={
"help": "Set seed for random initialization"}
)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((FileArguments, ))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
args = parser.parse_args_into_dataclasses()
args = args[0]
set_seed(args.seed)
# Get config
config = AutoConfig.from_pretrained(args.initial_model, trust_remote_code=True, use_auth_token=args.use_auth_token)
model_type = config.model_type
if model_type in _AUTH_MODELS.keys():
converter = _AUTH_MODELS[model_type](
initial_model=args.initial_model,
model_name=args.model_name,
max_sequence_length=args.max_sequence_length,
architecture=args.architecture,
random_global_init=args.random_global_init,
global_positional_stride=args.global_positional_stride,
keep_first_global_token=args.keep_first_global_token,
resize_lsg=args.resize_lsg,
model_kwargs=args.model_kwargs,
use_token_ids=args.use_token_ids,
use_auth_token=args.use_auth_token,
config=config,
save_model=True,
seed=args.seed
)
converter.process()
if args.run_test:
converter.run_test()
else:
s = "\n * " + "\n * ".join([k for k in _AUTH_MODELS.keys()])
warnings.warn(f"Model type <{model_type}> can not be handled by this script. Model type must be one of: {s}")
if __name__ == "__main__":
main()