Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Jan 22, 2024
1 parent 9e31fa6 commit 34cd00a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 9 deletions.
4 changes: 2 additions & 2 deletions composer/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
def __init__(
self,
stop_sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
batch_size: int,
) -> None:
self.done_tracker = [False] * batch_size
Expand Down Expand Up @@ -213,7 +213,7 @@ def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwar
return False not in self.done_tracker

def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast],
stop_sequences: List[str],
batch_size: int,
) -> transformers.StoppingCriteriaList:
Expand Down
5 changes: 0 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import os
import random
import re
#import sys
import tempfile
import textwrap
import time
Expand Down Expand Up @@ -923,10 +922,6 @@ def __init__(
# compile config for PyTorch 2.0 or higher
compile_config: Optional[Dict[str, Any]] = None,
):

# Check if the current Python version is compatible
# major, minor = sys.version_info[0], sys.version_info[1]
# assert (major == 3 and minor <= 8) or (major < 3), f"Python version {major}.{minor} is not supported. Please use Python 3.9 or higher."

self.auto_log_hparams = auto_log_hparams
self.python_log_level = python_log_level
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer):
seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq1 = [50257] * (len(seq2) - len(seq1)) + seq1
input_ids = torch.tensor([seq1, seq2])
input_ids = torch.LongTensor([seq1, seq2])
assert not eos_criteria(input_ids, None)

eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2)
seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids']
input_ids = torch.tensor([seq1, seq2])
input_ids = torch.LongTensor([seq1, seq2])
assert eos_criteria(input_ids, None)


Expand Down

0 comments on commit 34cd00a

Please sign in to comment.