From 7193ae950859cd637c8246ee52407b2ed012289c Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 22 Oct 2023 23:35:35 +0900 Subject: [PATCH] feat: test env set for wandb and recommendation --- tests/test_validation.py | 72 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/tests/test_validation.py b/tests/test_validation.py index 35d90a2cb4..f845e2f39e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,6 +1,7 @@ """Module for testing the validation module""" import logging +import os import unittest from typing import Optional @@ -8,6 +9,7 @@ from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +from axolotl.utils.wandb_ import setup_wandb_env_vars class ValidationTest(unittest.TestCase): @@ -565,3 +567,73 @@ def test_no_conflict_eval_strategy(self): ) validate_config(cfg) + + def test_wandb_rename_run_id_to_name(self): + cfg = DictDefault( + { + "wandb_run_id": "foo", + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "wandb_run_id is not recommended anymore. Please use wandb_name instead." + in record.message + for record in self._caplog.records + ) + + assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None + + cfg = DictDefault( + { + "wandb_name": "foo", + } + ) + + validate_config(cfg) + + def test_wandb_sets_env(self): + cfg = DictDefault( + { + "wandb_project": "foo", + "wandb_name": "bar", + "wandb_entity": "baz", + "wandb_mode": "online", + "wandb_watch": "false", + "wandb_log_model": "checkpoint", + } + ) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_PROJECT", "") == "foo" + assert os.environ.get("WANDB_NAME", "") == "bar" + assert os.environ.get("WANDB_ENTITY", "") == "baz" + assert os.environ.get("WANDB_MODE", "") == "online" + assert os.environ.get("WANDB_WATCH", "") == "false" + assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint" + assert os.environ.get("WANDB_DISABLED", "") != "true" + + def test_wandb_set_disabled(self): + cfg = DictDefault({}) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_DISABLED", "") == "true" + + cfg = DictDefault( + { + "wandb_project": "foo", + } + ) + + validate_config(cfg) + + setup_wandb_env_vars(cfg) + + assert os.environ.get("WANDB_DISABLED", "") != "true"