-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[#4] Add VESSL Callback to Post Metrics to VESSL AI #6
Changes from all commits
29c2c36
ebf2b36
f780647
ecdff97
5b0aa90
b917faa
0b76c97
74511a4
22d5f97
f4437b0
b89a91a
16d9813
1f17adb
370b9e5
213bc6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ | |
from axolotl.utils.models import load_tokenizer | ||
from axolotl.utils.tokenization import check_dataset_labels | ||
from axolotl.utils.trainer import prepare_optim_env | ||
from axolotl.utils.vessl_ import setup_vessl_env_vars | ||
from axolotl.utils.wandb_ import setup_wandb_env_vars | ||
|
||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | ||
|
@@ -384,6 +385,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): | |
|
||
setup_mlflow_env_vars(cfg) | ||
|
||
setup_vessl_env_vars(cfg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please note that usually, it is better you check the condition from where you call the function when the function does nothing when the condition is not met. |
||
|
||
return cfg | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -874,6 +874,11 @@ def get_callbacks(self) -> List[TrainerCallback]: | |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) | ||
) | ||
|
||
if self.cfg.vessl_credential_path: | ||
from axolotl.utils.callbacks.vessl_ import VesslLogMetricsCallback | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you tell me why you added a suffix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They added underscore too on mlflow ( |
||
|
||
callbacks.append(VesslLogMetricsCallback(self.cfg.vessl_credential_path)) | ||
|
||
return callbacks | ||
|
||
@abstractmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Vessl module for trainer callbacks""" | ||
import logging | ||
from typing import Dict | ||
|
||
import vessl | ||
from transformers import TrainerCallback, TrainerControl, TrainerState | ||
from transformers.training_args import TrainingArguments | ||
|
||
LOG = logging.getLogger("axolotl.callbacks") | ||
|
||
|
||
class VesslLogMetricsCallback(TrainerCallback): | ||
"""Callback to send training metrics to VESSL AI""" | ||
|
||
def __init__(self, credential_path: str) -> None: | ||
vessl.configure(credentials_file=credential_path) | ||
|
||
def on_log( | ||
self, | ||
args: TrainingArguments, # pylint: disable=unused-argument | ||
state: TrainerState, | ||
control: TrainerControl, # pylint: disable=unused-argument | ||
logs: Dict[str, float], | ||
**kwargs # pylint: disable=unused-argument | ||
): | ||
# is_world_process_zero: Whether or not this process is the global main process (when training in a | ||
# distributed fashion on several machines, this is only going to be `True` for one process). | ||
if state.is_world_process_zero: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain where you copied this code from, please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I take it from wandb integration: The difference between |
||
vessl.log(logs, state.global_step) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Module for vessl utilities""" | ||
|
||
import os | ||
|
||
from axolotl.utils.dict import DictDefault | ||
|
||
|
||
def setup_vessl_env_vars(cfg: DictDefault): | ||
if cfg.vessl_credential_path: | ||
return | ||
|
||
# VESSL_RUN_INITIAL_CONFIG is a variable that contain path to default credential inside a VESSL Run. | ||
# Currently there is no docs regarding this variable, but it exists inside the container. | ||
# Ref: https://screen.yanolja.in/lrTGow4Pr8eXhAai.png | ||
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should not override There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if cfg.vessl_credential_path:
return
credential_path = os.environ.get("VESSL_RUN_INITIAL_CONFIG")
if credential_path:
cfg.vessl_credential_path = credential_path |
||
if credential_path: | ||
cfg.vessl_credential_path = credential_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added
too-many-lines
becausedata.py
is already almost hitting the 1000 lines limit, and with puree dataset logic it exceeds the limit.