Skip to content

Commit

Permalink
fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rifqiyan committed Mar 29, 2024
1 parent 74511a4 commit 22d5f97
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 1 addition & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,7 @@ def get_callbacks(self) -> List[TrainerCallback]:
if self.cfg.vessl_credential_path:
from axolotl.utils.callbacks.vessl_ import VesslLogMetricsCallback

callbacks.append(
VesslLogMetricsCallback(self.cfg.vessl_credential_path)
)
callbacks.append(VesslLogMetricsCallback(self.cfg.vessl_credential_path))

return callbacks

Expand Down
14 changes: 11 additions & 3 deletions src/axolotl/utils/callbacks/vessl_.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
"""Vessl module for trainer callbacks"""
import logging
from typing import Dict, List
from typing import Dict

import vessl
from transformers import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

LOG = logging.getLogger("axolotl.callbacks")


class VesslLogMetricsCallback(TrainerCallback):
"""Callback to send training metrics to VESSL AI"""

def __init__(self, credential_path: str) -> None:
vessl.configure(credentials_file=credential_path)

def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: Dict[str, float] = None, **kwargs):
def on_log(
self,
args: TrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl, # pylint: disable=unused-argument
logs: Dict[str, float],
**kwargs # pylint: disable=unused-argument
):
if state.is_world_process_zero:
vessl.log(logs, state.global_step)

0 comments on commit 22d5f97

Please sign in to comment.