Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch from training_args.bin training_args.json #35010

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
a6484c0
capture init parameters in training_args
not-lain Nov 29, 2024
780d3e7
update relevant attributes
not-lain Nov 29, 2024
66c655b
attribute calling for training args
not-lain Nov 29, 2024
6d3a186
add class attribute to load the training_args from a local file
not-lain Nov 29, 2024
2694533
Merge branch 'main' into switch-training_args-file-format
not-lain Nov 29, 2024
5acc885
Update src/transformers/training_args.py
not-lain Nov 29, 2024
110e303
Merge branch 'main' into switch-training_args-file-format
not-lain Dec 9, 2024
9470d77
Merge branch 'huggingface:main' into switch-training_args-file-format
not-lain Dec 11, 2024
79d2a3a
ensure backward compatibility
not-lain Dec 11, 2024
c7567f7
fix version parsing
not-lain Dec 11, 2024
68089ed
fix deprecated filename
not-lain Dec 11, 2024
2a6ee0d
update test
not-lain Dec 12, 2024
8228d0f
format with ruff
not-lain Dec 12, 2024
ba373dc
fix trainer logger tests
not-lain Dec 12, 2024
10b05dd
serialize all parameters
not-lain Dec 12, 2024
7633c80
Update src/transformers/trainer.py
not-lain Dec 24, 2024
e3c9e57
Refactor serialization logic in Trainer to use binary serialization flag
not-lain Dec 27, 2024
0d4366b
Merge branch 'main' into switch-training_args-file-format
not-lain Dec 27, 2024
689348d
format with ruff
not-lain Dec 27, 2024
56bdbca
Merge branch 'switch-training_args-file-format' of https://github.com…
not-lain Dec 27, 2024
2c2198c
switch to dynamic serialization
not-lain Dec 27, 2024
9741c8f
Refactor training_args serialization to only include defined dataclas…
not-lain Dec 27, 2024
ef16016
ensure consistant traininglogs to ensure that tensorboard logs contin…
not-lain Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/
-rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model
-rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json
-rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json
-rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py*
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ja/main_classes/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/
-rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model
-rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json
-rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json
-rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py*
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/
-rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model
-rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json
-rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json
-rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py*
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/zh/main_classes/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ drwxrwxr-x 2 stas stas 4.0K Mar 25 19:52 global_step1/
-rw-rw-r-- 1 stas stas 774K Mar 27 20:42 spiece.model
-rw-rw-r-- 1 stas stas 1.9K Mar 27 20:42 tokenizer_config.json
-rw-rw-r-- 1 stas stas 339 Mar 27 20:42 trainer_state.json
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.bin
-rw-rw-r-- 1 stas stas 2.3K Mar 27 20:42 training_args.json
-rwxrw-r-- 1 stas stas 5.5K Mar 27 13:16 zero_to_fp32.py*
```

Expand Down
4 changes: 2 additions & 2 deletions examples/legacy/question-answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def train(args, train_dataset, model, tokenizer):
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Expand Down Expand Up @@ -792,7 +792,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
Expand Down
4 changes: 2 additions & 2 deletions examples/legacy/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def train(args, train_dataset, model, tokenizer):
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_vocabulary(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

if args.max_steps > 0 and global_step > args.max_steps:
Expand Down Expand Up @@ -678,7 +678,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = AutoModelForMultipleChoice.from_pretrained(args.output_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def train(args, train_dataset, model, tokenizer):
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Expand Down Expand Up @@ -712,7 +712,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
Expand Down
4 changes: 2 additions & 2 deletions examples/research_projects/deebert/run_glue_deebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train(args, train_dataset, model, tokenizer, train_highway=False):
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

if args.max_steps > 0 and global_step > args.max_steps:
Expand Down Expand Up @@ -672,7 +672,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Expand Down Expand Up @@ -836,7 +836,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
Expand Down
4 changes: 2 additions & 2 deletions examples/research_projects/mm-imdb/run_mmimdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def train(args, train_dataset, model, tokenizer, criterion):
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

if args.max_steps > 0 and global_step > args.max_steps:
Expand Down Expand Up @@ -540,7 +540,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = MMBTForClassification(config, transformer, img_encoder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Expand Down Expand Up @@ -927,7 +927,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

torch.save(args, os.path.join(output_dir, "training_args.bin"))
args.to_json_file(os.path.join(output_dir, "training_args.json"))
logger.info("Saving model checkpoint to %s", output_dir)

torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
Expand Down Expand Up @@ -1094,7 +1094,7 @@ def main():
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
args.to_json_file(os.path.join(args.output_dir, "training_args.json"))

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ Having finished the training you should find the following files/folders under t
- `special_tokens_map.json` - the special token map of the tokenizer
- `tokenizer_config.json` - the parameters of the tokenizer
- `vocab.json` - the vocabulary of the tokenizer
- `checkpoint-{...}/` - the saved checkpoints saved during training. Each checkpoint should contain the files: `config.json`, `optimizer.pt`, `pytorch_model.bin`, `scheduler.pt`, `training_args.bin`. The files `config.json` and `pytorch_model.bin` define your model.
- `checkpoint-{...}/` - the saved checkpoints saved during training. Each checkpoint should contain the files: `config.json`, `optimizer.pt`, `pytorch_model.bin`, `scheduler.pt`, `training_args.json`. The files `config.json` and `pytorch_model.bin` define your model.

If you are happy with your training results it is time to upload your model!
Download the following files to your local computer: **`preprocessor_config.json`, `special_tokens_map.json`, `tokenizer_config.json`, `vocab.json`, `config.json`, `pytorch_model.bin`**. Those files fully define a XLSR-Wav2Vec2 model checkpoint.
Expand Down
36 changes: 32 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,21 @@ def safe_globals():


# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINING_ARGS_NAME = "training_args.json"
DEPRECATED_ARGS_NAME = "trainer_state.bin"
not-lain marked this conversation as resolved.
Show resolved Hide resolved
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
OPTIMIZER_NAME_BIN = "optimizer.bin"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
FSDP_MODEL_NAME = "pytorch_model_fsdp"

# Safe serialization check
safe_serialize = os.environ.get("TRAINER_SAFE_SERIALIZE")

if safe_serialize or int(__version__.split(".")[0]) >= 5:
safe_serialize = True

not-lain marked this conversation as resolved.
Show resolved Hide resolved

class Trainer:
"""
Expand Down Expand Up @@ -3830,7 +3837,14 @@ def _save_tpu(self, output_dir: Optional[str] = None):

if xm.is_master_ordinal(local=False):
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if safe_serialize:
self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
logger.info(
f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, "
"you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true' "
)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
Expand Down Expand Up @@ -3927,7 +3941,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
self.processing_class.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if safe_serialize:
self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
logger.info(
f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, "
"you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'"
)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

def store_flos(self):
# Storing the number of floating-point operations that went into the model
Expand Down Expand Up @@ -4637,7 +4658,14 @@ def _push_from_checkpoint(self, checkpoint_folder):
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
# Same for the training arguments
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if safe_serialize:
self.args.to_json_file(os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
logger.info(
f"trainer API will deprecate the {DEPRECATED_ARGS_NAME} in 5.0.0, to switch to a safe serialization method, "
"you can set os.environ['TRAINER_SAFE_SERIALIZE']= 'true'"
)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

if self.args.save_strategy == SaveStrategy.STEPS:
commit_message = f"Training in progress, step {self.state.global_step}"
Expand Down
51 changes: 39 additions & 12 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, TypeVar, Union

from huggingface_hub import get_full_repo_name
from packaging import version
Expand Down Expand Up @@ -67,6 +67,8 @@
log_levels = logging.get_log_levels_dict().copy()
trainer_log_levels = dict(**log_levels, passive=-1)

T = TypeVar("T", bound="TrainingArguments")

if is_torch_available():
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -218,7 +220,6 @@ def _convert_str_dict(passed_value: dict):
return passed_value


# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
@dataclass
class TrainingArguments:
"""
Expand Down Expand Up @@ -2529,25 +2530,33 @@ def to_dict(self):
d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}

for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
# Handle the accelerator_config if passed
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
d[k] = v.to_dict()
d[k] = serialize_parameter(k, v)
self._dict_torch_dtype_to_str(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this and the function since you took care of that in serialize_parameter. Can you add a comment in serialize_parameter to explain what is happening to torch_dtype ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope I did not misunderstand you, let me know if this has been resolved after the new changes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine thanks ! SInce we are not using _dict_torch_dtype_to_str method, you can also remove it in this PR


return d

def to_json_string(self):
"""
Serializes this instance to a JSON string.
Serializes the TrainingArguments into a JSON string.
"""
return json.dumps(self.to_dict(), indent=2)

def to_json_file(self, json_file_path: str):
"""
Save this instance's parameters to a json file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())

@classmethod
def from_json_file(cls: Type[T], json_file_path: str) -> T:
"""
Loads and initializes the TrainingArguments from a json file.
"""
with open(json_file_path, "r", encoding="utf-8") as reader:
params = json.load(reader)
return cls(**params)

def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoard’s hparams
Expand Down Expand Up @@ -3104,3 +3113,21 @@ class ParallelMode(Enum):
SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
TPU = "tpu"


def serialize_parameter(k, v):
if k == "torch_dtype" and not isinstance(v, str):
return str(v).split(".")[1]
if isinstance(v, dict):
return {key: serialize_parameter(key, value) for key, value in v.items()}
if isinstance(v, Enum):
return v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
l = [x.value for x in v]
return l
if k.endswith("_token"):
return f"<{k.upper()}>"
# Handle the accelerator_config if passed
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
return v.to_dict()
return v
2 changes: 1 addition & 1 deletion tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def test_gradient_accumulation(self, stage, dtype):

def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
file_list = [SAFE_WEIGHTS_NAME, "training_args.json", "trainer_state.json", "config.json"]

if stage == ZERO2:
ds_file_list = ["mp_rank_00_model_states.pt"]
Expand Down
7 changes: 6 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,14 +572,19 @@ def get_regression_trainer(
class TrainerIntegrationCommon:
def check_saved_checkpoints(self, output_dir, freq, total, is_pretrained=True, safe_weights=True):
weights_file = WEIGHTS_NAME if not safe_weights else SAFE_WEIGHTS_NAME
file_list = [weights_file, "training_args.bin", "optimizer.pt", "scheduler.pt", "trainer_state.json"]
file_list = [weights_file, "optimizer.pt", "scheduler.pt", "trainer_state.json"]
safe_serialized = ["training_args.bin", "training_args.json"] # default to json in version 5.x.x
not-lain marked this conversation as resolved.
Show resolved Hide resolved
if is_pretrained:
file_list.append("config.json")
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint))
for filename in file_list:
self.assertTrue(os.path.isfile(os.path.join(checkpoint, filename)))
self.assertTrue(
(os.path.isfile(os.path.join(checkpoint, safe_serialized[0])))
or (os.path.isfile(os.path.join(checkpoint, safe_serialized[1])))
)

def check_best_model_has_been_loaded(
self, output_dir, freq, total, trainer, metric, greater_is_better=False, is_pretrained=True, safe_weights=True
Expand Down