diff --git a/tuning/config/tracker_configs.py b/tuning/config/tracker_configs.py index 5a8781375..030ea9241 100644 --- a/tuning/config/tracker_configs.py +++ b/tuning/config/tracker_configs.py @@ -14,6 +14,7 @@ # Standard from dataclasses import dataclass +from typing import List, Optional @dataclass @@ -62,6 +63,11 @@ def __post_init__(self): + "/" ) +@dataclass +class WandBConfig: + project: str = 'fms-hf-tuning' # experiment / project name + entity: Optional[str] = None + @dataclass class TrackerConfigFactory: diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 096099306..65e272cdd 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -25,9 +25,10 @@ # Information about all registered trackers AIMSTACK_TRACKER = "aim" +WANDB_TRACKER = "wandb" FILE_LOGGING_TRACKER = "file_logger" -AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER] +AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, WANDB_TRACKER] # Trackers which can be used @@ -35,7 +36,7 @@ # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") - +_is_wandb_available = _is_package_available("wandb") def _get_tracker_class(T, C): return {"tracker": T, "config": C} @@ -59,10 +60,30 @@ def _register_aim_tracker(): "\t pip install aim" ) +def _register_wandb_tracker(): + # pylint: disable=import-outside-toplevel + if _is_wandb_available: + # Local + from .wandb_tracker import WandBTracker + from tuning.config.tracker_configs import WandBConfig + + WandbTracker = _get_tracker_class(WandBTracker, WandBConfig) + + REGISTERED_TRACKERS[WANDB_TRACKER] = WandbTracker + logger.info("Registered wandb tracker") + else: + logger.info( + "Not registering WANDB due to unavailablity of package.\n" + "Please install wandb if you intend to use it.\n" + "\t pip install wandb" + ) + def _is_tracker_installed(name): if name == "aim": return _is_aim_available + if name == "wandb": + return _is_wandb_available return False @@ -79,6 +100,8 @@ def _register_trackers(): logging.info("Registering trackers") if AIMSTACK_TRACKER not in REGISTERED_TRACKERS: _register_aim_tracker() + if WANDB_TRACKER not in REGISTERED_TRACKERS: + _register_wandb_tracker() if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: _register_file_logging_tracker() diff --git a/tuning/trackers/wandb_tracker.py b/tuning/trackers/wandb_tracker.py new file mode 100644 index 000000000..6d1ce34f7 --- /dev/null +++ b/tuning/trackers/wandb_tracker.py @@ -0,0 +1,64 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import json +import os + +# Third Party +import wandb +from transformers.integrations import WandbCallback +from transformers.utils import logging + +# Local +from .tracker import Tracker +from tuning.config.tracker_configs import WandBConfig + +class WandBTracker(Tracker): + def __init__(self, tracker_config: WandBConfig): + """Tracker which uses Wandb to collect and store metrics. + """ + super().__init__(name="aim", tracker_config=tracker_config) + self.logger = logging.get_logger("wandb_tracker") + + def get_hf_callback(self): + """Returns the WandBCallback object associated with this tracker. + """ + c = self.config + project = c.project + entity = c.entity + + run = wandb.init(project=project, entity=entity) + WandbCallback = WandbCallback() + + self.run = run + self.hf_callback = WandbCallback + return self.hf_callback + + def _wandb_log(self, data, name): + self.run.log({name: data}) + + def track(self, metric, name, stage): + """Track any additional metric with name under Aimstack tracker. + """ + if metric is None or name is None: + raise ValueError( + "wandb track function should not be called with None metric value or name" + ) + self._wandb_log(metric, name) + + def set_params(self, params, name="extra_params"): + """Attach any extra params with the run information stored in Aimstack tracker. + """ + self.run.log(params)