Skip to content

Commit

Permalink
add finetome dataset to fixtures, check eval_loss in test (#2106) [sk…
Browse files Browse the repository at this point in the history
…ip ci]

* add finetome dataset to fixtures, check eval_loss in test

* add qwen 0.5b to pytest session fixture
  • Loading branch information
winglian authored Nov 30, 2024
1 parent 724b660 commit 6e0fb4a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ def download_smollm2_135m_model():
snapshot_download("HuggingFaceTB/SmolLM2-135M")


@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download("Qwen/Qwen2.5-0.5B")


@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the model
Expand All @@ -26,6 +32,11 @@ def download_mhenrichsen_alpaca_2k_dataset():
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")


def download_mlabonne_finetome_100k_dataset():
# download the model
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")


@pytest.fixture
def temp_dir():
# Create a temporary directory
Expand Down
29 changes: 23 additions & 6 deletions tests/e2e/multigpu/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port

from axolotl.utils.dict import DictDefault

from ..utils import most_recent_subdir

LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

Expand All @@ -26,7 +29,7 @@ def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
Expand All @@ -40,8 +43,8 @@ def test_eval_sample_packing(self, temp_dir):
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"val_set_size": 0.004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
Expand All @@ -66,6 +69,7 @@ def test_eval_sample_packing(self, temp_dir):
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
}
)

Expand All @@ -87,12 +91,18 @@ def test_eval_sample_packing(self, temp_dir):
str(Path(temp_dir) / "config.yaml"),
]
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.5, "Loss is too high"

def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_8bit": False,
"load_in_4bit": True,
"strict": False,
Expand All @@ -106,8 +116,8 @@ def test_eval(self, temp_dir):
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.1,
"special_tokens": {"pad_token": "<|end_of_text|>"},
"val_set_size": 0.0004,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
Expand All @@ -132,6 +142,7 @@ def test_eval(self, temp_dir):
"saves_per_epoch": 1,
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
}
)

Expand All @@ -153,3 +164,9 @@ def test_eval(self, temp_dir):
str(Path(temp_dir) / "config.yaml"),
]
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == "eval/loss")] # pylint: disable=invalid-name
assert df.value.values[-1] < 2.9, "Loss is too high"

0 comments on commit 6e0fb4a

Please sign in to comment.