Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue Loading FSDP wrapped module using FULL_STATE_DICT type. #141

Open
hbikki opened this issue May 3, 2023 · 3 comments
Open

Issue Loading FSDP wrapped module using FULL_STATE_DICT type. #141

hbikki opened this issue May 3, 2023 · 3 comments

Comments

@hbikki
Copy link

hbikki commented May 3, 2023

🐛 Describe the bug

Hello , I am working on training a pretrained hugging face model "t5-small". Using the torchsnpashot examples provided form the documentaion, I am able to save/load checkpoint for LOCAL_STATE_DICT type, I am also able to save the model checkpoint for FULL_STATE_DICT. But, when loading the full statedict checkpoint I am facing the below issue.

Versions:
pytorch = 2.0.0+cu117
torchx-nightly>=2023.3.15
torchsnapshot=0.1.0

Host Details:
The bellow training is tested on a single node with 8 NPROC_PER_NODE.

Code:

Model training code:

def train() -> None:
    init_process_group(backend="nccl")
    torch.cuda.empty_cache()
    torch.cuda.set_device(local_rank())
    model = load_model("t5-small")

    fsdp_model = FSDP(
        model,
        auto_wrap_policy=functools.partial(
            transformer_auto_wrap_policy, transformer_layer_cls={T5Block}
        ),
        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
        device_id=local_rank(),
    )
    <-------training -loop-->
    <-------save_checkpoint-->

stateDictType = FULL_STATE_DICT
related saving/loading code:

  def save_checkpoint() -> None:
        with FSDP.state_dict_type(
            checkpoint.model,
            self.stateDictType):
            Snapshot.take(path=str(save_dir), app_state=app_state)

    def load_checkpoint() -> None:
        with FSDP.state_dict_type(checkpoint.model, self.stateDictType):
            Snapshot(path=str(load_dir)).restore(app_state=app_state)
   

Error stack trace:
https://pastebin.com/ih9qSbwR

.snapshot_metadata for the model on local rank:
https://pastebin.com/t6grkKyX

Does anyone know how to resolve this ? thanks!

@kiukchung
Copy link

@hbikki can you correctly edit the markdown so that the stacktrace displays in a code block? And could you also include the full stack trace (if its too long feel free to paste bin and provide a link here).

@yifuwang
Copy link
Contributor

yifuwang commented May 4, 2023

Hey @hbikki, could you please share the snapshot metadata in question? It's the .snapshot_metadata file under the snapshot folder/prefix in question.

@hbikki
Copy link
Author

hbikki commented May 5, 2023

Hello, Updated the issue with the requested data, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants