Skip to content

Commit

Permalink
fix inference and phi e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 15, 2023
1 parent 8cc732c commit 2f99c6e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 23 deletions.
11 changes: 7 additions & 4 deletions src/axolotl/models/phi/modeling_mixformer_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,11 +915,14 @@ def forward(
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
else:
hidden_layer = self.layers[0](
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(hidden_layer, past_cache=past_key_values)
hidden_layer = module(
hidden_layer,
past_cache=past_key_values,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
lm_logits = self.layers[-1](hidden_layer)

loss = None
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
last_run_prepared
23 changes: 4 additions & 19 deletions tests/e2e/test_lora_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,23 @@
import tempfile
import unittest

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import TrainDatasetMeta, train
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_tokenizer

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint:disable=unused-argument
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)

train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)

return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)


class TestLoraLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""

def test_lora(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down Expand Up @@ -80,6 +64,7 @@ def test_lora(self):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

def test_lora_packing(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
Expand Down
109 changes: 109 additions & 0 deletions tests/e2e/test_phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
E2E tests for lora llama
"""

import logging
import os
import tempfile
import unittest

from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"


class TestPhi(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""

def test_ft(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"load_in_8bit": True,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

def test_ft_packed(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": True,
"load_in_8bit": True,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

0 comments on commit 2f99c6e

Please sign in to comment.