From 7cba1e49fb995c968a5eb0d54616697a2d8d5bae Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Tue, 30 Jan 2024 15:03:54 -0800 Subject: [PATCH 1/4] throw error when no EOS --- llmfoundry/utils/builders.py | 3 +++ tests/utils/test_builders.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 42f817b386..47023bad92 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -407,6 +407,9 @@ def build_tokenizer( int(1e30), ) + if not hasattr(tokenizer, 'eos_token_id') or tokenizer.eos_token_id is None: + raise ValueError(f"The tokenizer '{tokenizer_name}' must have an 'eos_token_id'.") + if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: if dist.get_local_rank() == 0: diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 303afc9b7d..07d3fe1258 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -47,6 +47,24 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): 'model_max_length'] assert isinstance(tokenizer, PreTrainedTokenizerBase) +@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ + ('bert-base-uncased', { + 'model_max_length': 10, + 'eos_token_id': None, + }), + ('bert-base-uncased', { + 'model_max_length': 10 + }), +]) +def test_tokenizer_no_EOS(tokenizer_name: str, tokenizer_kwargs: dict): + tokenizer_kwargs = { + 'model_max_length': 10, + } + tokenizer_name = 'bert-base-uncased' + + with pytest.raises(ValueError) as exc_info: + build_tokenizer(tokenizer_name, tokenizer_kwargs) + assert "must have an 'eos_token_id'" in str(exc_info.value), "Error message for missing eos_token_id is not correct" def test_build_callback_fails(): with pytest.raises(ValueError): From e9e6db17718c4df47ba9282c75cc0466b0797fb0 Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Tue, 30 Jan 2024 15:26:37 -0800 Subject: [PATCH 2/4] lint --- llmfoundry/utils/builders.py | 3 ++- tests/utils/test_builders.py | 7 +++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 47023bad92..c9ae6211aa 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -408,7 +408,8 @@ def build_tokenizer( ) if not hasattr(tokenizer, 'eos_token_id') or tokenizer.eos_token_id is None: - raise ValueError(f"The tokenizer '{tokenizer_name}' must have an 'eos_token_id'.") + raise ValueError( + f'The tokenizer {tokenizer_name} must have an eos_token_id.') if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 07d3fe1258..7e7e482378 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -47,6 +47,7 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): 'model_max_length'] assert isinstance(tokenizer, PreTrainedTokenizerBase) + @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ ('bert-base-uncased', { 'model_max_length': 10, @@ -61,10 +62,12 @@ def test_tokenizer_no_EOS(tokenizer_name: str, tokenizer_kwargs: dict): 'model_max_length': 10, } tokenizer_name = 'bert-base-uncased' - + with pytest.raises(ValueError) as exc_info: build_tokenizer(tokenizer_name, tokenizer_kwargs) - assert "must have an 'eos_token_id'" in str(exc_info.value), "Error message for missing eos_token_id is not correct" + assert 'must have an eos_token_id.' in str( + exc_info.value), 'Error message for missing eos_token_id is not correct' + def test_build_callback_fails(): with pytest.raises(ValueError): From 4e03b2ab6ca0e884a1b762da9c7cf88b1e275ecc Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Tue, 30 Jan 2024 16:47:00 -0800 Subject: [PATCH 3/4] clean up test case --- tests/utils/test_builders.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 7e7e482378..294ea67aa8 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -48,25 +48,12 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): assert isinstance(tokenizer, PreTrainedTokenizerBase) -@pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ - ('bert-base-uncased', { - 'model_max_length': 10, - 'eos_token_id': None, - }), - ('bert-base-uncased', { - 'model_max_length': 10 - }), -]) def test_tokenizer_no_EOS(tokenizer_name: str, tokenizer_kwargs: dict): - tokenizer_kwargs = { - 'model_max_length': 10, - } - tokenizer_name = 'bert-base-uncased' - - with pytest.raises(ValueError) as exc_info: - build_tokenizer(tokenizer_name, tokenizer_kwargs) - assert 'must have an eos_token_id.' in str( - exc_info.value), 'Error message for missing eos_token_id is not correct' + with pytest.raises(ValueError, match=r".*must have an eos_token_id.*"): + build_tokenizer('bert-base-uncased', {}) + + with pytest.raises(ValueError, match=r".*must have an eos_token_id.*"): + build_tokenizer('bert-base-uncased', {'eos_token_id': None}) def test_build_callback_fails(): From 785f906f7ac5d7e5ce8f6bfe3895b814ba269b6d Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Wed, 31 Jan 2024 11:21:19 -0800 Subject: [PATCH 4/4] lint --- llmfoundry/utils/builders.py | 4 ++-- tests/utils/test_builders.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index c9ae6211aa..457f146986 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -407,9 +407,9 @@ def build_tokenizer( int(1e30), ) - if not hasattr(tokenizer, 'eos_token_id') or tokenizer.eos_token_id is None: + if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: raise ValueError( - f'The tokenizer {tokenizer_name} must have an eos_token_id.') + f'The tokenizer {tokenizer_name} must have an eos_token.') if dist.is_available() and dist.is_initialized( ) and dist.get_world_size() > 1: diff --git a/tests/utils/test_builders.py b/tests/utils/test_builders.py index 294ea67aa8..b35e053c5d 100644 --- a/tests/utils/test_builders.py +++ b/tests/utils/test_builders.py @@ -48,12 +48,11 @@ def test_tokenizer_builder(tokenizer_name: str, tokenizer_kwargs: dict): assert isinstance(tokenizer, PreTrainedTokenizerBase) -def test_tokenizer_no_EOS(tokenizer_name: str, tokenizer_kwargs: dict): - with pytest.raises(ValueError, match=r".*must have an eos_token_id.*"): +def test_tokenizer_no_EOS(): + with pytest.raises( + ValueError, + match='The tokenizer bert-base-uncased must have an eos_token.'): build_tokenizer('bert-base-uncased', {}) - - with pytest.raises(ValueError, match=r".*must have an eos_token_id.*"): - build_tokenizer('bert-base-uncased', {'eos_token_id': None}) def test_build_callback_fails():