diff --git a/.github/workflows/nv-accelerate-v100.yml b/.github/workflows/nv-accelerate-v100.yml index 96d0ef3c9e28..93286b62610a 100644 --- a/.github/workflows/nv-accelerate-v100.yml +++ b/.github/workflows/nv-accelerate-v100.yml @@ -29,7 +29,7 @@ jobs: - name: Install pytorch run: | - pip install -U --cache-dir $TORCH_CACHE torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 + pip install -U --cache-dir $TORCH_CACHE torch torchvision --index-url https://download.pytorch.org/whl/cu118 python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" diff --git a/.github/workflows/nv-inference.yml b/.github/workflows/nv-inference.yml index 579a638b567b..cc55136d1a4b 100644 --- a/.github/workflows/nv-inference.yml +++ b/.github/workflows/nv-inference.yml @@ -57,6 +57,6 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" + #pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6" pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6" diff --git a/.gitignore b/.gitignore index e284c4fd35a1..db6790886cb4 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,10 @@ docs/code-docs/build ## Testing data # Saved checkpoints for testing tests/unit/saved_checkpoint/ + +# HIP files created during AMD compilation +*_hip.cpp +*_hip.h +*.hip +*.cuh +*hip_layers.h diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index 8ba652a8de92..3c5d799e293e 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -17,6 +17,18 @@ def __init__(self): def is_synchronized_device(self): ... + @abc.abstractmethod + def use_host_timers(self): + ... + + @abc.abstractmethod + def resolves_data_dependency(self): + ... + + @abc.abstractmethod + def handles_memory_backpressure(self): + ... + # Device APIs @abc.abstractmethod def device_name(self, device_index): diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index a02777f5223b..b1aba75b4c5e 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -21,6 +21,15 @@ def __init__(self): def is_synchronized_device(self): return True + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): return 'cpu' diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index c6285e8bc97a..2030f36631e9 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -42,6 +42,15 @@ def _init_pynvml(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index 120e038dd227..30b115e8b1ab 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -30,6 +30,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return False + + def resolves_data_dependency(self): + return True + + def handles_memory_backpressure(self): + return True + def device_name(self, device_index=None): if device_index is None: return 'hpu' @@ -147,7 +156,7 @@ def is_fp16_supported(self): def supported_dtypes(self): supported_dtypes = [torch.float, torch.bfloat16] if self.is_fp16_supported(): - supported_dtypes.append(torch.bfloat16) + supported_dtypes.append(torch.half) return supported_dtypes # Misc diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index f6303cf9890f..972b33caece1 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -24,6 +24,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 4e20445d9d32..472157e32c02 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -28,6 +28,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index is None: diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index c3d6630cc235..3f65263946ab 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -19,6 +19,15 @@ def __init__(self): def is_synchronized_device(self): return False + def use_host_timers(self): + return self.is_synchronized_device() + + def resolves_data_dependency(self): + return self.is_synchronized_device() + + def handles_memory_backpressure(self): + return self.is_synchronized_device() + # Device APIs def device_name(self, device_index=None): if device_index == None: diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index 32b0e8a32394..0f2895dfa328 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -268,6 +268,10 @@ void report_file_error(const char* filename, const std::string file_op, const in int open_file(const char* filename, const bool read_op) { const int flags = read_op ? (O_RDONLY | O_DIRECT) : (O_WRONLY | O_CREAT | O_DIRECT); +#if defined(__ENABLE_CANN__) + int* flags_ptr = (int*)&flags; + *flags_ptr = read_op ? (O_RDONLY) : (O_WRONLY | O_CREAT); +#endif const int mode = 0600; const auto fd = open(filename, flags, mode); if (fd == -1) { diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a6feae06974b..6c7aa8b15ef9 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -12,10 +12,14 @@ from torch.optim.lr_scheduler import _LRScheduler from packaging import version as pkg_version -try: - import triton # noqa: F401 # type: ignore - HAS_TRITON = True -except ImportError: +# Skip Triton import for AMD due to pytorch-triton-rocm module breaking device API in DeepSpeed +if not (hasattr(torch.version, 'hip') and torch.version.hip is not None): + try: + import triton # noqa: F401 # type: ignore + HAS_TRITON = True + except ImportError: + HAS_TRITON = False +else: HAS_TRITON = False from . import ops diff --git a/deepspeed/inference/v2/allocator.py b/deepspeed/inference/v2/allocator.py index bebdcf83aee3..fcc0d94c0f82 100644 --- a/deepspeed/inference/v2/allocator.py +++ b/deepspeed/inference/v2/allocator.py @@ -5,17 +5,27 @@ from functools import reduce from typing import Iterable - +from collections import defaultdict import torch from deepspeed.accelerator import get_accelerator -def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: - shape_size = reduce(lambda x, y: x * y, shape) - if shape_size == 0: - raise ValueError("Cannot create empty tensor with size 0") - return tensor.flatten()[:shape_size].view(shape) +class Allocator: + cache = defaultdict(dict) + + def empty_from(tensor: torch.Tensor, shape: Iterable[int]) -> torch.Tensor: + try: + return Allocator.cache[tensor][shape] + except KeyError: + shape_size = reduce(lambda x, y: x * y, shape) + if shape_size == 0: + raise ValueError("Cannot create empty tensor with size 0") + Allocator.cache[tensor][shape] = tensor.flatten()[:shape_size].view(shape) + return Allocator.cache[tensor][shape] + + +empty_from = Allocator.empty_from def on_device(method) -> torch.Tensor: diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0aefd1946c36..aaa836bf1c31 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,7 +18,7 @@ align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) -from deepspeed.utils import link_hp_params, fragment_address +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -39,7 +39,8 @@ def __init__(self, dp_process_group=None, timers=None, grad_acc_dtype=None, - graph_harvesting=False): + graph_harvesting=False, + immediate_grad_update=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers @@ -50,6 +51,7 @@ def __init__(self, assert grad_acc_dtype in [torch.float32, torch.bfloat16 ], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}" self.grad_acc_dtype = grad_acc_dtype + self.immediate_grad_update = immediate_grad_update self.clip_grad = clip_grad self.norm_type = norm_type @@ -163,8 +165,12 @@ def _setup_for_real_optimizer(self): self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) + if self.immediate_grad_update: + self.create_grad_acc_hooks() + # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._hp_optimizer_states_linked = False self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -199,9 +205,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. @@ -215,8 +227,6 @@ def initialize_optimizer_states(self): param_partition.grad = grad_partition.to( param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition - self.optimizer.step() - if self.grad_acc_dtype is not torch.float32: for param_partition in self.fp32_groups_flat_partition: param_partition.grad = None @@ -263,6 +273,9 @@ def step(self, closure=None): self.optimizer.step() + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + self.update_lp_params() self.clear_hp_grads() @@ -283,27 +296,37 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg self.update_hp_grads(clear_lp_grads=clear_lp_grads) @torch.no_grad() - def update_hp_grads(self, clear_lp_grads=False): + def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): + if lp.grad is None: + return - def _update_hp_grads_func(clear_lp_grads=False): - for i, group in enumerate(self.bf16_groups): - for j, lp in enumerate(group): - if lp.grad is None: - continue - hp_grad = self.fp32_groups_gradients[i][j] - assert hp_grad is not None, \ - f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]' - hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) - lp._hp_grad = hp_grad - self.fp32_groups_has_gradients[i][j] = True - # clear gradients - if clear_lp_grads: - lp.grad._zero() + hp_grad = self.fp32_groups_gradients[group_idx][param_idx] + assert hp_grad is not None, \ + f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]' + + hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape)) + lp._hp_grad = hp_grad + self.fp32_groups_has_gradients[group_idx][param_idx] = True + + # clear gradients + if clear_lp_grads: + lp.grad._zero() + + @torch.no_grad() + def _update_hp_grads_func(self, clear_lp_grads=False): + for i, group in enumerate(self.bf16_groups): + for j, lp in enumerate(group): + self._update_hp_grad(lp, i, j, clear_lp_grads) + + @torch.no_grad() + def update_hp_grads(self, clear_lp_grads=False): + if self.immediate_grad_update: + return if self.graph_harvesting: - graph_process(False, _update_hp_grads_func, clear_lp_grads) + graph_process(False, self._update_hp_grads_func, clear_lp_grads) else: - _update_hp_grads_func(clear_lp_grads) + self._update_hp_grads_func(clear_lp_grads) #cpu op for i, group in enumerate(self.bf16_groups): for j, lp in enumerate(group): @@ -441,6 +464,28 @@ def _load_hp_checkpoint_state(self, checkpoint_dir): lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank, tp_world_size) + def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): + assert self.immediate_grad_update + self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) + + def create_grad_acc_hooks(self): + self.grad_accs = [] + for i, param_group in enumerate(self.bf16_groups): + for j, param in enumerate(param_group): + if param.requires_grad: + + def wrapper(param, i, j): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + def accumulate_hp_grads_and_remove_lp(*notneeded): + self.accumulate_hp_grads_and_remove_lp(param, i, j) + + grad_acc.register_hook(accumulate_hp_grads_and_remove_lp) + self.grad_accs.append(grad_acc) + + wrapper(param, i, j) + def _get_padded_tensor(src_tensor, size): if src_tensor.numel() >= size: diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index 603f563fca60..b2b612c85180 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -13,7 +13,7 @@ def is_compile_supported(): - return hasattr(torch, "compile") + return hasattr(torch, "compiler") def disable(func): diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 20fbf475ca90..975fb1f21501 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -169,6 +169,14 @@ def get_bfloat16_enabled(param_dict): return False +def get_bfloat16_immediate_grad_update(param_dict): + for key in [BFLOAT16, BFLOAT16_OLD]: + if key in param_dict.keys(): + return get_scalar_param(param_dict[key], BFLOAT16_IMMEDIATE_GRAD_UPDATE, + BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT) + return False + + def get_fp16_master_weights_and_grads_enabled(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) @@ -818,6 +826,7 @@ def _initialize_params(self, param_dict): self.fp16_enabled = get_fp16_enabled(param_dict) self.fp16_auto_cast = get_fp16_auto_cast(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + self.bfloat16_immediate_grad_update = get_bfloat16_immediate_grad_update(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 82d8a0557a41..679230ca7d4c 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -126,6 +126,10 @@ BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False +# BFLOAT16 optimizer immediate gradient update +BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" +BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False + ######################################### # FP16 support ######################################### diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index cb0d366ce798..f8ce0e3fa2bf 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -13,7 +13,7 @@ from torch.utils.data import BatchSampler, SequentialSampler, DataLoader, Subset from deepspeed.utils import logger -from .indexed_dataset import MMapIndexedDataset +from .indexed_dataset import MMapIndexedDataset, valid_dtypes from .utils import split_dataset, split_index, create_mmap_dataset_builder, close_mmap_dataset_builder, find_fit_int_dtype @@ -36,7 +36,8 @@ def __init__(self, custom_map_init=None, custom_map_update=None, custom_map_finalize=None, - custom_reduce=None): + custom_reduce=None, + sample_indices=None): super().__init__() self.dataset = dataset self.num_workers = num_workers @@ -55,15 +56,14 @@ def __init__(self, self.custom_map_update = custom_map_update self.custom_map_finalize = custom_map_finalize self.custom_reduce = custom_reduce + self.sample_indices = sample_indices def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtypes, save_path, worker_id): metric_results = [] for m_idx in range(len(metric_names)): metric_name, metric_type, metric_dtype = metric_names[m_idx], \ metric_types[m_idx], metric_dtypes[m_idx] - assert metric_dtype not in [ - np.float64, np.double - ], "Currently floating point metric values are not supported. Please change your metric into integer values (and potentially multiply a larger coefficient to keep the precision)." + assert metric_dtype in valid_dtypes, f"metric_dtype {metric_dtype} not supported. Supported dtypes {valid_dtypes}" metric_save_path = f"{save_path}/{metric_name}/worker{worker_id}_thread{thread_id}/" os.makedirs(metric_save_path, exist_ok=True) if metric_type == 'single_value_per_sample': @@ -84,16 +84,34 @@ def init_metric_results(self, thread_id, metric_names, metric_types, metric_dtyp metric_results.append({"metric_value": metric_value, "metric_value_fname": metric_value_fname}) return metric_results - def update_metric_results(self, data, metric_types, metric_functions, metric_results): + def update_metric_results(self, + data, + metric_types, + metric_dtypes, + metric_functions, + metric_results, + batch_start_idx=0): for m_idx in range(len(metric_types)): - metric_type, metric_function, metric_result = metric_types[m_idx], \ - metric_functions[m_idx], metric_results[m_idx] + metric_type, metric_dtype, metric_function, metric_result = metric_types[m_idx], \ + metric_dtypes[m_idx], metric_functions[m_idx], metric_results[m_idx] + metric_values = metric_function(data) + + assert torch.is_tensor(metric_values) or isinstance(metric_values, np.ndarray), \ + "metric_function must return a tensor or array" + assert metric_values.dtype == metric_dtype, \ + f"metric_function result dtype {metric_values.dtype} does not match metric_dtype {metric_dtype}" + if isinstance(metric_values, np.ndarray): + metric_values = torch.from_numpy(metric_values) + if metric_type == 'single_value_per_sample': - metric_values = metric_function(data) for row in range(metric_values.size()[0]): + sample_idx = batch_start_idx + row # sample idx following dataset iteration order + if 'index' in data: # Megatron use case, sample idx provided in 'index' field + sample_idx = data['index'][row][0].item() + elif self.sample_indices is not None: # user defined shuffling of indices + sample_idx = self.sample_indices[sample_idx] metric_result["sample_to_metric_builder"].add_item(metric_values[row].reshape(-1)) - metric_result["metric_to_sample_dict"][metric_values[row].item()].append( - data['index'][row][0].item()) + metric_result["metric_to_sample_dict"][metric_values[row].item()].append(sample_idx) for m_value in metric_result["metric_to_sample_dict"]: if len(metric_result["metric_to_sample_dict"][m_value]) > 100: metric_fname = metric_result["metric_to_sample_fname"] @@ -102,7 +120,6 @@ def update_metric_results(self, data, metric_types, metric_functions, metric_res writer.writerows([metric_result["metric_to_sample_dict"][m_value]]) metric_result["metric_to_sample_dict"][m_value] = [] elif metric_type == 'accumulate_value_over_samples': - metric_values = metric_function(data) if metric_result["metric_value"] is None: metric_result["metric_value"] = metric_values else: @@ -136,15 +153,12 @@ def run_map_helper(self, thread_id): f"on data subset {start_idx} to {end_idx}") thread_dataset = Subset(self.dataset, list(range(start_idx, end_idx))) sampler = BatchSampler(SequentialSampler(thread_dataset), batch_size=self.batch_size, drop_last=False) - if self.collate_fn is None: - iterator = iter(DataLoader(thread_dataset, batch_sampler=sampler, num_workers=0, pin_memory=False)) - else: - iterator = iter( - DataLoader(thread_dataset, - batch_sampler=sampler, - num_workers=0, - collate_fn=self.collate_fn, - pin_memory=False)) + iterator = iter( + DataLoader(thread_dataset, + batch_sampler=sampler, + num_workers=0, + collate_fn=self.collate_fn, + pin_memory=False)) if self.custom_map_init is None: metric_results = self.init_metric_results(thread_id, self.metric_names, self.metric_types, self.metric_dtypes, self.save_path, self.worker_id) @@ -157,10 +171,13 @@ def run_map_helper(self, thread_id): while True: try: data = next(iterator) + batch_start_idx = start_idx + processed_sample if self.custom_map_update is None: - self.update_metric_results(data, self.metric_types, self.metric_functions, metric_results) + self.update_metric_results(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) else: - self.custom_map_update(data, self.metric_types, self.metric_functions, metric_results) + self.custom_map_update(data, self.metric_types, self.metric_dtypes, self.metric_functions, + metric_results, batch_start_idx) processed_sample += self.batch_size duration = (time.time() - start) / 3600.0 remain_duration = duration * total_sample / processed_sample - duration diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py index ef845e4bc490..100bef3f7946 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py @@ -119,9 +119,15 @@ def set_custom_curriculum_learning_schedule(self, schedule_func_dict): if metric in schedule_func_dict: self.curriculum_schedulers[metric].set_custom_get_difficulty(schedule_func_dict[metric]) - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.micro_batch_size - end_idx = start_idx + self.micro_batch_size + def get_start_end_idx(self, batch_len=None): + """ + given the length of a minibatch (defaults to micro-batch size * data_parallel_size), + return the start and end indices of the current data parallel rank + """ + batch_len = batch_len or self.micro_batch_times_data_parallel_size + start_idx_fn = lambda r: round(r * batch_len / self.data_parallel_group.size()) + start_idx = start_idx_fn(self.data_parallel_rank) + end_idx = start_idx_fn(self.data_parallel_rank + 1) return start_idx, end_idx def get_sample_based_on_metric_value(self, metric, value_start, value_end): @@ -281,12 +287,17 @@ def get_next_global_batch(self): for cidx in range(len(samples_per_cluster)): batch += self.get_sample_from_cluster(cidx, samples_per_cluster[cidx]) self.np_rng.shuffle(batch) + + # broadcast tensor must have same shape across participants. So we fill batch with -1s when not full + assert len(batch) <= self.global_batch_size + batch += [-1] * (self.global_batch_size - len(batch)) batch = torch.tensor(batch, device=get_accelerator().current_device_name(), dtype=torch.long).view(-1) else: batch = torch.empty(self.global_batch_size, device=get_accelerator().current_device_name(), dtype=torch.long) dist.broadcast(batch, 0, group=self.data_parallel_group) + batch = batch[batch != -1] # remove trailing -1s used to fill incomplete batch tensor self.batch = batch.tolist() def __iter__(self): @@ -297,7 +308,7 @@ def __iter__(self): self.batch = self.batch[self.micro_batch_times_data_parallel_size:] if len(current_batch) == self.micro_batch_times_data_parallel_size or \ (len(current_batch) > 0 and not self.drop_last): - start_idx, end_idx = self.get_start_end_idx() + start_idx, end_idx = self.get_start_end_idx(len(current_batch)) yield current_batch[start_idx:end_idx] self.consumed_samples += len(current_batch) current_batch = [] diff --git a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py index 60115fa6efef..7a6963bc27eb 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/indexed_dataset.py @@ -98,25 +98,26 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) +# valid metric_dtypes as numpy and torch types dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float64, - 7: np.double, - 8: np.uint16, - 9: np.uint32, - 10: np.uint64 + 1: (np.uint8, torch.uint8), + 2: (np.int8, torch.int8), + 3: (np.int16, torch.int16), + 4: (np.int32, torch.int32), + 5: (np.int64, torch.int64), + 6: (np.uint16, None), + 7: (np.uint32, None), + 8: (np.uint64, None), } +valid_dtypes = set([dt[0] for dt in dtypes.values()] + [dt[1] for dt in dtypes.values() if dt[1] is not None]) + def code(dtype): - for k in dtypes.keys(): - if dtypes[k] == dtype: - return k - raise ValueError(dtype) + for c, (np_dt, torch_dt) in dtypes.items(): + if dtype in [np_dt, torch_dt]: + return c + raise ValueError(f"{dtype} not supported. Supported types: {valid_dtypes}") def index_file_path(prefix_path): @@ -153,7 +154,7 @@ def read_index(self, path): version = f.read(8) assert struct.unpack(' None: @@ -1478,7 +1477,8 @@ def _configure_bf16_optimizer(self, optimizer): dp_process_group=self.seq_data_parallel_group, timers=timers, grad_acc_dtype=self.get_data_types()[1], - graph_harvesting=self.graph_harvesting()) + graph_harvesting=self.graph_harvesting(), + immediate_grad_update=self._config.bfloat16_immediate_grad_update) return optimizer diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 82f200fccf9f..d7a35b7dbbe9 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -363,44 +363,53 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) - max_norm = float(max_norm) norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(p.grad.data.abs().max() for p in parameters) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for p in parameters: + all_norms.append(p.grad.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + origin_device = total_norm.device.type + total_norm = total_norm.to(get_accelerator().device_name()) # Take max across all GPUs. if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item() + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) else: total_norm = 0 for p in parameters: if mpu is not None: if (mpu.get_model_parallel_rank() == 0) or is_model_parallel_parameter(p): - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item()**norm_type + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) else: - param_norm = p.grad.data.float().norm(norm_type) - total_norm += param_norm.item()**norm_type - + param_norm = p.grad.data.detach().float().norm(norm_type) + all_norms.append(param_norm) + if len(all_norms) > 0: + total_norm = torch.stack(all_norms).square().sum().float() + else: + total_norm = torch.FloatTensor([0.0]).to(parameters[0].device) + origin_device = total_norm.device.type + total_norm = total_norm.to(get_accelerator().device_name()) # Sum across all model parallel GPUs. - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + total_norm = total_norm.pow(1. / norm_type) # Need to average total_norm across different GPUs due to the presence of moe params pg = groups._get_data_parallel_group() scaled_norm = total_norm * 1.0 / float(dist.get_world_size(group=pg)) + scaled_norm_tensor = scaled_norm - scaled_norm_tensor = get_accelerator().FloatTensor([float(scaled_norm)]) dist.all_reduce(scaled_norm_tensor, group=pg) - total_norm = scaled_norm_tensor.item() + total_norm = scaled_norm_tensor + total_norm = total_norm.to(origin_device) + max_norm = torch.tensor([float(max_norm)], device=parameters[0].device) clip_coef = max_norm / (total_norm + 1e-6) - if clip_coef < 1: - for p in parameters: - p.grad.data.mul_(clip_coef) + tmp_tensor = torch.tensor([1.0], device=parameters[0].device) + clip_coef = torch.max(tmp_tensor, clip_coef) + for p in parameters: + p.grad.data.mul_(clip_coef) return total_norm diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 99a9d100082b..5cf655d8741a 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1008,9 +1008,10 @@ def _update_persist_config(self, ds_config): def _zero_init_param(self, param): self._convert_to_deepspeed_param(param) if dist.get_world_group() == self.get_dp_process_group(): - dist.broadcast(param, 0, self.get_dp_process_group()) + dist.broadcast(param.data, 0, self.get_dp_process_group()) else: - dist.broadcast(param, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group()) + dist.broadcast(param.data, dist.get_global_rank(self.get_dp_process_group(), 0), + self.get_dp_process_group()) param.partition() def _convert_to_zero_parameters(self, param_list): @@ -2177,7 +2178,7 @@ def __exit__(self, *exc): self.params[0].partition(param_list=self.params, has_been_updated=False) return - handles = [dist.broadcast(p, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] for h in handles: h.wait() self.params[0].partition(param_list=self.params, has_been_updated=True) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index cfeae9e7839a..8fc962c4f2a7 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -308,13 +308,13 @@ def fetch_sub_module(self, current_submodule: Module, forward: bool) -> None: self.__inflight_param_registry.pop(param).wait() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().handles_memory_backpressure(): event = get_accelerator().Event() event.record() self.__ongoing_fetch_events.append(event) assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().current_stream().wait_stream(self.__allgather_stream) self.__profiler.stop_event(wait_event_name, wait_numel) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b76b781346e7..42008236a9ea 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -198,7 +198,7 @@ def __init__( # backup fused_adam optimizer init if self.offload_optimizer and self.partial_offload != 1.0: - backup_gpu_tensor = torch.randn(1, device='cuda').to(self.dtype) + backup_gpu_tensor = torch.randn(1, device=get_accelerator().device_name()).to(self.dtype) backup_gpu_param = torch.nn.Parameter(backup_gpu_tensor) assert type(init_optimizer) == DeepSpeedCPUAdam, 'Hybrid Optimizer Only Supports DeepSpeedCPUAdam' self.backup_optimizer = FusedAdam([backup_gpu_param], @@ -1015,10 +1015,6 @@ def initialize_optimizer_states(self): else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) - # Initialize the optimizer states with the flattened fp32 partition. - if not is_adagrad: - self._optimizer_step(i) - if swappable_param_subgroup: self._partitioned_params_swap_out(i) @@ -1087,7 +1083,7 @@ def independent_gradient_partition_epilogue(self): self.__reduce_and_partition_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() for param_id in self.params_already_reduced.keys(): @@ -1231,7 +1227,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param): @instrument_w_nvtx @torch.no_grad() def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.wait_stream(get_accelerator().default_stream()) if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel() <= self.reduce_bucket_size: @@ -1280,7 +1276,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: self.params_in_ipg_bucket.clear() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().handles_memory_backpressure(): event = get_accelerator().Event() event.record() self.param_reduce_events.append(event) @@ -2153,7 +2149,7 @@ def has_overflow(self, partition_gradients=True): overflow_gpu = self.inf_or_nan_tracker.clone().to(torch.uint8) self.inf_or_nan_tracker.zero_() - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): get_accelerator().default_stream().wait_stream(self.reduce_and_partition_stream) dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group) @@ -2224,7 +2220,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: """get fp32 gradient partition dictionary accessed as grad_dict[parameter_group_index][parameter_index] """ - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() grad_dict = collections.defaultdict(dict) if self.offload_optimizer: @@ -2254,7 +2250,7 @@ def get_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() if self.offload_optimizer: @@ -2266,7 +2262,7 @@ def get_fp32_grad_for_param(self, param) -> Tensor: return self._fp32_state_allgather(param, fp32_grad) def _get_fp32_opt_state_partition(self, param, optim_state_key=None): - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] @@ -2323,7 +2319,7 @@ def get_local_fp32_grad_for_param(self, param) -> Tensor: if not param.requires_grad: return None - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): self.reduce_and_partition_stream.synchronize() if self.offload_optimizer: diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 56607b349ae7..18b58403f1d7 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,7 +28,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -88,6 +88,12 @@ def _get_padded_tensor(src_tensor, size): return padded_tensor +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) + return padded_tensor + + class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -536,6 +542,8 @@ def __init__(self, see_memory_usage(f"After initializing ZeRO optimizer", force=True) self._link_all_hp_params() + self._hp_optimizer_states_linked = False + self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -578,9 +586,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def is_moe_group(self, group): return 'moe' in group and group['moe'] @@ -664,8 +678,6 @@ def initialize_optimizer_states(self): # which do lazy initialization of the state at the first call to step. if isinstance(self.optimizer, torch.optim.Adagrad): self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) - else: - self.optimizer.step() if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: @@ -744,7 +756,8 @@ def independent_gradient_partition_epilogue(self): self.params_already_reduced[i] = False if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear previously reduced grads of other partitions self._clear_previous_reduced_grads() @@ -1020,7 +1033,7 @@ def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, di def average_tensor(self, tensor): if self.overlap_comm: stream = self.reduction_stream - if not get_accelerator().is_synchronized_device(): + if not get_accelerator().resolves_data_dependency(): stream.wait_stream(get_accelerator().current_stream()) else: stream = get_accelerator().current_stream() @@ -1501,7 +1514,8 @@ def _clear_previous_reduced_grads(self): def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): process_group = self.dp_process_group if process_group is None else process_group if self.overlap_comm: - get_accelerator().synchronize() + if not get_accelerator().resolves_data_dependency(): + get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions self._clear_previous_reduced_grads() stream = self.reduction_stream @@ -1791,6 +1805,9 @@ def _optimizer_step(self, group_no): self.optimizer.step() self.optimizer.param_groups = original_param_groups + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + def step(self, closure=None): """ Not supporting closure. @@ -2197,7 +2214,7 @@ def refresh_fp32_params(self): # Extract optimizer state for current partition from merged states of all partitions def _partition_base_optimizer_state(self, state_key, all_partition_states, group_id): partition_id = dist.get_rank(group=self.real_dp_process_group[group_id]) - alignment = dist.get_world_size(group=self.real_dp_process_group[group_id]) + alignment = self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[group_id]) if torch.is_tensor(all_partition_states[0]): flat_merged_partitions = self.flatten_dense_tensors_aligned(all_partition_states, alignment) dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions, group_id) @@ -2206,19 +2223,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group # Assume non-tensor states are not partitioned and equal across ranks, so return first one return all_partition_states[0] - def _restore_base_optimizer_state(self, base_optimizer_group_states): + def _restore_step_from_elastic_checkpoint(self, all_state_dict): + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + for sd in all_state_dict), "State dicts of all partitions must have the same step value" + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): if type(base_optimizer_group_states) == dict: base_optimizer_group_states = base_optimizer_group_states['state'] + + saved_keys = base_optimizer_group_states[0].keys() + for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - dst_tensor = self.optimizer.state[p][key] - src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) - self.optimizer.state[p][key].data.copy_(src_tensor.data) + padding = 0 if group_paddings is None else group_paddings[i] + for key in saved_keys: + saved = base_optimizer_group_states[i][key] + + if torch.is_tensor(saved): + if key in self.optimizer.state[p]: + dst_tensor = self.optimizer.state[p][key] + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) + self.optimizer.state[p][key].data.copy_(src_tensor.data) + else: + self.optimizer.state[p][key] = _pad_tensor_by_size( + saved, padding, torch.float32, + torch.device('cpu') if self.cpu_offload else self.device) else: self.optimizer.state[p][key] = saved + for param_group in self.optimizer.param_groups: + param_group['step'] = base_optimizer_state_step + def get_ep_ranks(self, rank=0, group_name=None): from deepspeed.utils import groups expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) @@ -2246,15 +2283,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) base_optimizer_group_states.append(partition_states) - self._restore_base_optimizer_state(base_optimizer_group_states) - - # Restore step - if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]: - assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for sd in all_state_dict), "State dicts of all partitions must have the same step value" - loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for param_group in self.optimizer.param_groups: - param_group['step'] = loaded_param_groups_step + self._restore_base_optimizer_state(base_optimizer_group_states, + self._restore_step_from_elastic_checkpoint(all_state_dict), None) def load_state_dict(self, state_dict_list, @@ -2366,7 +2396,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._restore_elastic_base_optimizer_state(state_dict_list) else: # loading an elastic checkpoint into rigid exec - self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], + current_rank_sd[GROUP_PADDINGS]) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 1f86306aefec..33ea8ba60818 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -17,6 +17,6 @@ from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter -from .mixed_precision_linkage import link_hp_params +from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index b1afa8f00aa3..7dea6ba322db 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -9,13 +9,19 @@ def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group): + param_group_index, partition_start, partition_size, dp_group): local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group) for lp_param, lp_start in local_lp_param_and_offset: lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, - partition_start, partition_size, partition_optimizer_state) + partition_start, partition_size) + + +def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state): + for lp in lp_param_list: + if lp._hp_mapping is not None: + lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition]) def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 5f94070dc4c7..49eefafcfbcc 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -21,11 +21,11 @@ class tensor_fragment: lp_fragment_address: fragment_address hp_fragment: torch.Tensor hp_fragment_address: fragment_address - optim_fragment: Dict gradient_dict: Dict offload_gradient_dict: Dict use_offload: bool param_group_index: int + optim_fragment: Dict = None def update_hp(self): self.hp_fragment.data.copy_(self.lp_fragment.data) @@ -39,6 +39,13 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def set_optim_state_fragment(self, flat_hp_partition, optim_fragment): + self.optim_fragment = { + key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel) + for key, value in optim_fragment.items() + if torch.is_tensor(value) and value.shape == flat_hp_partition.shape + } + def get_hp_fragment_address(self): return self.hp_fragment_address @@ -255,7 +262,7 @@ def safe_set_local_fp32_param(param, value): def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, optimizer_state_dict): + param_group_index, partition_start, partition_size): lp_end = lp_param.numel() + lp_start hp_start = partition_start hp_end = partition_start + partition_size @@ -268,11 +275,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict fragment_numel = fragment_end - fragment_start hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) - optim_fragment = { - key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) - for key, value in optimizer_state_dict.items() - if torch.is_tensor(value) and value.shape == flat_hp_partition.shape - } lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) @@ -281,7 +283,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, - optim_fragment=optim_fragment, gradient_dict=gradient_dict, offload_gradient_dict=offload_gradient_dict, use_offload=use_offload, diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index 4aac4ded1243..11ef54fe4665 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -50,7 +50,7 @@ def __init__(self, name): self.name_ = name self.started_ = False self.event_timers = [] - self.use_host_timer = get_accelerator().is_synchronized_device() + self.use_host_timer = get_accelerator().use_host_timers() self.start_event = None self.elapsed_records = None self.start_time = 0.0 diff --git a/deepspeed/utils/zero_to_fp32.py b/deepspeed/utils/zero_to_fp32.py index 49b846633d6e..24cc342e78d1 100755 --- a/deepspeed/utils/zero_to_fp32.py +++ b/deepspeed/utils/zero_to_fp32.py @@ -191,7 +191,7 @@ def parse_optim_states(files, ds_checkpoint_dir): return zero_stage, world_size, fp32_flat_groups -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters): """ Returns fp32 state_dict reconstructed from ds checkpoint @@ -211,9 +211,11 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters) def _zero2_merge_frozen_params(state_dict, zero_model_states): @@ -326,7 +328,8 @@ def zero2_align(x): print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -335,7 +338,8 @@ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero2_merge_frozen_params(state_dict, zero_model_states) + if not exclude_frozen_parameters: + _zero2_merge_frozen_params(state_dict, zero_model_states) _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -444,7 +448,8 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states, + exclude_frozen_parameters): state_dict = OrderedDict() # buffers @@ -453,7 +458,8 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer if debug: print(f"added {len(buffers)} buffers") - _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + if not exclude_frozen_parameters: + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) @@ -465,7 +471,7 @@ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zer return state_dict -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example @@ -474,6 +480,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): Args: - ``checkpoint_dir``: path to the desired checkpoint folder - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters Returns: - pytorch ``state_dict`` @@ -511,10 +518,10 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): if not os.path.isdir(ds_checkpoint_dir): raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters) -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None, exclude_frozen_parameters=False): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. @@ -523,9 +530,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag= - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + - ``exclude_frozen_parameters``: exclude frozen parameters """ - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters) print(f"Saving fp32 state dict to {output_file}") torch.save(state_dict, output_file) @@ -584,9 +592,13 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): type=str, default=None, help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") + parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters") parser.add_argument("-d", "--debug", action='store_true', help="enable debug") args = parser.parse_args() debug = args.debug - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag) + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, + args.output_file, + tag=args.tag, + exclude_frozen_parameters=args.exclude_frozen_parameters) diff --git a/tests/unit/accelerator/test_accelerator.py b/tests/unit/accelerator/test_accelerator.py new file mode 100644 index 000000000000..964cf2b24f4e --- /dev/null +++ b/tests/unit/accelerator/test_accelerator.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest + +import os +import sys +import importlib +import re + +import deepspeed + +DS_ACCEL_PATH = "deepspeed.accelerator" +IGNORE_FILES = ["abstract_accelerator.py", "real_accelerator.py"] + + +@pytest.fixture +def accel_class_name(module_name): + class_list = [] + mocked_modules = [] + + # Get the accelerator class name for a given module + while True: + try: + module = importlib.import_module(module_name) + break + except ModuleNotFoundError as e: + # If the environment is missing a module, mock it so we can still + # test importing the accelerator class + missing_module = re.search(r"\'(.*)\'", e.msg).group().strip("'") + sys.modules[missing_module] = lambda x: None + mocked_modules.append(missing_module) + for name in dir(module): + if name.endswith("_Accelerator"): + class_list.append(name) + + assert len(class_list) == 1, f"Multiple accelerator classes found in {module_name}" + + yield class_list[0] + + # Clean up mocked modules so as to not impact other tests + for module in mocked_modules: + del sys.modules[module] + + +@pytest.mark.parametrize( + "module_name", + [ + DS_ACCEL_PATH + "." + f.rstrip(".py") for f in os.listdir(deepspeed.accelerator.__path__[0]) + if f.endswith("_accelerator.py") and f not in IGNORE_FILES + ], +) +def test_abstract_methods_defined(module_name, accel_class_name): + module = importlib.import_module(module_name) + accel_class = getattr(module, accel_class_name) + accel_class.__init__ = lambda self: None + _ = accel_class() diff --git a/tests/unit/checkpoint/common.py b/tests/unit/checkpoint/common.py index d6dda2f14cbe..7442e51bad5d 100644 --- a/tests/unit/checkpoint/common.py +++ b/tests/unit/checkpoint/common.py @@ -96,6 +96,19 @@ def compare_state_dicts(state0, state1, expected_mismatch_keys=[]): assert s0 == s1, f'failures with keys = {k0}, {k1}, values = {type(s0[0])} and {type(s1[0])}' +def compare_opt_state_dicts(state0, state1, expected_mismatch_keys=[]): + for param_group0, saved_param_group1 in zip(state0['param_groups'], state1['param_groups']): + compare_state_dicts(param_group0, saved_param_group1, expected_mismatch_keys) + + assert "state" in state0 + assert "state" in state1 + assert len([state0["state"].keys()]) == len([state1["state"].keys()]) + + for (k0, s0), (k1, s1) in zip(state0["state"].items(), state1["state"].items()): + assert k0 == k1, f'failure due to key mismatch {k0} != {k1}' + compare_state_dicts(s0, s1, expected_mismatch_keys) + + def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True): saved_optimizer = saved_model.optimizer.optimizer if fp16 else saved_model.optimizer loaded_optimizer = loaded_model.optimizer.optimizer if fp16 else loaded_model.optimizer diff --git a/tests/unit/checkpoint/test_zero_optimizer.py b/tests/unit/checkpoint/test_zero_optimizer.py index f2237341ef68..0b9efb3ec462 100644 --- a/tests/unit/checkpoint/test_zero_optimizer.py +++ b/tests/unit/checkpoint/test_zero_optimizer.py @@ -246,7 +246,8 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.backward(loss) model.step() if load_optim: - torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, 'opt-state-dict')) + opt_state_dict_file = f'opt-state-dict_rank{dist.get_rank()}' + torch.save(model.optimizer.optimizer.state_dict(), os.path.join(tmpdir, opt_state_dict_file)) model.save_checkpoint(tmpdir) ds_config["zero_optimization"]["elastic_checkpoint"] = elastic_load @@ -256,10 +257,9 @@ def test_elastic_checkpoint_fixed_dp(self, tmpdir, elastic_save, elastic_load, l model.load_checkpoint(tmpdir, load_optimizer_states=load_optim) if load_optim: - saved_sd = torch.load(os.path.join(tmpdir, 'opt-state-dict')) + saved_sd = torch.load(os.path.join(tmpdir, opt_state_dict_file)) curr_sd = model.optimizer.optimizer.state_dict() - for curr_param_group, saved_param_group in zip(curr_sd['param_groups'], saved_sd['param_groups']): - compare_state_dicts(curr_param_group, saved_param_group, expected_mismatch_keys) + compare_opt_state_dicts(curr_sd, saved_sd, expected_mismatch_keys) data_loader = random_dataloader(model=model, total_samples=8, hidden_dim=hidden_dim, device=model.device) for n, batch in enumerate(data_loader): diff --git a/tests/unit/common.py b/tests/unit/common.py index 420db577cf09..76bebf6b725a 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -168,7 +168,7 @@ def _launch_daemonic_procs(self, num_procs): # Shortcut to exit pytest in the case of a hanged test. This # usually means an environment error and the rest of tests will # hang (causing super long unit test runtimes) - pytest.exit("Test hanged, exiting", returncode=0) + pytest.exit("Test hanged, exiting", returncode=1) # Tear down distributed environment and close process pools self._close_pool(pool, num_procs) @@ -204,7 +204,7 @@ def _launch_non_daemonic_procs(self, num_procs): if not any_done: for p in processes: p.terminate() - pytest.exit("Test hanged, exiting", returncode=0) + pytest.exit("Test hanged, exiting", returncode=1) # Wait for all other processes to complete for p in processes: diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 067a4969869f..f3056a225a9b 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -36,16 +36,16 @@ pytest.skip("skip inference tests on rocm for now", allow_module_level=True) _bert_models = [ - "bert-base-cased", - "bert-base-uncased", - "bert-large-cased", - "bert-large-uncased", - "bert-base-multilingual-cased", - "bert-base-multilingual-uncased", + "google-bert/bert-base-cased", + "google-bert/bert-base-uncased", + "google-bert/bert-large-cased", + "google-bert/bert-large-uncased", + "google-bert/bert-base-multilingual-cased", + "google-bert/bert-base-multilingual-uncased", "deepset/minilm-uncased-squad2", "cross-encoder/ms-marco-MiniLM-L-12-v2", "dslim/bert-base-NER", - "bert-large-uncased-whole-word-masking-finetuned-squad", + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", "distilbert/distilbert-base-cased-distilled-squad", ] _roberta_models = [ diff --git a/tests/unit/runtime/compile/test_compile_wrapper.py b/tests/unit/runtime/compile/test_compile_wrapper.py index fbf235fb7d62..98a7c28c6a28 100644 --- a/tests/unit/runtime/compile/test_compile_wrapper.py +++ b/tests/unit/runtime/compile/test_compile_wrapper.py @@ -8,9 +8,13 @@ import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import required_torch_version from unit.common import DistributedTest +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + @pytest.fixture def base_config(): diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index 87e3c52b9e3c..910f32db1c96 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -7,11 +7,15 @@ import torch from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.utils import required_torch_version from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest from unit.util import bf16_required_version_check +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + class TestZeRO(DistributedTest): world_size = 2 diff --git a/tests/unit/runtime/compile/test_load_config.py b/tests/unit/runtime/compile/test_load_config.py index 351e91d2f69b..5f1c01b86852 100644 --- a/tests/unit/runtime/compile/test_load_config.py +++ b/tests/unit/runtime/compile/test_load_config.py @@ -9,9 +9,13 @@ from unit.simple_model import SimpleModel import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import required_torch_version from unit.common import DistributedTest +pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), + reason="Compile tests requires Pytorch version 2.1 or above") + custom_backend_called = False custom_compler_fn_called = False diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index bc31e3b9a968..2594d910acff 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1370,6 +1370,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest): world_size = 1 def test(self, zero_stage): + # We verify trhee conditions: + # 1. global_steps starts at 0 + # 2. All subgroups have the same step count + # 3. The global step count is the same as the step count of the first subgroup + # force all params to be partitioned by forcing threshold=0 config_dict = { "train_micro_batch_size_per_gpu": 2, @@ -1399,24 +1404,31 @@ def test(self, zero_stage): model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) - for i, batch in enumerate(data_loader): + assert model.global_steps == 0 + + for batch in data_loader: loss = model(batch[0], batch[1]) model.backward(loss) + + is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary() model.step() - step_counts = [] - if zero_stage == 3: - for sub_group_id, _ in enumerate(optimizer.fp16_groups): - fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] - state = optimizer.optimizer.state[fp32_param] - step_counts.append(state["step"]) - assert all(step == step_counts[0] for step in step_counts) - elif zero_stage == 1 or zero_stage == 2: - for param_group in optimizer.optimizer.param_groups: - for param in param_group["params"]: - state = optimizer.optimizer.state[param] + if is_gradient_accumulation_boundary: + step_counts = [] + + if zero_stage == 3: + for sub_group_id, _ in enumerate(optimizer.fp16_groups): + fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] + state = optimizer.optimizer.state[fp32_param] step_counts.append(state["step"]) + elif zero_stage == 1 or zero_stage == 2: + for param_group in optimizer.optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.optimizer.state[param] + step_counts.append(state["step"]) + assert all(step == step_counts[0] for step in step_counts) + assert model.global_steps == step_counts[0] @pytest.mark.parametrize("zero_stage", [1, 2, 3]) diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index c223e67af697..b3adfdf96c50 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -24,35 +24,26 @@ SECOND_ORDER_KEY = 'exp_avg_sq' -def validate_full_tensors(model): +def validate_tensor(model, api_type, opt_states): + assert api_type in ["full", "local"] for _, lp in model.named_parameters(): - hp = safe_get_full_fp32_param(lp) - exp_avg = safe_get_full_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_full_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_full_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] - if lp.requires_grad: - assert all([p is not None for p in param_list]) + param_list = [] + if opt_states: + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg')) + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg_sq') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg_sq')) else: - assert all([p is None for p in param_list]) - - -def validate_local_tensors(model): - for _, lp in model.named_parameters(): - hp = safe_get_local_fp32_param(lp) - exp_avg = safe_get_local_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_local_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_local_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] + param_list.append(safe_get_full_fp32_param(lp) if api_type == "full" else safe_get_local_fp32_param(lp)) + param_list.append(safe_get_full_grad(lp) if api_type == "full" else safe_get_local_grad(lp)) if lp.requires_grad: assert all([p is not None for p in param_list]) else: assert all([p is None for p in param_list]) -validate_funcs_mapping = {"full": validate_full_tensors, "local": validate_local_tensors} - - class MyModel(torch.nn.Module): def __init__(self, hidden_dim, frozen_weights): @@ -71,12 +62,10 @@ def forward(self, x, y): for l in self.linears: x = l(x) x = self.act(x) - loss = self.cel(x, y) - val = (x, loss) - return val + return self.cel(x, y) -def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): +def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_after_bwd, validate_after_step): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -86,10 +75,10 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): dist.barrier() for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) - loss = loss[1] model.backward(loss) - validate_func(model) + validate_after_bwd(model) model.step() + validate_after_step(model) # Needed in ZeRO 3. Not doing so can give memory leak model.destroy() @@ -147,9 +136,10 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz else: model = MyModel(hidden_dim, frozen_weights) - validate_func = validate_funcs_mapping[api_type] + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_func) + run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_after_bwd, validate_after_step) def test_bf16_fragments(self, frozen_weights): if frozen_weights: @@ -178,7 +168,12 @@ def test_bf16_fragments(self, frozen_weights): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_full_tensors) + + api_type = "full" + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) + + run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_after_bwd, validate_after_step) def create_random_values(model, key_list, group, use_cuda=True): @@ -315,23 +310,21 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp if zero_stage == 3: config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim with deepspeed.zero.Init(config_dict_or_path=config_dict): - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) else: - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) world = dist.get_world_size() group = dist.new_group(ranks=list(range(world))) dist.barrier() - optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] - helper_funcs = helper_funcs_mapping[api_type] - optim_state_values = helper_funcs["create_random_values"](model, - optim_keys, - group, - use_cuda=offload_device == OffloadDeviceEnum.none) - helper_funcs["set_param_values_with_dict"](model, optim_state_values) - helper_funcs["validate_param_values_with_dict"](model, optim_state_values) - - # Needed in ZeRO 3. Not doing so can leak memory. - model.destroy() + + def validate_func(model): + optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"]( + model, optim_keys, group, use_cuda=offload_device == OffloadDeviceEnum.none) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + run_fragmented_model(model, config_dict, hidden_dim, dtype, lambda _: None, validate_func) diff --git a/tests/unit/runtime/zero/test_zeropp.py b/tests/unit/runtime/zero/test_zeropp.py index 545ed98ad2ef..7a05c2a8001b 100644 --- a/tests/unit/runtime/zero/test_zeropp.py +++ b/tests/unit/runtime/zero/test_zeropp.py @@ -202,7 +202,7 @@ def load_and_prepare_data(self, model_name): tokenizer.pad_token = tokenizer.eos_token # Load and tokenize dataset - dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]') + dataset = load_dataset("wikitext", 'wikitext-103-raw-v1', split='train[:1%]').filter(lambda x: x["text"]) def tokenize_function(examples): # Tokenize and ensure 'labels' are the same as 'input_ids' diff --git a/version.txt b/version.txt index 9beb74d490bc..288adf538f0e 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.13.2 +0.13.3