diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 8a28b1286cfa..adb8f62a5084 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -139,7 +139,7 @@ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: loss (torch.Tensor): The loss to be backpropagated. optimizer (Optimizer): The optimizer to be updated. """ - # TODO: implement this method with plugin + # TODO(frank lee): implement this method with plugin optimizer.backward(loss) def execute_pipeline(self, diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py index 09cb7bfe1407..577bef076a7e 100644 --- a/colossalai/shardformer/layer/utils.py +++ b/colossalai/shardformer/layer/utils.py @@ -29,8 +29,6 @@ class Randomizer: _INDEX = 0 def __init__(self, seed: int): - # TODO: remove colossalai.context.random - self.seed = seed # Handle CUDA rng state diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index eaafd67b8968..5bd1c531cc68 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -57,7 +57,7 @@ def bert_model_forward( hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, ): - # TODO: add explaination of the output here. + # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if @@ -113,7 +113,7 @@ def bert_model_forward( batch_size, seq_length = input_shape device = hidden_states.device - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -272,7 +272,7 @@ def bert_for_pretraining_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai) left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -534,7 +534,7 @@ def bert_for_next_sentence_prediction_forward( stage_index: Optional[List[int]] = None, **kwargs, ): - #-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: + # -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 57c45bc6adfa..12276635ecfa 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -252,7 +252,7 @@ def custom_forward(*inputs): # Add last hidden state hidden_states = self.ln_f(hidden_states) - # TODO: deal with all_hidden_states, all_self_attentions, presents + # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -307,7 +307,7 @@ def bloom_for_causal_lm_forward(self: BloomForCausalLM, raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -402,7 +402,7 @@ def bloom_for_sequence_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -431,7 +431,7 @@ def bloom_for_sequence_classification_forward( all_cross_attentions = None if stage_manager.is_last_stage(): batch_size = hidden_states.shape[0] - #update batch size + # update batch size hidden_states = transformer_outputs[0] logits = self.score(hidden_states) @@ -525,7 +525,7 @@ def bloom_for_token_classification_forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -611,7 +611,7 @@ def bloom_for_question_answering_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False diff --git a/colossalai/shardformer/modeling/chatglm.py b/colossalai/shardformer/modeling/chatglm.py index a95966c3b99e..409e2e1f5497 100644 --- a/colossalai/shardformer/modeling/chatglm.py +++ b/colossalai/shardformer/modeling/chatglm.py @@ -152,7 +152,7 @@ def chatglm_model_forward( if output_hidden_states is not None else self.config.output_hidden_states) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index a12a9796fa8a..47835d5d5468 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -57,7 +57,7 @@ def gpt2_model_forward( logger = logging.get_logger(__name__) # Preprocess passed in arguments - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 2f54daac586a..f1d2998bbee4 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -65,7 +65,7 @@ def llama_model_forward( seq_length_with_past = seq_length past_key_values_length = 0 - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -216,7 +216,7 @@ def llama_for_causal_lm_forward( if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False @@ -301,7 +301,7 @@ def llama_for_sequence_classification_forward( logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future. if output_attentions: logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') output_attentions = False diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 9afdfff4d71d..b4251f33b457 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -148,7 +148,7 @@ def opt_model_forward( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/t5.py b/colossalai/shardformer/modeling/t5.py index d622da452366..9cc071f91dfc 100644 --- a/colossalai/shardformer/modeling/t5.py +++ b/colossalai/shardformer/modeling/t5.py @@ -50,7 +50,7 @@ def t5_stack_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None @@ -285,7 +285,7 @@ def t5_model_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None @@ -422,7 +422,7 @@ def t5_for_conditional_generation_forward( logger = logging.get_logger(__name__) - # TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future. + # TODO(baizhou): left the recording kv-value tensors as () or None type, this feature may be added in the future. if past_key_values: logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.') past_key_values = None diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index eb0ea4c7502b..9fc0b7488803 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -96,7 +96,7 @@ def pp_forward( if pixel_values is None: raise ValueError("You have to specify pixel_values") - # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + # TODO(FoolPlayer): maybe have a cleaner way to cast the input (from `ImageProcessor` side?) expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype if pixel_values.dtype != expected_dtype: pixel_values = pixel_values.to(expected_dtype) diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index ec6e0cd0d4be..0c28f115d018 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -29,7 +29,6 @@ class ShardConfig: enable_flash_attention: bool = False enable_jit_fused: bool = False - # TODO: add support for tensor parallel # pipeline_parallel_size: int # data_parallel_size: int # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index e29afe786c46..1cd3b90db917 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -15,7 +15,7 @@ def test_gpt(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() - # TODO: support the following models + # TODO(ver217): support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py index 6f24dc9608bd..b5709d1451f2 100644 --- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py @@ -27,7 +27,7 @@ def rearrange(tensor: torch.Tensor, dim: int): return rearanged_tensor -# TODO: solve lazy_init True is not working +# TODO(FoolPlayer): solve lazy_init True is not working @parameterize('lazy_init', [False]) def check_linear_conv_1d_col(lazy_init: bool): ctx = LazyInitContext() if lazy_init else nullcontext() diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py deleted file mode 100644 index 31e76ef5107c..000000000000 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy -import random -from typing import Any, Callable, Iterator, List, Optional, Tuple - -import numpy as np -import pytest -import torch -import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer -from torch.optim.lr_scheduler import _LRScheduler as LRScheduler -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -import colossalai -from colossalai.cluster import ProcessGroupMesh -from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule -from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import ( - assert_hf_output_close, - clear_cache_before_run, - parameterize, - rerun_if_address_is_in_use, - spawn, -) -from tests.kit.model_zoo import model_zoo -from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward - -DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 - - -class PipelineOptimizer(OptimizerWrapper): - - def __init__(self, optim: Optimizer, model: Module): - super().__init__(optim) - params = set(model.parameters()) - new_param_groups = [] - for group in optim.param_groups: - params = [p for p in group['params'] if p in params] - new_param_groups.append({**group, 'params': params}) - optim.__setstate__({'param_groups': new_param_groups}) - # TODO: support amp - - -class PipelinedModel(ModelWrapper): - - def __init__(self, module: Module, shard_config: ShardConfig, stage_manager: PipelineStageManager) -> None: - self.stage_manager = stage_manager - shardformer = ShardFormer(shard_config) - module, self.shared_params = shardformer.optimize(module) - self.shared_param_process_groups = [] - super().__init__(module) - - -def prepare_dataloader(dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0): - sampler = DistributedSampler( - dataset, - # rank=self.pg_mesh.coordinate(DP_AXIS), - shuffle=shuffle) - - # Deterministic dataloader - def seed_worker(worker_id): - worker_seed = seed - np.random.seed(worker_seed) - torch.manual_seed(worker_seed) - random.seed(worker_seed) - - return DataLoader( - dataset, - batch_size=batch_size, - sampler=sampler, - worker_init_fn=seed_worker, - drop_last=drop_last, - pin_memory=pin_memory, - num_workers=num_workers, - ) - - -def execute_pipeline( - data_iter: Iterator, - model: PipelinedModel, - criterion: Callable[[Any, Any], torch.Tensor], - optimizer: PipelineOptimizer, - return_loss: bool = True, - return_outputs: bool = False, - schedule: OneForwardOneBackwardSchedule = None, -) -> dict: - # return loss or outputs if needed - outputs = schedule.forward_backward_step(model, optimizer, data_iter, criterion, return_loss, return_outputs) - return outputs - - -class data_loader(): - - def __getitem__(self, x): - return torch.ones((4, 128), dtype=torch.int).cuda() * 10 - - -def loss(y, x): - return (y[0].float().mean() - x[0].float().mean()) - - -@parameterize('enable_fused_normalization', [False]) -@parameterize('enable_tensor_parallelism', [False]) -@parameterize('use_lazy_init', [False]) -def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): - PP_DIM = 0 - PP_SIZE = 2 - RANK_TO_COORDINATE = { - 0: (0, 0), - 1: (0, 1), - 2: (1, 0), - 3: (1, 1), - } - PP_RANKS_IN_GROUP = { - 0: [0, 1], - 1: [0, 1], - 2: [2, 3], - 3: [2, 3], - } - - pg_mesh = ProcessGroupMesh(PP_SIZE) - stage_manager = PipelineStageManager(pg_mesh, PP_DIM) - sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name != 'transformers_llama': - continue - num_microbatches = 2 - org_model = model_fn().cuda() - data_iter = iter(data_loader()) - - model_copy = copy.deepcopy(org_model) - batch = next(data_iter) - with torch.no_grad(): - y = model_copy(batch) - org_loss = loss(y, batch) - optimizer = torch.optim.AdamW(org_model.parameters(), lr=1e-3) - schedule = OneForwardOneBackwardSchedule(num_microbatches, stage_manager) - shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization, - enable_tensor_parallelism=enable_tensor_parallelism, - pipeline_stage_manager=stage_manager) - pipelined_model = PipelinedModel(org_model, shard_config, stage_manager) - pp_optimizer = PipelineOptimizer(optimizer, pipelined_model) - results = execute_pipeline(data_iter, pipelined_model, loss, pp_optimizer, schedule=schedule) - - if stage_manager.is_last_stage(): - assert results['loss'] == org_loss - else: - assert results['loss'] is None - assert results['outputs'] is None - torch.cuda.empty_cache() - - -def check_llama(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_llama_test() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -@clear_cache_before_run() -def test_llama(): - spawn(check_llama, 2) - - -if __name__ == "__main__": - test_llama() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 145ccf97c388..ed0d1d8e401d 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -101,7 +101,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_bloom_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') diff --git a/tests/test_shardformer/test_model/test_shard_chatglm.py b/tests/test_shardformer/test_model/test_shard_chatglm.py index e9c74b300daa..bb77759048b3 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm.py @@ -125,7 +125,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_chatglm_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm') diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 8b7a6bf29c8b..ca086bf12776 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -110,7 +110,7 @@ def unwrap(module): @clear_cache_before_run() def run_gpt2_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index fa4ee43e3114..30ebdfbe5cd9 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -133,7 +133,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_llama_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 403c3e75f52c..8d1154d82638 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -127,7 +127,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_opt_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index fb065b42250b..066f7ee815b4 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -105,10 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @clear_cache_before_run() def run_t5_test(test_config): - # TODO: add plugin_config for TP+DP after supporting & debugging it + # TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it # {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True} - # TODO: add test_config for flash attention & jit operator after supporting + # TODO(baizhou): add test_config for flash attention & jit operator after supporting sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 919bceffc847..18df8ef555f2 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -124,8 +124,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, }]) def run_vit_test(test_config): - # TODO: add test_config for TP+DP after supporting & debugging it - # TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models + # TODO(baizhou): add test_config for TP+DP after supporting & debugging it + # TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')