Skip to content

Commit

Permalink
[shardformer] fix
Browse files Browse the repository at this point in the history
[example] update opt example

[example] resolve comments

fix

fix
  • Loading branch information
flybird11111 committed Sep 8, 2023
1 parent d25fbde commit e84b267
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 340 deletions.
1 change: 0 additions & 1 deletion colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."

attention_input_shape = (bsz, -1, self.num_heads, self.head_dim)
# get query proj
Expand Down
27 changes: 9 additions & 18 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,15 @@ def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()

dataloader = iter(dataloader)
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
labels = batch["labels"]
batch_size = batch["input_ids"].shape[0]
if use_pipeline:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
# Can't pass dataloader to execute_pipeline directly, Because we need the actual batch size from batch to broadcast output.
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
Expand All @@ -83,30 +80,24 @@ def evaluate_subset(dataloader: DataLoader):
return_outputs=True)

if is_pp_last_stage:
val_loss = outputs["loss"]

logits = outputs["outputs"]["logits"]

val_loss = outputs["loss"]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)
dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group)

metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
object_list = [None, None]
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)

dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)

accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
accum_loss.add_(object_list[1].to(get_current_device()))

else:
batch = move_to_cuda(batch)
Expand Down Expand Up @@ -148,14 +139,14 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:

model.train()
optimizer.zero_grad()
train_dataloader = iter(train_dataloader)
train_dataloader_iter = iter(train_dataloader)
with tqdm(range(total_step),
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
# Forward pass
for _ in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(train_dataloader,
outputs = booster.execute_pipeline(train_dataloader_iter,
model,
_criterion,
optimizer,
Expand All @@ -166,7 +157,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
data = next(train_dataloader)
data = next(train_dataloader_iter)
data = move_to_cuda(data)
outputs = model(**data)
loss = _criterion(outputs, None)
Expand Down
98 changes: 0 additions & 98 deletions examples/language/llama2/data.py

This file was deleted.

Loading

0 comments on commit e84b267

Please sign in to comment.