diff --git a/src/experiment_pipeline.py b/src/experiment_pipeline.py index 8174b21..7123e8b 100644 --- a/src/experiment_pipeline.py +++ b/src/experiment_pipeline.py @@ -76,8 +76,11 @@ def _get_random_nums_experiment_name(self): args = self.args random_num_exp_args = args.random_nums_experiment_arguments model_name = args.model_arguments.model_name_or_path if args.model_arguments.model_name_or_path else args.model_arguments.config_name - return (f'randomNums_nVars{random_num_exp_args.n_vars}_seqLen{random_num_exp_args.seq_len}_varLen{random_num_exp_args.var_len}' + experiment_name = (f'randomNums_nVars{random_num_exp_args.n_vars}_seqLen{random_num_exp_args.seq_len}_varLen{random_num_exp_args.var_len}' f'_bs{self.batch_size_string}_eps{self.epochs_string}_{model_name.split("/")[-1].replace("-","_")}') + if args.experiment_arguments.name_prefix: + experiment_name = f'{args.experiment_arguments.name_prefix}_{experiment_name}' + return experiment_name @property