diff --git a/composer/datasets/utils.py b/composer/datasets/utils.py index 431a860900..44186ac58e 100644 --- a/composer/datasets/utils.py +++ b/composer/datasets/utils.py @@ -179,7 +179,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): def __init__( self, stop_sequence: str, - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], batch_size: int, ) -> None: self.done_tracker = [False] * batch_size @@ -213,7 +213,7 @@ def __call__(self, input_ids, scores: Optional[torch.FloatTensor] = None, **kwar return False not in self.done_tracker def stop_sequences_criteria( - tokenizer: transformers.PreTrainedTokenizer, + tokenizer: Union[transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerFast], stop_sequences: List[str], batch_size: int, ) -> transformers.StoppingCriteriaList: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 9ebec6c373..2b9c9731a5 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -13,7 +13,6 @@ import os import random import re -#import sys import tempfile import textwrap import time @@ -923,10 +922,6 @@ def __init__( # compile config for PyTorch 2.0 or higher compile_config: Optional[Dict[str, Any]] = None, ): - - # Check if the current Python version is compatible - # major, minor = sys.version_info[0], sys.version_info[1] - # assert (major == 3 and minor <= 8) or (major < 3), f"Python version {major}.{minor} is not supported. Please use Python 3.9 or higher." self.auto_log_hparams = auto_log_hparams self.python_log_level = python_log_level diff --git a/tests/datasets/test_in_context_learning_datasets.py b/tests/datasets/test_in_context_learning_datasets.py index ec7df306d6..2a3ff87884 100644 --- a/tests/datasets/test_in_context_learning_datasets.py +++ b/tests/datasets/test_in_context_learning_datasets.py @@ -73,13 +73,13 @@ def test_stop_sequences_criteria(tiny_gpt2_tokenizer): seq1 = tiny_gpt2_tokenizer('Dogs are furry')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq1 = [50257] * (len(seq2) - len(seq1)) + seq1 - input_ids = torch.tensor([seq1, seq2]) + input_ids = torch.LongTensor([seq1, seq2]) assert not eos_criteria(input_ids, None) eos_criteria = MultiTokenEOSCriteria('\n\n', tiny_gpt2_tokenizer, 2) seq1 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] seq2 = tiny_gpt2_tokenizer('Dogs are furry\n\n')['input_ids'] - input_ids = torch.tensor([seq1, seq2]) + input_ids = torch.LongTensor([seq1, seq2]) assert eos_criteria(input_ids, None)