-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[example] update vit example for hybrid parallel plugin (#4641)
* update vit example for hybrid plugin * reset tp/pp size * fix dataloader iteration bug * update optimizer passing in evaluation/add grad_accum * change criterion * wrap tqdm * change grad_accum to grad_checkpoint * fix pbar
- Loading branch information
Showing
10 changed files
with
248 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,124 +1,82 @@ | ||
from colossalai import get_default_parser | ||
|
||
|
||
def parse_demo_args(): | ||
|
||
parser = get_default_parser() | ||
parser.add_argument( | ||
"--model_name_or_path", | ||
type=str, | ||
default="google/vit-base-patch16-224", | ||
help="Path to pretrained model or model identifier from huggingface.co/models." | ||
) | ||
parser.add_argument( | ||
"--output_path", | ||
type=str, | ||
default="./output_model.bin", | ||
help="The path of your saved model after finetuning." | ||
) | ||
parser.add_argument("--model_name_or_path", | ||
type=str, | ||
default="google/vit-base-patch16-224", | ||
help="Path to pretrained model or model identifier from huggingface.co/models.") | ||
parser.add_argument("--output_path", | ||
type=str, | ||
default="./output_model", | ||
help="The path of your saved model after finetuning.") | ||
parser.add_argument( | ||
"--plugin", | ||
type=str, | ||
default="gemini", | ||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." | ||
) | ||
parser.add_argument( | ||
"--num_epoch", | ||
type=int, | ||
default=3, | ||
help="Number of epochs." | ||
) | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=32, | ||
help="Batch size (per dp group) for the training dataloader." | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=3e-4, | ||
help="Initial learning rate (after the potential warmup period) to use." | ||
) | ||
parser.add_argument( | ||
"--warmup_ratio", | ||
type=float, | ||
default=0.3, | ||
help="Ratio of warmup steps against total training steps." | ||
) | ||
parser.add_argument( | ||
"--weight_decay", | ||
type=float, | ||
default=0.1, | ||
help="Weight decay to use." | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=42, | ||
help="A seed for reproducible training." | ||
) | ||
help= | ||
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." | ||
) | ||
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") | ||
parser.add_argument("--batch_size", | ||
type=int, | ||
default=32, | ||
help="Batch size (per dp group) for the training dataloader.") | ||
parser.add_argument("--tp_size", | ||
type=int, | ||
default=1, | ||
help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") | ||
parser.add_argument("--pp_size", | ||
type=int, | ||
default=1, | ||
help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") | ||
parser.add_argument("--learning_rate", | ||
type=float, | ||
default=3e-4, | ||
help="Initial learning rate (after the potential warmup period) to use.") | ||
parser.add_argument("--warmup_ratio", | ||
type=float, | ||
default=0.3, | ||
help="Ratio of warmup steps against total training steps.") | ||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") | ||
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") | ||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def parse_benchmark_args(): | ||
|
||
parser = get_default_parser() | ||
|
||
parser.add_argument( | ||
"--model_name_or_path", | ||
type=str, | ||
default="google/vit-base-patch16-224", | ||
help="Path to a pretrained model or model identifier from huggingface.co/models." | ||
) | ||
parser.add_argument("--model_name_or_path", | ||
type=str, | ||
default="google/vit-base-patch16-224", | ||
help="Path to a pretrained model or model identifier from huggingface.co/models.") | ||
parser.add_argument( | ||
"--plugin", | ||
type=str, | ||
default="gemini", | ||
help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." | ||
) | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=8, | ||
help="Batch size (per dp group) for the training dataloader." | ||
) | ||
parser.add_argument( | ||
"--num_labels", | ||
type=int, | ||
default=10, | ||
help="Number of labels for classification." | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=5e-5, | ||
help="Initial learning rate (after the potential warmup period) to use." | ||
) | ||
parser.add_argument( | ||
"--weight_decay", | ||
type=float, | ||
default=0.0, | ||
help="Weight decay to use." | ||
) | ||
parser.add_argument( | ||
"--max_train_steps", | ||
type=int, | ||
default=20, | ||
help="Total number of training steps to perform." | ||
) | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=42, | ||
help="A seed for reproducible training." | ||
) | ||
parser.add_argument( | ||
"--mem_cap", | ||
type=int, | ||
default=0, | ||
help="Limit on the usage of space for each GPU (in GB)." | ||
) | ||
help= | ||
"Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." | ||
) | ||
parser.add_argument("--batch_size", | ||
type=int, | ||
default=8, | ||
help="Batch size (per dp group) for the training dataloader.") | ||
parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") | ||
parser.add_argument("--learning_rate", | ||
type=float, | ||
default=5e-5, | ||
help="Initial learning rate (after the potential warmup period) to use.") | ||
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") | ||
parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") | ||
parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") | ||
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") | ||
parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") | ||
args = parser.parse_args() | ||
|
||
return args | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,38 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from datasets import load_dataset | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class BeansDataset(Dataset): | ||
def __init__(self, image_processor, split='train'): | ||
|
||
def __init__(self, image_processor, tp_size=1, split='train'): | ||
|
||
super().__init__() | ||
self.image_processor = image_processor | ||
self.ds = load_dataset('beans')[split] | ||
self.label_names = self.ds.features['labels'].names | ||
while len(self.label_names) % tp_size != 0: | ||
# ensure that the number of labels is multiple of tp_size | ||
self.label_names.append(f"pad_label_{len(self.label_names)}") | ||
self.num_labels = len(self.label_names) | ||
self.inputs = [] | ||
for example in self.ds: | ||
self.inputs.append(self.process_example(example)) | ||
|
||
def __len__(self): | ||
return len(self.inputs) | ||
|
||
def __getitem__(self, idx): | ||
return self.inputs[idx] | ||
|
||
def process_example(self, example): | ||
input = self.image_processor(example['image'], return_tensors='pt') | ||
input['labels'] = example['labels'] | ||
return input | ||
|
||
|
||
def beans_collator(batch): | ||
return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), | ||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)} | ||
return { | ||
'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), | ||
'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.