Skip to content

Commit

Permalink
fix typing and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 27, 2023
1 parent 13cbc13 commit 204d2f7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
9 changes: 3 additions & 6 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
}


def build_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
device_batch_size: int,
) -> DataSpec:
def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
"""Builds a dataloader from a config.
Args:
Expand All @@ -33,7 +30,7 @@ def build_dataloader(
"""

if cfg.name not in LOADER_NAME_TO_FUNCTION:
allowed = ", ".join(LOADER_NAME_TO_FUNCTION.keys())
allowed = ', '.join(LOADER_NAME_TO_FUNCTION.keys())
raise ValueError(f'Expected dataloader name to be one of {allowed}' +
f' but found name "{cfg.name}" in config: {cfg}')

Expand Down
24 changes: 14 additions & 10 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,26 @@


def build_eval_loader(
eval_loader_config: DictConfig,
eval_loader_config: Union[DictConfig, ListConfig],
model: Union[Any, ComposerHFCausalLM],
tokenizer: PreTrainedTokenizerBase,
device_eval_batch_size: int,
) -> Evaluator:
) -> List[Evaluator]:
assert model.train_metrics is not None
eval_metric_names = list(model.train_metrics.keys())

evaluators = []
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_configs
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False

is_multi_eval = isinstance(eval_loader_config, ListConfig)
eval_configs = eval_loader_config if is_multi_eval else [eval_loader_config]
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer,
device_eval_batch_size)
eval_loader = Evaluator(
eval_loader: Evaluator = Evaluator(
label=f'eval/{eval_config.label}' if is_multi_eval else 'eval',
dataloader=eval_dataloader,
metric_names=eval_metric_names,
Expand Down Expand Up @@ -220,8 +224,8 @@ def build_tokenizer(

signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'

if dist.is_available() and dist.is_initialized() and dist.get_world_size(
) > 1:
if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
# Make sure the tokenizer files are downloaded and cached first by local rank 0
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
Expand All @@ -240,8 +244,8 @@ def build_tokenizer(
int(1e30),
)

if dist.is_available() and dist.is_initialized() and dist.get_world_size(
) > 1:
if dist.is_available() and dist.is_initialized(
) and dist.get_world_size() > 1:
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')
Expand Down
5 changes: 2 additions & 3 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,8 @@ def main(cfg: DictConfig):
if eval_gauntlet_df is not None and eval_gauntlet_callback is not None:
assert composite_scores is not None
row = {'model_name': model_cfg['model_name']}
row.update({
k.split('/')[-1]: v for k, v in composite_scores.items()
})
row.update(
{k.split('/')[-1]: v for k, v in composite_scores.items()})
eval_gauntlet_df = pd.concat(
[eval_gauntlet_df, pd.DataFrame([row])], ignore_index=True)

Expand Down

0 comments on commit 204d2f7

Please sign in to comment.