Skip to content

Commit

Permalink
Fix code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 24, 2023
1 parent 8848934 commit e85fc95
Show file tree
Hide file tree
Showing 9 changed files with 8 additions and 11 deletions.
2 changes: 2 additions & 0 deletions tests/fixtures/autouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ def clear_cuda_cache(request: pytest.FixtureRequest):
torch.cuda.empty_cache()
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests


@pytest.fixture
def random_seed() -> int:
return 17


@pytest.fixture(autouse=True)
def seed_all(random_seed: int):
"""Sets the seed for reproducibility."""
Expand Down
7 changes: 5 additions & 2 deletions tests/fixtures/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,11 @@ def tiny_ft_dataloader(tiny_ft_dataset_path: str,
'timeout': 0
})

return build_finetuning_dataloader(
dataloader = build_finetuning_dataloader(
dataloader_cfg,
mpt_tokenizer,
device_batch_size,
)
).dataloader

assert isinstance(dataloader, DataLoader)
return dataloader
1 change: 0 additions & 1 deletion tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import OmegaConf as om


Expand Down
1 change: 0 additions & 1 deletion tests/test_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import AutoModelForCausalLM
Expand Down
3 changes: 0 additions & 3 deletions tests/test_hf_v_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import OmegaConf as om

from llmfoundry import COMPOSER_MODEL_REGISTRY
Expand Down Expand Up @@ -52,8 +51,6 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool,
batch_size = 2 # set batch size
device = 'cuda' # set decive



# get hf gpt2 cfg
hf_cfg = om.create({
'model': {
Expand Down
1 change: 0 additions & 1 deletion tests/test_init_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import pytest
import torch
from composer.utils import reproducibility
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from torch import nn
Expand Down
1 change: 0 additions & 1 deletion tests/test_llama_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import torch
import transformers
from composer.utils import reproducibility
from transformers.models.llama.modeling_llama import LlamaAttention

from llmfoundry.models.layers.llama_attention_monkeypatch import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from composer.core.precision import Precision, get_precision_context
from composer.optim import DecoupledAdamW
from composer.trainer.dist_strategy import prepare_fsdp_module
from composer.utils import dist, get_device, reproducibility
from composer.utils import dist, get_device
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om
from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedModel,
Expand Down
1 change: 0 additions & 1 deletion tests/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pathlib

import torch
from composer.utils import reproducibility
from transformers import AutoModelForCausalLM

from llmfoundry import MPTConfig, MPTForCausalLM
Expand Down

0 comments on commit e85fc95

Please sign in to comment.