diff --git a/llmfoundry/callbacks/eval_gauntlet_callback.py b/llmfoundry/callbacks/eval_gauntlet_callback.py index b1570e9793..78ccbb529b 100644 --- a/llmfoundry/callbacks/eval_gauntlet_callback.py +++ b/llmfoundry/callbacks/eval_gauntlet_callback.py @@ -6,7 +6,7 @@ import logging import math from enum import Enum -from typing import Optional +from typing import Dict, Optional from composer.core import Callback, State from composer.loggers import Logger @@ -95,7 +95,7 @@ def __init__(self, assert weight is not None benchmark['weighting'] = weight - def compute_averages(self, state: State): + def compute_averages(self, state: State) -> Dict[str, float]: results = {} for key in self.logger_keys: @@ -120,7 +120,7 @@ def compute_averages(self, state: State): return {k: sum(v) / len(v) for k, v in results.items()} - def eval_after_all(self, state: State, logger: Logger): + def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]: new_metrics = self.compute_averages(state) if len(new_metrics) == 0: return {} diff --git a/llmfoundry/callbacks/fdiff_callback.py b/llmfoundry/callbacks/fdiff_callback.py index 3c6064932d..1237f32e22 100644 --- a/llmfoundry/callbacks/fdiff_callback.py +++ b/llmfoundry/callbacks/fdiff_callback.py @@ -26,7 +26,7 @@ def __init__(self, self.train_prev_metric = {} self.eval_prev_metric = {} - def batch_end(self, state: State, logger: Logger): + def batch_end(self, state: State, logger: Logger) -> None: if self.diff_train_metrics: if not isinstance(state.loss, torch.Tensor): raise NotImplementedError('Multiple losses not supported yet') @@ -46,7 +46,7 @@ def batch_end(self, state: State, logger: Logger): value = state.train_metric_values[k] self.train_prev_metric[k] = value - def eval_end(self, state: State, logger: Logger): + def eval_end(self, state: State, logger: Logger) -> None: if self.diff_eval_metrics: evaluator = state.dataloader_label assert evaluator is not None, 'dataloader should have been set' diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index b6596fbc6a..bb5b557d37 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -47,11 +47,11 @@ def init(self, state: State, logger: Logger): if isinstance(destination, WandBLogger): self.wandb_logger = destination - def batch_checkpoint(self, state: State, logger: Logger): + def batch_checkpoint(self, state: State, logger: Logger) -> None: if (state.timestamp.batch.value % self.batch_log_interval) == 0: self.generate(state, logger) - def generate(self, state: State, logger: Logger): + def generate(self, state: State, logger: Logger) -> None: model = state.model original_mode = model.training model.eval() diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index afca099832..6d72762323 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -46,22 +46,24 @@ def __init__(self, else: self.remote_ud = None - def init(self, state: State, logger: Logger): + def init(self, state: State, logger: Logger) -> None: if self.upload_to_object_store and self.remote_ud is not None: self.remote_ud.init(state, logger) # updated_logger_destinations = [*logger.destinations, new_remote_ud] # logger.destinations = tuple(updated_logger_destinations) state.callbacks.append(self.remote_ud) - def batch_checkpoint(self, state: State, logger: Logger): + def batch_checkpoint(self, state: State, logger: Logger) -> None: if state.timestamp.batch.value % self.batch_interval == 0: self._save_checkpoint(state, logger) - def fit_end(self, state: State, logger: Logger): + def fit_end(self, state: State, logger: Logger) -> None: if state.timestamp.batch.value % self.batch_interval != 0: self._save_checkpoint(state, logger) - def _save_checkpoint(self, state: State, logger: Logger): + def _save_checkpoint(self, state: State, logger: Logger) -> None: + del logger # unused + filename = format_name_with_dist_and_time(self.filename_format_str, state.run_name, state.timestamp) diff --git a/llmfoundry/callbacks/resumption_callbacks.py b/llmfoundry/callbacks/resumption_callbacks.py index b5e20a7a57..751accc922 100644 --- a/llmfoundry/callbacks/resumption_callbacks.py +++ b/llmfoundry/callbacks/resumption_callbacks.py @@ -32,7 +32,9 @@ def __init__(self, lr_scale: float, wd_pct: float = 0.0): self.lr_scale = lr_scale self.wd_pct = wd_pct - def fit_start(self, state: State, logger: Logger): + def fit_start(self, state: State, logger: Logger) -> None: + del logger # unused + if hasattr(state, 'optimizer') and state.optimizers is None: raise Exception('No optimizers defined') for optimizer in state.optimizers: @@ -65,7 +67,9 @@ class LayerFreezing(Callback): def __init__(self, layer_names: List[str]): self.layer_names = set(layer_names) - def fit_start(self, state: State, logger: Logger): + def fit_start(self, state: State, logger: Logger) -> None: + del logger # unused + model_layers = set(name for name, _ in state.model.named_parameters()) for layer in self.layer_names: if layer not in model_layers: diff --git a/llmfoundry/callbacks/scheduled_gc_callback.py b/llmfoundry/callbacks/scheduled_gc_callback.py index 37c2193eda..6bd085e68f 100644 --- a/llmfoundry/callbacks/scheduled_gc_callback.py +++ b/llmfoundry/callbacks/scheduled_gc_callback.py @@ -9,7 +9,7 @@ def gc_cuda(): - """Gargage collect Torch (CUDA) memory.""" + """Garbage collect Torch (CUDA) memory.""" gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -32,7 +32,9 @@ def __init__( self.eval_keep_disabled = eval_keep_disabled self.gc_init_state = None - def fit_start(self, state: State, logger: Logger): + def fit_start(self, state: State, logger: Logger) -> None: + del state, logger # unused + # cache if automatic garbage collection is enabled; reset at fit_end self.gc_init_state = gc.isenabled() @@ -40,7 +42,9 @@ def fit_start(self, state: State, logger: Logger): gc.disable() gc_cuda() - def fit_end(self, state: State, logger: Logger): + def fit_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + gc_cuda() # reset automatic garbage collection at fit_end @@ -49,16 +53,22 @@ def fit_end(self, state: State, logger: Logger): else: gc.disable() - def before_dataloader(self, state: State, logger: Logger): + def before_dataloader(self, state: State, logger: Logger) -> None: + del logger # unused + if state.timestamp.batch.value % self.batch_interval == 0: gc_cuda() - def eval_start(self, state: State, logger: Logger): + def eval_start(self, state: State, logger: Logger) -> None: + del state, logger # unused + gc_cuda() if not self.eval_keep_disabled: gc.enable() - def eval_end(self, state: State, logger: Logger): + def eval_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + if not self.eval_keep_disabled: gc.disable() diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index 302bdc4bc4..d685d0077d 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -269,11 +269,11 @@ def __init__( '`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.') @property - def smallest_max_raw_length(self): + def smallest_max_raw_length(self) -> int: return int(self._smallest_max_raw_length) @property - def largest_max_raw_length(self): + def largest_max_raw_length(self) -> int: return int(self._largest_max_raw_length) def __call__(self, examples: List[Dict[str, @@ -613,7 +613,8 @@ def noise_token_sequence( def _get_max_starting_length(max_length: int, mask_ratio: float, mean_span_length: float, n_prefix_tokens: int, - decoder_only_format: bool, context_eos: bool): + decoder_only_format: bool, + context_eos: bool) -> int: """Get max num raw tokens that will fit max_length.""" def sequence_stats(length: int): diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index b0d175f2a8..a009f13660 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import os -from typing import Union +from typing import Tuple, Union import datasets as hf_datasets import torch @@ -207,7 +207,7 @@ def build_finetuning_dataloader(cfg: DictConfig, ) -def _validate_config(dataset_cfg: DictConfig): +def _validate_config(dataset_cfg: DictConfig) -> None: """Validates the dataset configuration. Makes sure that the dataset is properly configured for either @@ -352,9 +352,10 @@ def _build_hf_dataset_from_remote( return dataset -def _build_collate_fn(dataset_cfg: DictConfig, - tokenizer: PreTrainedTokenizerBase, - device_batch_size: int): +def _build_collate_fn( + dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, + device_batch_size: int +) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]: collate_fn = Seq2SeqFinetuningCollator( tokenizer=tokenizer, max_seq_len=dataset_cfg.max_seq_len, diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index f5e6ac6b27..c184dc9848 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -35,7 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: import logging import os import warnings -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import datasets as hf_datasets from omegaconf import DictConfig @@ -47,8 +47,9 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: __all__ = ['dataset_constructor'] -def _tokenize_formatted_example(example: Dict[str, Any], - tokenizer: PreTrainedTokenizerBase): +def _tokenize_formatted_example( + example: Dict[str, Any], + tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]: if ('prompt' not in example) or ('response' not in example): raise KeyError( 'Unable to tokenize example because it has not been properly formatted. ' +\ @@ -150,7 +151,7 @@ class DatasetConstructor: def __init__(self): self._task_preprocessing_registry: Dict[str, Callable] = {} - def register(self, *names: str): + def register(self, *names: str) -> Callable[[Callable], Callable]: """Decorator for registering preprocessing functions.""" def _register_func(name: str, func: Callable) -> None: @@ -168,11 +169,13 @@ def wrapper(func: Callable) -> Callable: return wrapper - def print_registered_tasks(self): + def print_registered_tasks(self) -> None: tasks = sorted(self._task_preprocessing_registry.keys()) print('\n'.join(tasks)) - def get_preprocessing_fn_from_dict(self, mapping: Union[Dict, DictConfig]): + def get_preprocessing_fn_from_dict( + self, mapping: Union[Dict, DictConfig] + ) -> Callable[[Dict[str, Any]], Dict[str, str]]: """Get a preprocessing function from a dictionary. The dictionary maps column names in the dataset to "prompt" and "response". @@ -206,9 +209,11 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]: return _preprocessor - def get_preprocessing_fn_from_str(self, - preprocessor: Optional[str], - dataset_name: Optional[str] = None): + def get_preprocessing_fn_from_str( + self, + preprocessor: Optional[str], + dataset_name: Optional[str] = None + ) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]: """Get a preprocessing function from a string. String can be either a registered function or an import path. @@ -319,7 +324,8 @@ def dataset_mapper(example: Dict): return empty_examples_dropped_dataset - def build_from_streaming(self, *args: Any, **kwargs: Any): + def build_from_streaming(self, *args: Any, + **kwargs: Any) -> StreamingFinetuningDataset: return StreamingFinetuningDataset(*args, **kwargs) @@ -327,7 +333,7 @@ def build_from_streaming(self, *args: Any, **kwargs: Any): @dataset_constructor.register('tatsu-lab/alpaca') -def alpaca_preprocessing_function(inp: Dict): +def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]: """Split out prompt/response from text.""" try: prompt, response = inp['text'].split('### Response:') @@ -340,7 +346,7 @@ def alpaca_preprocessing_function(inp: Dict): @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k') -def dolly_preprocessing_function(inp: Dict): +def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]: """Format the text string.""" PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n' try: @@ -357,7 +363,7 @@ def dolly_preprocessing_function(inp: Dict): @dataset_constructor.register('bigscience/P3') -def p3_preprocessing_function(inp: Dict): +def p3_preprocessing_function(inp: Dict) -> Dict[str, str]: """Format the already-split example.""" return { 'prompt': inp['inputs'] + ':', @@ -367,7 +373,7 @@ def p3_preprocessing_function(inp: Dict): # Muennighoff's P3 and flan datasets share a similar convention @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan') -def muennighoff_tokenize_function(inp: Dict): +def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]: """Format the already-split example.""" try: prompt: str = inp['inputs'] diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 5f157724ce..d0a73be801 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -48,11 +48,11 @@ def __init__(self, self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = [] @property - def waste(self): + def waste(self) -> float: return 1 - (self.n_packed_tokens / self.n_total_tokens) @property - def efficiency(self): + def efficiency(self) -> float: return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples) @@ -100,7 +100,8 @@ def __call__( return batch -def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int): +def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], + idx: int) -> Tuple[int, Dict[str, torch.Tensor]]: example = {k: v[idx] for k, v in batch.items()} keep = example['attention_mask'] == 1 @@ -111,8 +112,9 @@ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int): return size, trim_example -def combine_in_place(example: Dict[str, torch.Tensor], - add_on: Dict[str, torch.Tensor]): +def combine_in_place( + example: Dict[str, torch.Tensor], + add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if 'labels' in add_on: # Prevents the last token in example from being trained to # predict the first token in add_on, which would make no sense. diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 4562d3de0a..31626b237f 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -143,7 +143,7 @@ def __init__(self, self.max_seq_len = max_seq_len # How to tokenize a text sample to a token sample - def _tokenize(self, text_sample: Mapping): + def _tokenize(self, text_sample: Mapping) -> Dict[str, List[int]]: if self.tokenizer._pad_token is None: # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs raise RuntimeError( @@ -154,13 +154,15 @@ def _tokenize(self, text_sample: Mapping): padding='max_length', max_length=self.max_seq_len) - def _read_binary_tokenized_sample(self, sample: Dict[str, Any]): + def _read_binary_tokenized_sample(self, sample: Dict[str, + Any]) -> torch.Tensor: return torch.from_numpy( np.frombuffer(sample['tokens'], dtype=np.int64)[:self.max_seq_len].copy()) # How to process a sample - def __getitem__(self, idx: int): + def __getitem__(self, + idx: int) -> Union[Dict[str, List[int]], torch.Tensor]: sample = super().__getitem__(idx) if 'text' in sample: token_sample = self._tokenize(sample) @@ -224,7 +226,7 @@ def build_text_dataloader( cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, -): +) -> DataLoader: assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}' if cfg.dataset.get('group_method', None) is not None: raise NotImplementedError( diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index 56ba24aeff..919c33227d 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -13,7 +13,7 @@ # helper functions -def rhasattr(obj: Any, attr: str): +def rhasattr(obj: Any, attr: str) -> bool: """A chain-able attribute version of hasattr. For example, to check if @@ -31,7 +31,7 @@ def rhasattr(obj: Any, attr: str): return hasattr(_curr_obj, _nested_attrs[-1]) -def rgetattr(obj: Any, attr: str, *args: List[Any]): +def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any: """A chain-able attribute version of getattr. For example, to get the attribute `foo.bar.baz` from `obj`, you can use: @@ -45,14 +45,14 @@ def _getattr(obj: Any, attr: str): return functools.reduce(_getattr, [obj] + attr.split('.')) -def findattr(obj: Any, attrs: Iterable[str]): +def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]: for attr in attrs: if rhasattr(obj, attr): return rgetattr(obj, attr) return None -def hf_get_causal_base_model(model: PreTrainedModel): +def hf_get_causal_base_model(model: PreTrainedModel) -> Any: """Returns the causal decoder backbone of the specified HuggingFace model. Newer HF models have a `self.get_decoder()` method. Older models do not. @@ -75,7 +75,7 @@ def hf_get_causal_base_model(model: PreTrainedModel): return causal_base_model -def hf_get_hidden_layers(model: PreTrainedModel): +def hf_get_hidden_layers(model: PreTrainedModel) -> Any: """Returns the hidden layers of the specified model. NOTE: Different model configurations have different hidden layer attribute names. @@ -102,7 +102,7 @@ def hf_get_hidden_layers(model: PreTrainedModel): return layers -def hf_get_init_device(init_device: Optional[str]): +def hf_get_init_device(init_device: Optional[str]) -> Optional[str]: """Returns the appropriate device to initialize models.""" from composer.utils import dist if init_device == 'mixed': diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 6ac496ebd8..76969b7810 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -18,7 +18,7 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, - original_is_causal: bool): + original_is_causal: bool) -> bool: # disable causal when it is not needed # necessary for flash & triton for generation with kv_cache if original_is_causal and num_query_tokens != num_key_tokens: @@ -495,7 +495,8 @@ def forward( attention_mask: Optional[torch.Tensor] = None, is_causal: bool = True, needs_weights: bool = False, - ): + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[ + torch.Tensor, torch.Tensor]]]: qkv = self.Wqkv(x) if self.clip_qkv: @@ -605,8 +606,10 @@ def __init__( device=device) -def attn_bias_shape(attn_impl: str, n_heads: int, seq_len: int, alibi: bool, - prefix_lm: bool, causal: bool, use_sequence_id: bool): +def attn_bias_shape( + attn_impl: str, n_heads: int, seq_len: int, alibi: bool, + prefix_lm: bool, causal: bool, + use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: @@ -629,7 +632,7 @@ def build_attn_bias( causal: bool = False, alibi: bool = False, alibi_bias_max: int = 8, -): +) -> Optional[torch.Tensor]: if attn_impl == 'flash': return None elif attn_impl in ['torch', 'triton']: @@ -652,7 +655,7 @@ def build_attn_bias( def gen_slopes(n_heads: int, alibi_bias_max: int = 8, - device: Optional[torch.device] = None): + device: Optional[torch.device] = None) -> torch.Tensor: _n_heads = 2**math.ceil(math.log2(n_heads)) m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) m = m.mul(alibi_bias_max / _n_heads) @@ -674,7 +677,7 @@ def build_alibi_bias( alibi_bias_max: int = 8, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, -): +) -> torch.Tensor: alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len) if full: diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index 0b41a753d9..af770a84f7 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -42,7 +42,7 @@ def __init__( ) self.down_proj._is_residual = True - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(self.act(self.up_proj(x))) @@ -61,7 +61,7 @@ def build_ffn( fc_type: str = 'torch', device: Optional[str] = None, **kwargs: Any, -): +) -> nn.Module: ffn_type = kwargs.pop('ffn_type') if ffn_type == 'mptmlp': if len(kwargs) > 0: diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py index 0f75986e11..88f61e3fef 100644 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py @@ -36,18 +36,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim) -def rotate_half(x: torch.Tensor): +def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - position_ids: Optional[torch.Tensor] = None): +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] diff --git a/llmfoundry/models/layers/norm.py b/llmfoundry/models/layers/norm.py index fabe0a8ccb..2ff4eaed0c 100644 --- a/llmfoundry/models/layers/norm.py +++ b/llmfoundry/models/layers/norm.py @@ -6,7 +6,7 @@ import torch -def _cast_if_autocast_enabled(tensor: torch.Tensor): +def _cast_if_autocast_enabled(tensor: torch.Tensor) -> torch.Tensor: if torch.is_autocast_enabled(): if tensor.device.type == 'cuda': dtype = torch.get_autocast_gpu_dtype() @@ -36,7 +36,7 @@ def __init__( dtype=dtype, ) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( @@ -55,7 +55,7 @@ def forward(self, x: torch.Tensor): def rms_norm(x: torch.Tensor, weight: Optional[torch.Tensor] = None, - eps: float = 1e-5): + eps: float = 1e-5) -> torch.Tensor: output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) if weight is not None: return output * weight @@ -80,7 +80,7 @@ def __init__( else: self.register_parameter('weight', None) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) @@ -102,7 +102,7 @@ def __init__( device=device, ) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled( self.weight) if self.weight is not None else self.weight diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 38946b47c8..251e4f5caf 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -159,14 +159,14 @@ def __init__( self._validate_config() def _set_config_defaults(self, config: Dict[str, Any], - config_defaults: Dict[str, Any]): + config_defaults: Dict[str, Any]) -> Dict[str, Any]: # set config defaults for k, v in config_defaults.items(): if k not in config: config[k] = v return config - def _validate_config(self): + def _validate_config(self) -> None: # set config defaults self.attn_config = self._set_config_defaults( self.attn_config, diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 1b4ca764ea..3371c67a0d 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -8,7 +8,8 @@ import math import warnings -from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union +from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, + Union) import torch import torch.nn as nn @@ -152,10 +153,10 @@ def __init__(self, config: MPTConfig): log.debug(self) log.debug(f'Using {self.config.init_config["name"]} initialization.') - def get_input_embeddings(self): + def get_input_embeddings(self) -> nn.Embedding: return self.wte - def set_input_embeddings(self, value: nn.Embedding): + def set_input_embeddings(self, value: nn.Embedding) -> None: self.wte = value @torch.no_grad() @@ -166,7 +167,7 @@ def _attn_bias( attention_mask: Optional[torch.ByteTensor] = None, prefix_mask: Optional[torch.ByteTensor] = None, sequence_id: Optional[torch.LongTensor] = None, - ): + ) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]: if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = torch.zeros(self.attn_bias_shape, @@ -190,7 +191,7 @@ def _attn_bias( if self.attn_bias is not None: # .to(*args, **kwargs) is a no-op if tensor is already on - # specified device or of specificed dtype + # specified device or of specified dtype self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias @@ -231,7 +232,7 @@ def _attn_bias( return attn_bias, None def _apply_prefix_mask(self, attn_bias: torch.Tensor, - prefix_mask: torch.Tensor): + prefix_mask: torch.Tensor) -> torch.Tensor: s_k, s_q = attn_bias.shape[-2:] if (s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len): raise ValueError( @@ -262,7 +263,7 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, return attn_bias def _apply_sequence_id(self, attn_bias: torch.Tensor, - sequence_id: torch.LongTensor): + sequence_id: torch.LongTensor) -> torch.Tensor: seq_len = sequence_id.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError( @@ -296,7 +297,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.Tensor] = None, - ): + ) -> BaseModelOutputWithPast: return_dict = (return_dict if return_dict is not None else self.config.return_dict) use_cache = (use_cache @@ -456,7 +457,7 @@ def forward( ) # Param Initialization, needed for device='meta' fast initialization - def param_init_fn(self, module: nn.Module): + def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name]( module=module, @@ -466,11 +467,11 @@ def param_init_fn(self, module: nn.Module): ) # FSDP Wrap function - def fsdp_wrap_fn(self, module: nn.Module): + def fsdp_wrap_fn(self, module: nn.Module) -> bool: return isinstance(module, MPTBlock) # Activation Checkpointing - def activation_checkpointing_fn(self, module: nn.Module): + def activation_checkpointing_fn(self, module: nn.Module) -> bool: return isinstance(module, MPTBlock) @@ -506,23 +507,24 @@ def __init__(self, config: MPTConfig): ) self.logit_scale = logit_scale - def get_input_embeddings(self): + def get_input_embeddings(self) -> nn.Embedding: return self.transformer.wte - def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]): + def set_input_embeddings( + self, value: Union[SharedEmbedding, nn.Embedding]) -> None: self.transformer.wte = value - def get_output_embeddings(self): + def get_output_embeddings(self) -> nn.Embedding: return self.transformer.wte - def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, - nn.Embedding]): + def set_output_embeddings( + self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None: self.transformer.wte = new_embeddings - def set_decoder(self, decoder: MPTModel): + def set_decoder(self, decoder: MPTModel) -> None: self.transformer = decoder - def get_decoder(self): + def get_decoder(self) -> MPTModel: return self.transformer def forward( @@ -538,7 +540,7 @@ def forward( output_hidden_states: Optional[bool] = None, use_cache: Optional[bool] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - ): + ) -> CausalLMOutputWithPast: return_dict = (return_dict if return_dict is not None else self.config.return_dict) use_cache = (use_cache @@ -593,7 +595,7 @@ def forward( ) # Param Initialization, needed for device='meta' fast initialization - def param_init_fn(self, module: nn.Module): + def param_init_fn(self, module: nn.Module) -> None: init_fn_name = self.config.init_config['name'] MODEL_INIT_REGISTRY[init_fn_name]( module=module, @@ -603,11 +605,11 @@ def param_init_fn(self, module: nn.Module): ) # FSDP Wrap function - def fsdp_wrap_fn(self, module: nn.Module): + def fsdp_wrap_fn(self, module: nn.Module) -> bool: return isinstance(module, MPTBlock) # Activation Checkpointing - def activation_checkpointing_fn(self, module: nn.Module): + def activation_checkpointing_fn(self, module: nn.Module) -> bool: return isinstance(module, MPTBlock) def prepare_inputs_for_generation( @@ -617,7 +619,7 @@ def prepare_inputs_for_generation( torch.Tensor]]] = None, inputs_embeds: Optional[torch.Tensor] = None, **kwargs: Any, - ): + ) -> Dict[str, Any]: if inputs_embeds is not None: raise NotImplementedError( 'inputs_embeds is not implemented for MPT yet') @@ -655,8 +657,9 @@ def prepare_inputs_for_generation( } @staticmethod - def _reorder_cache(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - beam_idx: torch.LongTensor): + def _reorder_cache( + past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], + beam_idx: torch.LongTensor) -> List[Tuple[torch.Tensor, ...]]: """Used by HuggingFace generate when using beam search with kv-caching. See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 @@ -729,12 +732,12 @@ def __init__( f'Specified loss_fn={self.loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].' ) - def get_targets(self, batch: Mapping): + def get_targets(self, batch: Mapping) -> torch.Tensor: targets = torch.roll(batch['labels'], shifts=-1) targets[:, -1] = -100 return targets - def forward(self, batch: MutableMapping): + def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: if self.model.transformer.prefix_lm: add_bidirectional_mask_if_missing(batch) # Note: prefix_mask is only used if model.prefix_lm is True @@ -746,12 +749,13 @@ def forward(self, batch: MutableMapping): inputs_embeds=batch.get('inputs_embeds', None), ) - def loss(self, outputs: CausalLMOutputWithPast, batch: Mapping): + def loss(self, outputs: CausalLMOutputWithPast, + batch: Mapping) -> torch.Tensor: targets = self.get_targets(batch) return self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), targets.view(-1)) - def flops_per_batch(self, batch: Mapping): + def flops_per_batch(self, batch: Mapping) -> int: # Note: this computation does not take into account padding, and assumes # that the dataset has been constructed without padding. Additionally, we # assume the backward pass is approximately 2x the forward pass diff --git a/llmfoundry/models/utils/adapt_tokenizer.py b/llmfoundry/models/utils/adapt_tokenizer.py index df98ba6895..8cb0c33697 100644 --- a/llmfoundry/models/utils/adapt_tokenizer.py +++ b/llmfoundry/models/utils/adapt_tokenizer.py @@ -10,7 +10,7 @@ NUM_SENTINEL_TOKENS: int = 100 -def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase): +def adapt_tokenizer_for_denoising(tokenizer: PreTrainedTokenizerBase) -> None: """Adds sentinel tokens and padding token (if missing). Expands the tokenizer vocabulary to include sentinel tokens @@ -49,7 +49,8 @@ class AutoTokenizerForMOD(AutoTokenizer): """ @classmethod - def from_pretrained(cls, *args: Any, **kwargs: Any): + def from_pretrained(cls, *args: Any, + **kwargs: Any) -> PreTrainedTokenizerBase: """See `AutoTokenizer.from_pretrained` docstring.""" tokenizer = super().from_pretrained(*args, **kwargs) adapt_tokenizer_for_denoising(tokenizer) diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 2411dc8a16..2e72ccfa47 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -22,7 +22,7 @@ def torch_default_param_init_fn_( module: nn.Module, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config if hasattr(module, 'reset_parameters') and isinstance( @@ -30,7 +30,7 @@ def torch_default_param_init_fn_( module.reset_parameters() -def fused_init_helper_(module: nn.Module, init_fn_: Callable): +def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None: # parameter initialization is often based on the parameters shape. # If a layer is fused, initialization should be based on the shapes # of the original tensor instead of the shape of the fused tensor. @@ -62,7 +62,7 @@ def generic_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config # enable user to divide _is_residual weights by @@ -198,7 +198,7 @@ def generic_param_init_fn_( ) -def _normal_init_(std: float, mean: float = 0.0): +def _normal_init_(std: float, mean: float = 0.0) -> Callable: return partial(torch.nn.init.normal_, mean=mean, std=std) @@ -211,7 +211,7 @@ def _normal_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config init_fn_ = _normal_init_(std=std) @@ -228,14 +228,14 @@ def _normal_param_init_fn_( def baseline_param_init_fn_( module: nn.Module, - init_std: float, + init_std: Optional[float], n_layers: int, d_model: Optional[int] = None, init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config if init_std is None: raise ValueError( @@ -260,7 +260,7 @@ def small_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config # very close to kaiming normal # from Transformers without Tears (2019) - Nguyen & Salazar @@ -283,7 +283,7 @@ def neox_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, **kwargs: Any, -): +) -> None: """From section 2.3.1 of GPT-NeoX-20B: An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) @@ -314,7 +314,7 @@ def kaiming_uniform_param_init_fn_( fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config kaiming_uniform_ = partial(nn.init.kaiming_uniform_, @@ -344,7 +344,7 @@ def kaiming_normal_param_init_fn_( fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, @@ -372,7 +372,7 @@ def xavier_uniform_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) @@ -396,7 +396,7 @@ def xavier_normal_param_init_fn_( emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, **kwargs: Any, -): +) -> None: del kwargs # unused, just to capture any extra args from the config xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) diff --git a/llmfoundry/optim/adaptive_lion.py b/llmfoundry/optim/adaptive_lion.py index 58c0f93ad5..06110bab23 100644 --- a/llmfoundry/optim/adaptive_lion.py +++ b/llmfoundry/optim/adaptive_lion.py @@ -101,7 +101,7 @@ def lionw(p: torch.Tensor, grad: torch.Tensor, exp_avg: torch.Tensor, @staticmethod def adjust_lr(lr: float, lr_penalty: float, num_times: int, - min_scale: float): + min_scale: float) -> float: """Adjusts LR. Multiplicatively scales down the LR by lr_penalty for each outlier diff --git a/llmfoundry/optim/outlier_detection.py b/llmfoundry/optim/outlier_detection.py index b485a17c5d..9df4381ba4 100644 --- a/llmfoundry/optim/outlier_detection.py +++ b/llmfoundry/optim/outlier_detection.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import collections +from typing import Optional class OutlierDetector: @@ -53,7 +54,7 @@ def insert_observation(self, obs: float) -> bool: delayed_mva = self.get_delayed_mva() return delayed_mva is not None and obs > self.threshold * delayed_mva - def get_delayed_mva(self): + def get_delayed_mva(self) -> Optional[float]: if len(self.delayed_moving_average) > 0: return sum(self.delayed_moving_average) / len( self.delayed_moving_average) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 937d30661e..c0eb2a59df 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -10,18 +10,20 @@ from composer.callbacks import (EarlyStopper, LRMonitor, MemoryMonitor, OptimizerMonitor, RuntimeEstimator, SpeedMonitor) -from composer.core import Evaluator +from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader -from composer.loggers import (InMemoryLogger, MLFlowLogger, TensorboardLogger, - WandBLogger) +from composer.loggers import (InMemoryLogger, LoggerDestination, MLFlowLogger, + TensorboardLogger, WandBLogger) from composer.optim import DecoupledAdamW -from composer.optim.scheduler import (ConstantWithWarmupScheduler, +from composer.optim.scheduler import (ComposerScheduler, + ConstantWithWarmupScheduler, CosineAnnealingWithWarmupScheduler, LinearWithWarmupScheduler) from composer.utils import dist from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from torch.optim.optimizer import Optimizer from transformers import AutoTokenizer, PreTrainedTokenizerBase from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, Generate, @@ -68,7 +70,7 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]): +def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -101,7 +103,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]): raise ValueError(f'Not sure how to build callback: {name}') -def build_logger(name: str, kwargs: Dict[str, Any]): +def build_logger(name: str, kwargs: Dict[str, Any]) -> LoggerDestination: if name == 'wandb': return WandBLogger(**kwargs) elif name == 'tensorboard': @@ -114,7 +116,7 @@ def build_logger(name: str, kwargs: Dict[str, Any]): raise ValueError(f'Not sure how to build logger: {name}') -def build_algorithm(name: str, kwargs: Dict[str, Any]): +def build_algorithm(name: str, kwargs: Dict[str, Any]) -> Algorithm: if name == 'gradient_clipping': return algorithms.GradientClipping(**kwargs) elif name == 'alibi': @@ -130,7 +132,7 @@ def build_algorithm(name: str, kwargs: Dict[str, Any]): def build_optimizer(model: torch.nn.Module, name: str, - optimizer_config: Dict[str, Any]): + optimizer_config: Dict[str, Any]) -> Optimizer: if name == 'decoupled_adamw': return DecoupledAdamW(model.parameters(), **optimizer_config) elif name == 'decoupled_lionw': @@ -145,7 +147,8 @@ def build_optimizer(model: torch.nn.Module, name: str, raise ValueError(f'Not sure how to build optimizer: {name}') -def build_scheduler(name: str, scheduler_config: Dict[str, Any]): +def build_scheduler(name: str, + scheduler_config: Dict[str, Any]) -> ComposerScheduler: if name == 'constant_with_warmup': return ConstantWithWarmupScheduler(**scheduler_config) elif name == 'cosine_with_warmup': @@ -183,7 +186,7 @@ def build_icl_evaluators( default_batch_size: int, destination_dir: Optional[str] = None, icl_subset_num_batches: Optional[int] = None, -): +) -> Tuple[List[Evaluator], List[str]]: if destination_dir is None: destination_dir = os.getcwd() diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index e058706316..0627cec4cd 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -117,7 +117,7 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, def _convert_weight_to_ft_each(save_dir: str, infer_gpu_num: int, tensor_name: str, config: Dict[str, Any], - data: np.ndarray): + data: np.ndarray) -> None: """Convert each MPT weight to a FasterTransformer compatible format. Args: @@ -231,7 +231,7 @@ def convert_and_save_ft_weights(named_params: dict, config: dict, infer_gpu_num: int = 1, weight_data_type: str = 'fp32', - save_dir: str = ''): + save_dir: str = '') -> None: """Convert a Composer MPT checkpoint to a FasterTransformer format. Args: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 103f091c0a..8690271874 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -5,7 +5,7 @@ import logging import math import warnings -from typing import Any, Dict, Mapping, Optional, Union +from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union from composer.utils import dist from omegaconf import DictConfig, ListConfig @@ -46,8 +46,9 @@ def pop_config(cfg: DictConfig, return default_value -def calculate_batch_size_info(global_batch_size: int, - device_microbatch_size: Union[int, str]): +def calculate_batch_size_info( + global_batch_size: int, device_microbatch_size: Union[int, str] +) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]: if global_batch_size % dist.get_world_size() != 0: raise ValueError( f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' @@ -73,7 +74,7 @@ def calculate_batch_size_info(global_batch_size: int, # Coming soon: this conversion math will be done inside Composer Trainer -def update_batch_size_info(cfg: DictConfig): +def update_batch_size_info(cfg: DictConfig) -> DictConfig: device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info( cfg.global_train_batch_size, cfg.device_train_microbatch_size) cfg.n_gpus = dist.get_world_size() @@ -141,7 +142,7 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): return init_context -def log_config(cfg: DictConfig): +def log_config(cfg: DictConfig) -> None: """Logs the current config and updates the wandb and mlflow configs. This function can be called multiple times to update the wandb and MLflow diff --git a/llmfoundry/utils/huggingface_hub_utils.py b/llmfoundry/utils/huggingface_hub_utils.py index 4b837d2e67..47d7f79bff 100644 --- a/llmfoundry/utils/huggingface_hub_utils.py +++ b/llmfoundry/utils/huggingface_hub_utils.py @@ -14,7 +14,7 @@ class DeleteSpecificNodes(ast.NodeTransformer): def __init__(self, nodes_to_remove: List[ast.AST]): self.nodes_to_remove = nodes_to_remove - def visit(self, node: ast.AST): + def visit(self, node: ast.AST) -> Optional[ast.AST]: if node in self.nodes_to_remove: return None @@ -92,7 +92,7 @@ def process_file(file_path: str, folder_path: str) -> List[str]: return new_files_to_process -def edit_files_for_hf_compatibility(folder: str): +def edit_files_for_hf_compatibility(folder: str) -> None: files_to_process = [ os.path.join(folder, filename) for filename in os.listdir(folder) diff --git a/scripts/inference/convert_composer_mpt_to_ft.py b/scripts/inference/convert_composer_mpt_to_ft.py index d260c31491..79275030b3 100644 --- a/scripts/inference/convert_composer_mpt_to_ft.py +++ b/scripts/inference/convert_composer_mpt_to_ft.py @@ -8,7 +8,7 @@ import tempfile from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch from composer.utils import get_file, safe_torch_load @@ -18,7 +18,7 @@ get_hf_tokenizer_from_composer_state_dict) -def save_ft_config(composer_config: dict, +def save_ft_config(composer_config: Dict[str, Any], tokenizer: PreTrainedTokenizer, save_dir: str, infer_gpu_num: int = 1, diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index f73836e28f..1ba1123c86 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -30,7 +30,7 @@ import os from argparse import ArgumentTypeError from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch from composer.utils import (maybe_create_object_store_from_uri, parse_uri, @@ -82,7 +82,7 @@ def export_to_onnx( export_batch_size: int, max_seq_len: Optional[int], verify_export: bool, - from_pretrained_kwargs: dict, + from_pretrained_kwargs: Dict[str, Any], ): reproducibility.seed_all(42) save_object_store = maybe_create_object_store_from_uri(output_folder) diff --git a/tests/test_model.py b/tests/test_model.py index f20381f288..501d9bf6e7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -513,8 +513,10 @@ def test_mpt_creation(norm_type: str, no_bias: bool): assert block.norm_1.weight.shape == torch.Size([d_model]) assert block.norm_2 is not None assert block.norm_2.weight.shape == torch.Size([d_model]) + assert isinstance(block.ffn.up_proj, nn.Linear) assert block.ffn.up_proj.weight.shape == torch.Size( [hf_config.d_model * hf_config.expansion_ratio, hf_config.d_model]) + assert isinstance(block.ffn.down_proj, nn.Linear) assert block.ffn.down_proj.weight.shape == torch.Size( [hf_config.d_model, hf_config.d_model * hf_config.expansion_ratio]) assert block.resid_attn_dropout.p == 0.2