Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
Restrict fairscale usage only to FSDP-specific inference example (#256)
Browse files Browse the repository at this point in the history
* Restrict fairscale usage only to FSDP-specific inference example
  • Loading branch information
nikita-smetanin authored Aug 23, 2022
1 parent 40febf7 commit 839c5b8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 31 deletions.
20 changes: 6 additions & 14 deletions esm/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace
import re
import warnings
import urllib
import warnings
from argparse import Namespace
from pathlib import Path

import torch
from fairscale.nn.wrap import wrap
from typing import Dict
from torch import Tensor

import esm
from esm.model.esm2 import ESM2
Expand Down Expand Up @@ -60,6 +58,7 @@ def _download_model_and_regression_data(model_name):
regression_data = None
return model_data, regression_data


def load_model_and_alphabet_hub(model_name):
model_data, regression_data = _download_model_and_regression_data(model_name)
return load_model_and_alphabet_core(model_name, model_data, regression_data)
Expand All @@ -85,6 +84,7 @@ def has_emb_layer_norm_before(model_state):

def _load_model_and_alphabet_core_v1(model_data):
import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here

alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)

if model_data["args"].arch == "roberta_large":
Expand Down Expand Up @@ -162,7 +162,7 @@ def update_name(s):


def _load_model_and_alphabet_core_v2(model_data):
def upgrade_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
def upgrade_state_dict(state_dict):
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
prefixes = ["encoder.sentence_encoder.", "encoder."]
pattern = re.compile("^" + "|".join(prefixes))
Expand All @@ -180,14 +180,6 @@ def upgrade_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]:
alphabet=alphabet,
token_dropout=cfg.token_dropout,
)

# Wrap is for use with FSDP. This falls back to no-op if FSDP is not enabled
for name, child in model.named_children():
if name == "layers":
for layer_name, layer in child.named_children():
wrapped_layer = wrap(layer)
setattr(child, layer_name, wrapped_layer)

return model, alphabet, state_dict


Expand Down
2 changes: 1 addition & 1 deletion esm/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

version = "1.0.1"
version = "1.0.2"
43 changes: 27 additions & 16 deletions examples/esm2_infer_fairscale_fsdp_cpu_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,49 @@

import esm


# init the distributed world with world_size 1
url="tcp://localhost:23456"
torch.distributed.init_process_group(backend='nccl', init_method=url, world_size=1, rank=0)
url = "tcp://localhost:23456"
torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0)

# download model data from the hub
model_data, regression_data = esm.pretrained._download_model_and_regression_data("esm2_t48_15B_UR50D")
model_data, regression_data = esm.pretrained._download_model_and_regression_data(
"esm2_t48_15B_UR50D"
)
if regression_data is not None:
model_data["model"].update(regression_data["model"])
model_data["model"].update(regression_data["model"])

# initialize the model with FSDP wrapper
fsdp_params = dict(
mixed_precision=True,
flatten_parameters=True,
state_dict_device=torch.device("cpu"), # reduce GPU mem usage
cpu_offload=True, # enable cpu offloading
mixed_precision=True,
flatten_parameters=True,
state_dict_device=torch.device("cpu"), # reduce GPU mem usage
cpu_offload=True, # enable cpu offloading
)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
model, vocab, _ = esm.pretrained._load_model_and_alphabet_core_v2(model_data)
batch_converter = vocab.get_batch_converter()
model.eval()
model = wrap(model)
model, vocab, _ = esm.pretrained._load_model_and_alphabet_core_v2(model_data)
batch_converter = vocab.get_batch_converter()
model.eval()

# Wrap each layer in FSDP separately
for name, child in model.named_children():
if name == "layers":
for layer_name, layer in child.named_children():
wrapped_layer = wrap(layer)
setattr(child, layer_name, wrapped_layer)
model = wrap(model)

data = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
("protein3", "K A <mask> I S Q"),
(
"protein2 with mask",
"KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE",
),
("protein3", "K A <mask> I S Q"),
]

batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_tokens = batch_tokens.cuda()
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
print(results)

0 comments on commit 839c5b8

Please sign in to comment.