From dda5a1e0e1e8191e570c3f5c8df77f944585f1ac Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Fri, 25 Oct 2024 18:37:57 +0000 Subject: [PATCH 1/4] feat: add text support to TensorBoardCallback --- src/transformers/integrations/integration_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index a09116552c8..be9a4aff3c7 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): for k, v in logs.items(): if isinstance(v, (int, float)): self.tb_writer.add_scalar(k, v, state.global_step) + elif isinstance(v, str): + self.tb_writer.add_text(k, v, state.global_step) else: logger.warning( "Trainer is attempting to log a value of " From ba5fd92ad02f2a4d6dbc0ffea08899fb2cd16b40 Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Fri, 25 Oct 2024 18:39:09 +0000 Subject: [PATCH 2/4] feat: ignore long strings in trainer progress --- src/transformers/trainer_callback.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 405874acf8f..03457072f95 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -594,6 +594,7 @@ class ProgressCallback(TrainerCallback): def __init__(self): self.training_bar = None self.prediction_bar = None + self.max_str_len = 100 def on_train_begin(self, args, state, control, **kwargs): if state.is_world_process_zero: @@ -631,7 +632,10 @@ def on_log(self, args, state, control, logs=None, **kwargs): # but avoid doing any value pickling. shallow_logs = {} for k, v in logs.items(): - shallow_logs[k] = v + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = f"[String too long to display, length: {len(v)}]" + else: + shallow_logs[k] = v _ = shallow_logs.pop("total_flos", None) # round numbers so that it looks better in console if "epoch" in shallow_logs: From f55c6499f53d3c72013134b95c3a0b0d94313e7b Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Mon, 28 Oct 2024 09:38:56 +0000 Subject: [PATCH 3/4] docs: add docstring for max_str_len --- src/transformers/trainer_callback.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 03457072f95..a664b735304 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -589,12 +589,21 @@ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: Tr class ProgressCallback(TrainerCallback): """ A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. """ - def __init__(self): + def __init__(self, max_str_len: int = 100): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ self.training_bar = None self.prediction_bar = None - self.max_str_len = 100 + self.max_str_len = max_str_len def on_train_begin(self, args, state, control, **kwargs): if state.is_world_process_zero: @@ -633,7 +642,10 @@ def on_log(self, args, state, control, logs=None, **kwargs): shallow_logs = {} for k, v in logs.items(): if isinstance(v, str) and len(v) > self.max_str_len: - shallow_logs[k] = f"[String too long to display, length: {len(v)}]" + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) else: shallow_logs[k] = v _ = shallow_logs.pop("total_flos", None) From a05a486247079a8a8d982370f2db793bf70e70d3 Mon Sep 17 00:00:00 2001 From: JacobLinCool Date: Mon, 4 Nov 2024 16:16:37 +0000 Subject: [PATCH 4/4] style: remove trailing whitespace --- src/transformers/trainer_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 2486b43d055..cf9a83aa188 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -598,7 +598,7 @@ def __init__(self, max_str_len: int = 100): Args: max_str_len (`int`): - Maximum length of strings to display in logs. + Maximum length of strings to display in logs. Longer strings will be truncated with a message. """ self.training_bar = None