Skip to content

Commit

Permalink
Add pytest testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Nov 29, 2024
1 parent 2a61fc0 commit c7cadbc
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 2 deletions.
54 changes: 54 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

name: Python package

on: [push]

jobs:
build:

runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel

- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
- name: Upload pytest test results
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures
if: ${{ always() }}
4 changes: 2 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Optional,
Sequence,
Tuple,
Union,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
Expand Down Expand Up @@ -4598,7 +4598,7 @@ def task():
accelerator.load_state(dirname)


def get_optimizer(args, trainable_params):
def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"

optimizer_type = args.optimizer_type
Expand Down
7 changes: 7 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[pytest]
minversion = 6.0
testpaths =
tests
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
153 changes: 153 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from unittest.mock import patch
from library.train_util import get_optimizer
from train_network import setup_parser
import torch
from torch.nn import Parameter

# Optimizer libraries
import bitsandbytes as bnb
from lion_pytorch import lion_pytorch
import schedulefree

import dadaptation
import dadaptation.experimental as dadapt_experimental

import prodigyopt
import schedulefree as sf
import transformers


def test_default_get_optimizer():
with patch("sys.argv", [""]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])

param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "torch.optim.adamw.AdamW"
assert optimizer_args == ""
assert isinstance(optimizer, torch.optim.AdamW)


def test_get_schedulefree_optimizer():
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])

param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
assert optimizer_args == ""
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)


def test_all_supported_optimizers():
optimizers = [
{
"name": "bitsandbytes.optim.adamw.AdamW8bit",
"alias": "AdamW8bit",
"instance": bnb.optim.AdamW8bit,
},
{
"name": "lion_pytorch.lion_pytorch.Lion",
"alias": "Lion",
"instance": lion_pytorch.Lion,
},
{
"name": "torch.optim.adamw.AdamW",
"alias": "AdamW",
"instance": torch.optim.AdamW,
},
{
"name": "bitsandbytes.optim.lion.Lion8bit",
"alias": "Lion8bit",
"instance": bnb.optim.Lion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
"alias": "PagedAdamW8bit",
"instance": bnb.optim.PagedAdamW8bit,
},
{
"name": "bitsandbytes.optim.lion.PagedLion8bit",
"alias": "PagedLion8bit",
"instance": bnb.optim.PagedLion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW",
"alias": "PagedAdamW",
"instance": bnb.optim.PagedAdamW,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
"alias": "PagedAdamW32bit",
"instance": bnb.optim.PagedAdamW32bit,
},
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
{
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
"alias": "DAdaptAdamPreprint",
"instance": dadapt_experimental.DAdaptAdamPreprint,
},
{
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
"alias": "DAdaptAdaGrad",
"instance": dadaptation.DAdaptAdaGrad,
},
{
"name": "dadaptation.dadapt_adan.DAdaptAdan",
"alias": "DAdaptAdan",
"instance": dadaptation.DAdaptAdan,
},
{
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
"alias": "DAdaptAdanIP",
"instance": dadapt_experimental.DAdaptAdanIP,
},
{
"name": "dadaptation.dadapt_lion.DAdaptLion",
"alias": "DAdaptLion",
"instance": dadaptation.DAdaptLion,
},
{
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
"alias": "DAdaptSGD",
"instance": dadaptation.DAdaptSGD,
},
{
"name": "prodigyopt.prodigy.Prodigy",
"alias": "Prodigy",
"instance": prodigyopt.Prodigy,
},
{
"name": "transformers.optimization.Adafactor",
"alias": "Adafactor",
"instance": transformers.optimization.Adafactor,
},
{
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
"alias": "AdamWScheduleFree",
"instance": sf.AdamWScheduleFree,
},
{
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
"alias": "SGDScheduleFree",
"instance": sf.SGDScheduleFree,
},
]

for opt in optimizers:
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])

param = Parameter(params_t)
optimizer_name, _, optimizer = get_optimizer(args, [param])
assert optimizer_name == opt.get("name")

instance = opt.get("instance")
assert instance is not None
assert isinstance(optimizer, instance)

0 comments on commit c7cadbc

Please sign in to comment.