Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Jan 1, 2024
1 parent 47f80f4 commit 5c2129a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
16 changes: 8 additions & 8 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def load_datasets(
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_datasets: List[Any] = []
for i, ds_cfg in enumerate(cfg.datasets):
Expand All @@ -340,7 +340,7 @@ def load_rl_datasets(
# )
eval_dataset = None

def argilla_apply_chatml(sample):
def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
Expand All @@ -354,7 +354,7 @@ def argilla_apply_chatml(sample):
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
return sample

def intel_apply_chatml(sample):
def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
Expand All @@ -368,7 +368,7 @@ def intel_apply_chatml(sample):
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def apply_chatml(sample):
def apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
Expand All @@ -382,7 +382,7 @@ def apply_chatml(sample):
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample

def ultra_apply_chatml(sample):
def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
Expand All @@ -396,10 +396,10 @@ def ultra_apply_chatml(sample):
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

for i, ds in enumerate(train_datasets):
for i, data_set in enumerate(train_datasets):
_type = cfg.datasets[i]["type"]
fn = locals()[_type]
train_datasets[i] = ds.map(fn)
ds_type_fn = locals()[_type]
train_datasets[i] = data_set.map(ds_type_fn)
train_dataset = concatenate_datasets(train_datasets)

# eval_dataset = eval_dataset.map(intel_apply_chatml)
Expand Down
17 changes: 15 additions & 2 deletions src/axolotl/core/trainers/trl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer


class TRLPPOTrainer(PPOTrainer):
def train(self, reward_pipe, resume_from_checkpoint=None):
"""
wrapper for ppo trainer to handle customizations
"""

def train(
self,
reward_pipe,
resume_from_checkpoint=None, # pylint: disable=unused-argument
):
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
Expand All @@ -19,7 +30,9 @@ def train(self, reward_pipe, resume_from_checkpoint=None):
"batch_size": 16,
}

for epoch, batch in tqdm(enumerate(self.dataloader)):
for epoch, batch in tqdm( # pylint: disable=unused-variable
enumerate(self.dataloader)
):
query_tensors = batch["input_ids"]

# generate model response
Expand Down
4 changes: 1 addition & 3 deletions tests/core/test_trainer_builder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
unit tests for axolotl.core.trainer_builder
"""
import unittest

import pytest

from axolotl.core.trainer_builder import HFDPOTrainerBuilder
Expand Down Expand Up @@ -45,7 +43,7 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)


class TestHFDPOTrainerBuilder(unittest.TestCase):
class TestHFDPOTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
Expand Down

0 comments on commit 5c2129a

Please sign in to comment.