-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[#4] Add VESSL Callback to Post Metrics to VESSL AI #6
Changes from 5 commits
29c2c36
ebf2b36
f780647
ecdff97
5b0aa90
b917faa
0b76c97
74511a4
22d5f97
f4437b0
b89a91a
16d9813
1f17adb
370b9e5
213bc6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -836,6 +836,13 @@ def get_callbacks(self) -> List[TrainerCallback]: | |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) | ||
) | ||
|
||
if self.cfg.use_vessl: | ||
from axolotl.utils.callbacks.vessl_ import VesslLogMetricsCallback | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you tell me why you added a suffix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They added underscore too on mlflow ( |
||
|
||
callbacks.append( | ||
VesslLogMetricsCallback(self.cfg.vessl_credential_path) | ||
) | ||
|
||
return callbacks | ||
|
||
@abstractmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""Vessl module for trainer callbacks""" | ||
import logging | ||
from typing import Dict, List | ||
|
||
import vessl | ||
from transformers import TrainerCallback, TrainerControl, TrainerState | ||
from transformers.training_args import TrainingArguments | ||
|
||
LOG = logging.getLogger("axolotl.callbacks") | ||
|
||
class VesslLogMetricsCallback(TrainerCallback): | ||
|
||
def __init__(self, credential_path: str, metrics: List[str]) -> None: | ||
vessl.configure(credentials_file=credential_path) | ||
|
||
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: Dict[str, float] = None, **kwargs): | ||
if state.is_world_process_zero: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain where you copied this code from, please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I take it from wandb integration: The difference between |
||
vessl.log(logs, state.global_step) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
"""Module for vessl utilities""" | ||
|
||
import os | ||
|
||
from axolotl.utils.dict import DictDefault | ||
|
||
|
||
def setup_vessl_env_vars(cfg: DictDefault): | ||
# VESSL_RUN_INITIAL_CONFIG is a variable that contain path to | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I cannot find any references explaining this variable. Can you attach a document pointing this variable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got it from container environment variables, will try to attach a screenshot |
||
# default credential inside a VESSL Run | ||
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should not override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if cfg.vessl_credential_path:
return
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG")
if credential_path:
cfg.vessl_credential_path = credential_path |
||
if credential_path: | ||
cfg.use_vessl = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In what case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but you are using |
||
cfg.vessl_credential_path = credential_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please note that usually, it is better you check the condition from where you call the function when the function does nothing when the condition is not met.
LGTM for now to keep it consistent since line 371 does the same as yours.