-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathtestcode003.py
55 lines (42 loc) · 1.46 KB
/
testcode003.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import os
from datetime import datetime
from transformers import HfArgumentParser
from dataclasses import dataclass, field
from typing import Optional
import sys
from transformers import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
def gen_data():
x = torch.randn(10, 10).to('cuda')
return x
def main() -> int:
parser = HfArgumentParser((ModelArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, train_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1]))
else:
model_args, train_args = parser.parse_args_into_dataclasses()
print(model_args)
print(train_args)
gen_data()
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
cur_dateime = datetime.now()
print("*" * 80)
value = torch.cuda.device_count()
print(
f"----> cur_datetime: {cur_dateime},world size: {world_size}, local_rank: {local_rank}, gpu count: {value}")
return value
if __name__ == '__main__':
main()