From 66ea6dc90ae279e038be7a67091afcfd3cd5b3d9 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Wed, 5 Jun 2024 18:28:48 -0700 Subject: [PATCH] long tensor --- llmfoundry/data/text_data.py | 6 ++++-- llmfoundry/eval/datasets/utils.py | 4 ++-- llmfoundry/models/mpt/modeling_mpt.py | 16 ++++++++-------- scripts/inference/hf_chat.py | 2 +- tests/data/test_dataloader.py | 3 +-- tests/models/test_mpt_gen.py | 8 ++++---- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 8c278b7e17..b6c0685960 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -210,15 +210,17 @@ def _read_binary_tokenized_sample( self, sample: Dict[str, Any], ) -> torch.Tensor: + # Modeling code still expects int64 tensors. if isinstance(sample['tokens'], np.ndarray): - return torch.from_numpy(sample['tokens'][:self.max_seq_len].copy()) + return torch.from_numpy(sample['tokens'][:self.max_seq_len].copy() + ).to(torch.int64) else: return torch.from_numpy( np.frombuffer( sample['tokens'], dtype=getattr(np, self.token_encoding_type), )[:self.max_seq_len].copy(), - ) + ).to(torch.int64) # How to process a sample def __getitem__(self, diff --git a/llmfoundry/eval/datasets/utils.py b/llmfoundry/eval/datasets/utils.py index 40881b3735..1ce249437d 100644 --- a/llmfoundry/eval/datasets/utils.py +++ b/llmfoundry/eval/datasets/utils.py @@ -6,7 +6,7 @@ import logging import random -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set import torch import transformers @@ -272,7 +272,7 @@ def __init__( def __call__( self, - input_ids: Union[torch.LongTensor, torch.IntTensor], + input_ids: torch.LongTensor, scores: Optional[torch.FloatTensor] = None, **kwargs: Dict[str, Any], ) -> bool: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a6a7d659ac..9d18799e93 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -277,7 +277,7 @@ def gen_flash_attn_padding_info( def apply_sequence_id( attn_bias: torch.Tensor, - sequence_id: Union[torch.LongTensor, torch.IntTensor], + sequence_id: torch.LongTensor, max_seq_len: int, ) -> torch.Tensor: seq_len = sequence_id.shape[-1] @@ -470,7 +470,7 @@ def _attn_bias( device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + sequence_id: Optional[torch.LongTensor] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]: if not self._attn_bias_initialized: if self.attn_bias_shape: @@ -533,10 +533,10 @@ def _attn_bias( def forward( self, - input_ids: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + sequence_id: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -877,11 +877,11 @@ def get_decoder(self) -> MPTModel: def forward( self, - input_ids: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, - labels: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1056,7 +1056,7 @@ def prepare_inputs_for_generation( @staticmethod def _reorder_cache( past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], - beam_idx: Union[torch.LongTensor, torch.IntTensor], + beam_idx: torch.LongTensor, ) -> List[Tuple[torch.Tensor, ...]]: """Used by HuggingFace generate when using beam search with kv-caching. diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index dc9776ee46..e992371c32 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -87,7 +87,7 @@ class StopOnTokens(StoppingCriteria): def __call__( self, - input_ids: Union[torch.LongTensor, torch.IntTensor], + input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any, ) -> bool: diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index fee3e53c8b..ec27df8121 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -270,8 +270,7 @@ def test_correct_padding( batch = next(iter(eval_loader)) assert batch['input_ids'].shape == torch.Size([batch_size, 2048]) - assert batch['input_ids'].type( - ) == 'torch.LongTensor' or batch['input_ids'].type() == 'torch.IntTensor' + assert batch['input_ids'].type() == 'torch.LongTensor' # we follow the convention (from huggingface) that non-attended tokens are 0 in the attn mask and -100 in the labels attention_mask = batch.get( diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py index 6cca704bb2..1c9b5ef9d4 100644 --- a/tests/models/test_mpt_gen.py +++ b/tests/models/test_mpt_gen.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple from unittest.mock import Mock, patch import pytest @@ -27,11 +27,11 @@ class MockMPTForCausalLM(MPTForCausalLM): def forward( self, - input_ids: Union[torch.LongTensor, torch.IntTensor], + input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None, attention_mask: Optional[torch.ByteTensor] = None, - sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, - labels: Optional[Union[torch.LongTensor, torch.IntTensor]] = None, + sequence_id: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,