Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text support to the Trainer's TensorBoard integration #34418

Merged
merged 8 commits into from
Nov 4, 2024
2 changes: 2 additions & 0 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def on_log(self, args, state, control, logs=None, **kwargs):
for k, v in logs.items():
if isinstance(v, (int, float)):
self.tb_writer.add_scalar(k, v, state.global_step)
elif isinstance(v, str):
self.tb_writer.add_text(k, v, state.global_step)
else:
logger.warning(
"Trainer is attempting to log a value of "
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ class ProgressCallback(TrainerCallback):
def __init__(self):
self.training_bar = None
self.prediction_bar = None
self.max_str_len = 100

def on_train_begin(self, args, state, control, **kwargs):
if state.is_world_process_zero:
Expand Down Expand Up @@ -631,7 +632,10 @@ def on_log(self, args, state, control, logs=None, **kwargs):
# but avoid doing any value pickling.
shallow_logs = {}
for k, v in logs.items():
shallow_logs[k] = v
if isinstance(v, str) and len(v) > self.max_str_len:
shallow_logs[k] = f"[String too long to display, length: {len(v)}]"
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
else:
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console
if "epoch" in shallow_logs:
Expand Down