Skip to content

Commit

Permalink
[shardformer] add bert finetune example
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Sep 1, 2023
1 parent dced3e1 commit db74e72
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
10 changes: 4 additions & 6 deletions examples/language/bert/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,19 @@ def train_dataloader(self):

def val_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["validation"],
batch_size=self.eval_batch_size,
drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

def test_dataloader(self):
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

Expand Down
48 changes: 30 additions & 18 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,22 @@ def evaluate_subset(dataloader: DataLoader):
batch = move_to_cuda(batch)
labels = batch["labels"]
batch_size = batch["input_ids"].shape[0]
if booster.plugin.stage_manager is not None:
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()

batch = iter([batch])

outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)

if dist.get_rank() == dist.get_world_size() - 1:
if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]

#TODO get merged output
Expand All @@ -88,11 +94,19 @@ def evaluate_subset(dataloader: DataLoader):
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
dist.broadcast(preds, src=dist.get_world_size() - 1)

dist.broadcast(preds, src=current_rank)
dist.broadcast(val_loss, src=current_rank)

metric.add_batch(predictions=preds, references=labels)
else:
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
dist.broadcast(preds, src=dist.get_world_size() - 1)

dist.broadcast(preds, src=current_pp_group_ranks[-1])
dist.broadcast(val_loss, src=current_pp_group_ranks[-1])

accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)

else:
Expand Down Expand Up @@ -130,11 +144,13 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):

model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or booster.plugin.stage_manager.is_last_stage())) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
if booster.plugin.stage_manager is not None:
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
Expand All @@ -145,7 +161,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss})
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
Expand Down Expand Up @@ -238,16 +254,12 @@ def main():

cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)

# lazy_init
use_lazy_init = args.use_lazy_init
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
raise RuntimeError
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
raise RuntimeError

# optimizer
no_decay = ["bias", "LayerNorm.weight"]
Expand Down

0 comments on commit db74e72

Please sign in to comment.