Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up hugginface/composer conversion script tests #950

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 91 additions & 68 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,45 +378,7 @@ def test_huggingface_conversion_callback_interval(
delete_transformers_cache()


@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback(
model: str,
tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str,
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
device_batch_size = 1
dataset_size = 2
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name')

def _get_model_info(model: str, tie_word_embeddings: bool, max_seq_len: int):
# get small version of each model
model_cfg = None
tokenizer_name = None
Expand Down Expand Up @@ -473,23 +435,10 @@ def test_huggingface_conversion_callback(
raise ValueError(f'Unknown model {model}')
assert model_cfg is not None
assert tokenizer_name is not None
model_cfg = om.create(model_cfg)

fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'mixed_precision': 'PURE',
'activation_checkpointing': False,
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': False,
'limit_all_gathers': True,
'state_dict_type': fsdp_state_dict_type,
}
return model_cfg, tokenizer_name

tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size)

def _get_dataloader_cfg(tiny_dataset_folder_path: str, max_seq_len: int):
dataloader_cfg = {
'name': 'finetuning',
'dataset': {
Expand All @@ -508,8 +457,95 @@ def test_huggingface_conversion_callback(
'persistent_workers': False,
'timeout': 0
}
return dataloader_cfg


def _get_fsdp_config(fsdp_state_dict_type: Optional[str]):
fsdp_config = {
'sharding_strategy': 'FULL_SHARD',
'mixed_precision': 'PURE',
'activation_checkpointing': False,
'activation_checkpointing_reentrant': False,
'activation_cpu_offload': False,
'limit_all_gathers': True,
'state_dict_type': fsdp_state_dict_type,
}
return fsdp_config


def _get_optimizer_config():
optimizer_config = {
'name': 'decoupled_adamw',
'lr': 6e-4,
'betas': [0.9, 0.95],
'eps': 1e-8,
'weight_decay': 0.0,
}
return optimizer_config


def _get_mlflow_logger_mock():
mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
mlflow_logger_mock._run_id = 'mlflow-run-id'
return mlflow_logger_mock


@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize(
'model,tie_word_embeddings',
[('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)])
@patch('os.cpu_count', MagicMock(return_value=1))
def test_huggingface_conversion_callback(
model: str,
tmp_path: pathlib.Path,
tie_word_embeddings: bool,
fsdp_state_dict_type: Optional[str],
hf_save_interval: str,
save_interval: str,
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))

max_seq_len = 16
device_batch_size = 1
dataset_size = 2
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=hf_save_interval,
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name')

tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size)

model_cfg, tokenizer_name = _get_model_info(model, tie_word_embeddings,
max_seq_len)
model_cfg = om.create(model_cfg)
dataloader_cfg = _get_dataloader_cfg(tiny_dataset_folder_path, max_seq_len)
dataloader_cfg = om.create(dataloader_cfg)
fsdp_config = _get_fsdp_config(fsdp_state_dict_type)
optimizer_config = _get_optimizer_config()

tokenizer = build_tokenizer(
tokenizer_name=tokenizer_name,
Expand All @@ -525,24 +561,11 @@ def test_huggingface_conversion_callback(
original_model = COMPOSER_MODEL_REGISTRY[model_cfg['name']](model_cfg,
tokenizer)

optimizer_config = {
'name': 'decoupled_adamw',
'lr': 6e-4,
'betas': [0.9, 0.95],
'eps': 1e-8,
'weight_decay': 0.0,
}
optimizer_name = optimizer_config.pop('name')
optimizer = build_optimizer(original_model, optimizer_name,
optimizer_config)

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock)
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
mlflow_logger_mock._run_id = 'mlflow-run-id'
mlflow_logger_mock = _get_mlflow_logger_mock()
trainer = Trainer(
model=original_model,
device='gpu',
Expand Down
Loading