Skip to content

Commit

Permalink
reduce test concurrency to avoid HF rate limiting, test suite parity (#…
Browse files Browse the repository at this point in the history
…2128)

* reduce test concurrency to avoid HF rate limiting, test suite parity

* make val_set_size smaller to speed up e2e tests

* more retries for pytest fixture downloads

* val_set_size was too small

* move retry_on_request_exceptions to data utils and add retry strategy

* pre-download ultrafeedback as a test fixture

* refactor download retry into it's own fn

* don't import from data utils

* use retry mechanism now for fixtures
  • Loading branch information
winglian authored Dec 6, 2024
1 parent 08fa133 commit 5e9fa33
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 47 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/tests-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20

steps:
Expand Down Expand Up @@ -55,6 +61,7 @@ jobs:
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
timeout-minutes: 20

steps:
Expand Down Expand Up @@ -95,6 +101,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
Expand Down Expand Up @@ -124,6 +131,8 @@ jobs:
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Ensure axolotl CLI was installed
Expand Down
29 changes: 5 additions & 24 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union

import requests
from datasets import (
Dataset,
DatasetDict,
Expand Down Expand Up @@ -44,7 +42,11 @@
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
retry_on_request_exceptions,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
Expand All @@ -55,27 +57,6 @@
LOG = logging.getLogger("axolotl")


def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc

return wrapper

return decorator


@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
Expand Down
46 changes: 45 additions & 1 deletion src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,57 @@
"""data handling helpers"""

import functools
import hashlib
import logging
import time
from enum import Enum

import huggingface_hub
import requests
from datasets import Dataset

LOG = logging.getLogger("axolotl")


class RetryStrategy(Enum):
"""
Enum for retry strategies.
"""

CONSTANT = 1
LINEAR = 2
EXPONENTIAL = 3


def retry_on_request_exceptions(
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
if retry_strategy == RetryStrategy.EXPONENTIAL:
step_delay = delay * 2**attempt
elif retry_strategy == RetryStrategy.LINEAR:
step_delay = delay * (attempt + 1)
else:
step_delay = delay # Use constant delay.
time.sleep(step_delay)
else:
raise exc

return wrapper

return decorator


def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
Expand Down
60 changes: 49 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,107 @@
"""
shared pytest fixtures
"""
import functools
import shutil
import tempfile
import time

import pytest
import requests
from huggingface_hub import snapshot_download


def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc

return wrapper

return decorator


@retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)


@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
snapshot_download("HuggingFaceTB/SmolLM2-135M")
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")


@pytest.fixture(scope="session", autouse=True)
def download_llama_68m_random_model():
# download the model
snapshot_download("JackFram/llama-68m")
snapshot_download_w_retry("JackFram/llama-68m")


@pytest.fixture(scope="session", autouse=True)
def download_qwen_2_5_half_billion_model():
# download the model
snapshot_download("Qwen/Qwen2.5-0.5B")
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")


@pytest.fixture(scope="session", autouse=True)
def download_tatsu_lab_alpaca_dataset():
# download the dataset
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")


@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_dataset():
# download the dataset
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")


@pytest.fixture(scope="session", autouse=True)
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
# download the dataset
snapshot_download(
snapshot_download_w_retry(
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
)


@pytest.fixture(scope="session", autouse=True)
def download_mlabonne_finetome_100k_dataset():
# download the dataset
snapshot_download("mlabonne/FineTome-100k", repo_type="dataset")
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")


@pytest.fixture
@pytest.fixture(scope="session", autouse=True)
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
# download the dataset
snapshot_download(
snapshot_download_w_retry(
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
)


@pytest.fixture
@pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset
snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
)


@pytest.fixture(scope="session", autouse=True)
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
# download the dataset
snapshot_download(
snapshot_download_w_retry(
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
)

Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_4d_multipack_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_sdp_lora_packing(self, temp_dir):
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.02,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_torch_lora_packing(self, temp_dir):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.02,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_falcon_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_qlora(self, temp_dir):
"lora_dropout": 0.1,
"lora_target_linear": True,
"lora_modules_to_save": ["word_embeddings", "lm_head"],
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_ft(self, temp_dir):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"bos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/patched/test_fused_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_fft_packing(self, temp_dir):
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/patched/test_lora_llama_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_lora_gptq_packed(self, temp_dir):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_mistral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_lora_packing(self, temp_dir):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_ft_packing(self, temp_dir):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_mixtral_samplepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_qlora(self, temp_dir):
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {},
"datasets": [
{
Expand Down Expand Up @@ -78,7 +78,7 @@ def test_ft(self, temp_dir):
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {},
"datasets": [
{
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/patched/test_phi_multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_ft_packed(self, temp_dir):
"pad_to_sequence_len": True,
"load_in_8bit": False,
"adapter": None,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
Expand Down

0 comments on commit 5e9fa33

Please sign in to comment.