Skip to content

Commit

Permalink
Add lots of return types (#595)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Sep 14, 2023
1 parent 7023e76 commit 48ba632
Show file tree
Hide file tree
Showing 29 changed files with 192 additions and 147 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/callbacks/eval_gauntlet_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import math
from enum import Enum
from typing import Optional
from typing import Dict, Optional

from composer.core import Callback, State
from composer.loggers import Logger
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self,
assert weight is not None
benchmark['weighting'] = weight

def compute_averages(self, state: State):
def compute_averages(self, state: State) -> Dict[str, float]:
results = {}

for key in self.logger_keys:
Expand All @@ -120,7 +120,7 @@ def compute_averages(self, state: State):

return {k: sum(v) / len(v) for k, v in results.items()}

def eval_after_all(self, state: State, logger: Logger):
def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
new_metrics = self.compute_averages(state)
if len(new_metrics) == 0:
return {}
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/fdiff_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self,
self.train_prev_metric = {}
self.eval_prev_metric = {}

def batch_end(self, state: State, logger: Logger):
def batch_end(self, state: State, logger: Logger) -> None:
if self.diff_train_metrics:
if not isinstance(state.loss, torch.Tensor):
raise NotImplementedError('Multiple losses not supported yet')
Expand All @@ -46,7 +46,7 @@ def batch_end(self, state: State, logger: Logger):
value = state.train_metric_values[k]
self.train_prev_metric[k] = value

def eval_end(self, state: State, logger: Logger):
def eval_end(self, state: State, logger: Logger) -> None:
if self.diff_eval_metrics:
evaluator = state.dataloader_label
assert evaluator is not None, 'dataloader should have been set'
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/generate_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def init(self, state: State, logger: Logger):
if isinstance(destination, WandBLogger):
self.wandb_logger = destination

def batch_checkpoint(self, state: State, logger: Logger):
def batch_checkpoint(self, state: State, logger: Logger) -> None:
if (state.timestamp.batch.value % self.batch_log_interval) == 0:
self.generate(state, logger)

def generate(self, state: State, logger: Logger):
def generate(self, state: State, logger: Logger) -> None:
model = state.model
original_mode = model.training
model.eval()
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,24 @@ def __init__(self,
else:
self.remote_ud = None

def init(self, state: State, logger: Logger):
def init(self, state: State, logger: Logger) -> None:
if self.upload_to_object_store and self.remote_ud is not None:
self.remote_ud.init(state, logger)
# updated_logger_destinations = [*logger.destinations, new_remote_ud]
# logger.destinations = tuple(updated_logger_destinations)
state.callbacks.append(self.remote_ud)

def batch_checkpoint(self, state: State, logger: Logger):
def batch_checkpoint(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.batch_interval == 0:
self._save_checkpoint(state, logger)

def fit_end(self, state: State, logger: Logger):
def fit_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.batch_interval != 0:
self._save_checkpoint(state, logger)

def _save_checkpoint(self, state: State, logger: Logger):
def _save_checkpoint(self, state: State, logger: Logger) -> None:
del logger # unused

filename = format_name_with_dist_and_time(self.filename_format_str,
state.run_name,
state.timestamp)
Expand Down
8 changes: 6 additions & 2 deletions llmfoundry/callbacks/resumption_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, lr_scale: float, wd_pct: float = 0.0):
self.lr_scale = lr_scale
self.wd_pct = wd_pct

def fit_start(self, state: State, logger: Logger):
def fit_start(self, state: State, logger: Logger) -> None:
del logger # unused

if hasattr(state, 'optimizer') and state.optimizers is None:
raise Exception('No optimizers defined')
for optimizer in state.optimizers:
Expand Down Expand Up @@ -65,7 +67,9 @@ class LayerFreezing(Callback):
def __init__(self, layer_names: List[str]):
self.layer_names = set(layer_names)

def fit_start(self, state: State, logger: Logger):
def fit_start(self, state: State, logger: Logger) -> None:
del logger # unused

model_layers = set(name for name, _ in state.model.named_parameters())
for layer in self.layer_names:
if layer not in model_layers:
Expand Down
22 changes: 16 additions & 6 deletions llmfoundry/callbacks/scheduled_gc_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def gc_cuda():
"""Gargage collect Torch (CUDA) memory."""
"""Garbage collect Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
Expand All @@ -32,15 +32,19 @@ def __init__(
self.eval_keep_disabled = eval_keep_disabled
self.gc_init_state = None

def fit_start(self, state: State, logger: Logger):
def fit_start(self, state: State, logger: Logger) -> None:
del state, logger # unused

# cache if automatic garbage collection is enabled; reset at fit_end
self.gc_init_state = gc.isenabled()

# disable automatic garbage collection
gc.disable()
gc_cuda()

def fit_end(self, state: State, logger: Logger):
def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused

gc_cuda()

# reset automatic garbage collection at fit_end
Expand All @@ -49,16 +53,22 @@ def fit_end(self, state: State, logger: Logger):
else:
gc.disable()

def before_dataloader(self, state: State, logger: Logger):
def before_dataloader(self, state: State, logger: Logger) -> None:
del logger # unused

if state.timestamp.batch.value % self.batch_interval == 0:
gc_cuda()

def eval_start(self, state: State, logger: Logger):
def eval_start(self, state: State, logger: Logger) -> None:
del state, logger # unused

gc_cuda()
if not self.eval_keep_disabled:
gc.enable()

def eval_end(self, state: State, logger: Logger):
def eval_end(self, state: State, logger: Logger) -> None:
del state, logger # unused

if not self.eval_keep_disabled:
gc.disable()

Expand Down
7 changes: 4 additions & 3 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,11 @@ def __init__(
'`span_mean_lengths_and_ratios` and/or `sequence_mask_ratios`.')

@property
def smallest_max_raw_length(self):
def smallest_max_raw_length(self) -> int:
return int(self._smallest_max_raw_length)

@property
def largest_max_raw_length(self):
def largest_max_raw_length(self) -> int:
return int(self._largest_max_raw_length)

def __call__(self, examples: List[Dict[str,
Expand Down Expand Up @@ -613,7 +613,8 @@ def noise_token_sequence(

def _get_max_starting_length(max_length: int, mask_ratio: float,
mean_span_length: float, n_prefix_tokens: int,
decoder_only_format: bool, context_eos: bool):
decoder_only_format: bool,
context_eos: bool) -> int:
"""Get max num raw tokens that will fit max_length."""

def sequence_stats(length: int):
Expand Down
11 changes: 6 additions & 5 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
from typing import Union
from typing import Tuple, Union

import datasets as hf_datasets
import torch
Expand Down Expand Up @@ -207,7 +207,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)


def _validate_config(dataset_cfg: DictConfig):
def _validate_config(dataset_cfg: DictConfig) -> None:
"""Validates the dataset configuration.
Makes sure that the dataset is properly configured for either
Expand Down Expand Up @@ -352,9 +352,10 @@ def _build_hf_dataset_from_remote(
return dataset


def _build_collate_fn(dataset_cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int):
def _build_collate_fn(
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
Expand Down
34 changes: 20 additions & 14 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import logging
import os
import warnings
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import datasets as hf_datasets
from omegaconf import DictConfig
Expand All @@ -47,8 +47,9 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
__all__ = ['dataset_constructor']


def _tokenize_formatted_example(example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase):
def _tokenize_formatted_example(
example: Dict[str, Any],
tokenizer: PreTrainedTokenizerBase) -> Dict[str, List[int]]:
if ('prompt' not in example) or ('response' not in example):
raise KeyError(
'Unable to tokenize example because it has not been properly formatted. ' +\
Expand Down Expand Up @@ -150,7 +151,7 @@ class DatasetConstructor:
def __init__(self):
self._task_preprocessing_registry: Dict[str, Callable] = {}

def register(self, *names: str):
def register(self, *names: str) -> Callable[[Callable], Callable]:
"""Decorator for registering preprocessing functions."""

def _register_func(name: str, func: Callable) -> None:
Expand All @@ -168,11 +169,13 @@ def wrapper(func: Callable) -> Callable:

return wrapper

def print_registered_tasks(self):
def print_registered_tasks(self) -> None:
tasks = sorted(self._task_preprocessing_registry.keys())
print('\n'.join(tasks))

def get_preprocessing_fn_from_dict(self, mapping: Union[Dict, DictConfig]):
def get_preprocessing_fn_from_dict(
self, mapping: Union[Dict, DictConfig]
) -> Callable[[Dict[str, Any]], Dict[str, str]]:
"""Get a preprocessing function from a dictionary.
The dictionary maps column names in the dataset to "prompt" and "response".
Expand Down Expand Up @@ -206,9 +209,11 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]:

return _preprocessor

def get_preprocessing_fn_from_str(self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None):
def get_preprocessing_fn_from_str(
self,
preprocessor: Optional[str],
dataset_name: Optional[str] = None
) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
"""Get a preprocessing function from a string.
String can be either a registered function or an import path.
Expand Down Expand Up @@ -319,15 +324,16 @@ def dataset_mapper(example: Dict):

return empty_examples_dropped_dataset

def build_from_streaming(self, *args: Any, **kwargs: Any):
def build_from_streaming(self, *args: Any,
**kwargs: Any) -> StreamingFinetuningDataset:
return StreamingFinetuningDataset(*args, **kwargs)


dataset_constructor = DatasetConstructor()


@dataset_constructor.register('tatsu-lab/alpaca')
def alpaca_preprocessing_function(inp: Dict):
def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
"""Split out prompt/response from text."""
try:
prompt, response = inp['text'].split('### Response:')
Expand All @@ -340,7 +346,7 @@ def alpaca_preprocessing_function(inp: Dict):


@dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
def dolly_preprocessing_function(inp: Dict):
def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
"""Format the text string."""
PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
try:
Expand All @@ -357,7 +363,7 @@ def dolly_preprocessing_function(inp: Dict):


@dataset_constructor.register('bigscience/P3')
def p3_preprocessing_function(inp: Dict):
def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
"""Format the already-split example."""
return {
'prompt': inp['inputs'] + ':',
Expand All @@ -367,7 +373,7 @@ def p3_preprocessing_function(inp: Dict):

# Muennighoff's P3 and flan datasets share a similar convention
@dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
def muennighoff_tokenize_function(inp: Dict):
def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
"""Format the already-split example."""
try:
prompt: str = inp['inputs']
Expand Down
12 changes: 7 additions & 5 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def __init__(self,
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []

@property
def waste(self):
def waste(self) -> float:
return 1 - (self.n_packed_tokens / self.n_total_tokens)

@property
def efficiency(self):
def efficiency(self) -> float:
return self.n_packed_tokens / (self.max_seq_len *
self.n_packed_examples)

Expand Down Expand Up @@ -100,7 +100,8 @@ def __call__(
return batch


def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int):
def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}

keep = example['attention_mask'] == 1
Expand All @@ -111,8 +112,9 @@ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int):
return size, trim_example


def combine_in_place(example: Dict[str, torch.Tensor],
add_on: Dict[str, torch.Tensor]):
def combine_in_place(
example: Dict[str, torch.Tensor],
add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if 'labels' in add_on:
# Prevents the last token in example from being trained to
# predict the first token in add_on, which would make no sense.
Expand Down
Loading

0 comments on commit 48ba632

Please sign in to comment.