diff --git a/src/postprocessing/prepare_upload.py b/src/postprocessing/prepare_upload.py index a1508ff..3abb9e5 100644 --- a/src/postprocessing/prepare_upload.py +++ b/src/postprocessing/prepare_upload.py @@ -10,8 +10,6 @@ import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - from safetensors.torch import save_file from tqdm import tqdm @@ -48,9 +46,15 @@ def prepare_upload( ) model_state_dict[k] = v - original_model = AutoModelForCausalLM.from_pretrained(config.pretrained_model_name) - original_model.load_state_dict(model_state_dict) - state_dict = original_model.state_dict() + state_dict = model_state_dict + if config.precision == 32 or config.precision == "32": + safetensors_dtype = torch.float32 + elif config.precision == 16 or config.precision == "16": + safetensors_dtype = torch.float16 + elif config.precision == "bf16": + safetensors_dtype = torch.bfloat16 + else: + raise ValueError(f"Invalid precision type: {config.precision}") keys = list(state_dict.keys()) num_splits = config.num_safetensors split_size = len(keys) // num_splits @@ -72,7 +76,8 @@ def prepare_upload( for i in tqdm(range(num_splits)): safe_tensors_name = f"model-{i+1:05d}-of-{num_splits:05d}.safetensors" part_state_dict = { - k: state_dict[k] for k in keys[i * split_size : (i + 1) * split_size] + k: state_dict[k].to(safetensors_dtype) + for k in keys[i * split_size : (i + 1) * split_size] } part_state_dict_mapping = { k: safe_tensors_name for k in keys[i * split_size : (i + 1) * split_size] @@ -91,15 +96,6 @@ def prepare_upload( f, indent=2, ) - tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_name) - tokenizer.save_pretrained(save_dir) - model_config = AutoConfig.from_pretrained(config.pretrained_model_name) - model_config._name_or_path = ( - f"{config.user_name}/{config.model_type}-{config.upload_tag}" - ) - if config.strategy.startswith("deepspeed"): - model_config.torch_dtype = "float32" - model_config.save_pretrained(save_dir) if __name__ == "__main__":