Skip to content

Commit

Permalink
fix dictconfig stuff again
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 10, 2024
1 parent 8f1177b commit 4965cd2
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def validate_config(cfg: TrainConfig):
`label` attribute.')
loaders.append(loader)
else:
loaders.append(eval_loader)
if eval_loader is not None:
loaders.append(eval_loader)
for loader in loaders:
if loader['name'] == 'text':
if cfg.model['name'] == 'hf_t5':
Expand Down Expand Up @@ -257,32 +258,32 @@ def main(cfg: DictConfig) -> Trainer:
dist.initialize_dist(get_device(None), timeout=dist_timeout)

# Mandatory model training configs
model_config: Dict[str, Any] = scfg.model
model_config: DictConfig = DictConfig(scfg.model)
tokenizer_config: Dict[str, Any] = scfg.tokenizer
optimizer_config: Dict[str, Any] = scfg.optimizer
scheduler_config: Dict[str, Any] = scfg.scheduler
train_loader_config: Dict[str, Any] = scfg.train_loader
train_loader_config: DictConfig = DictConfig(scfg.train_loader)

# Optional fsdp data, fine-tuning, and eval configs
fsdp_config: Optional[Dict[str, Any]] = scfg.fsdp_config

if scfg.eval_loader is not None and scfg.eval_loaders is not None:
raise ValueError(
'Only one of `eval_loader` or `eval_loaders` should be provided.')
eval_loader_config: Optional[Union[Dict[str, Any], List[Dict[
str,
Any]]]] = scfg.eval_loader if scfg.eval_loader is not None else scfg.eval_loaders
icl_tasks_config: Optional[Union[
List[Dict[str, Any]],
str]] = scfg.icl_tasks if scfg.icl_tasks is not None else scfg.icl_tasks_str
eval_gauntlet_config: Optional[Union[
DictConfig,
str]] = scfg.eval_gauntlet if scfg.eval_gauntlet is not None else scfg.eval_gauntlet_str
eval_loader_config: Optional[Union[DictConfig, ListConfig]] = DictConfig(
scfg.eval_loader) if scfg.eval_loader is not None else ListConfig(
scfg.eval_loaders) if scfg.eval_loaders is not None else None
icl_tasks_config: Optional[Union[ListConfig, str]] = ListConfig(
scfg.icl_tasks) if scfg.icl_tasks is not None else scfg.icl_tasks_str
eval_gauntlet_config: Optional[Union[DictConfig, str]] = DictConfig(
scfg.eval_gauntlet
) if scfg.eval_gauntlet is not None else scfg.eval_gauntlet_str
icl_subset_num_batches: Optional[int] = scfg.icl_subset_num_batches
icl_seq_len: Optional[int] = scfg.icl_seq_len
# Optional logging, evaluation and callback configs
logger_configs: Optional[Dict[str, Any]] = scfg.loggers
callback_configs: Optional[Dict[str, Any]] = scfg.callbacks
callback_configs: Optional[DictConfig] = DictConfig(
scfg.callbacks) if scfg.callbacks else None
algorithm_configs: Optional[Dict[str, Any]] = scfg.algorithms

# Mandatory hyperparameters for training
Expand Down Expand Up @@ -392,7 +393,8 @@ def main(cfg: DictConfig) -> Trainer:

# Profiling
profiler: Optional[Profiler] = None
profiler_cfg: Optional[DictConfig] = scfg.profiler
profiler_cfg: Optional[DictConfig] = DictConfig(
scfg.profiler) if scfg.profiler is not None else None
if profiler_cfg:
profiler_schedule_cfg: Dict = pop_config(profiler_cfg,
'schedule',
Expand Down

0 comments on commit 4965cd2

Please sign in to comment.