Skip to content

Commit

Permalink
long tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Jun 6, 2024
1 parent 0ef131d commit 66ea6dc
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 19 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,15 +210,17 @@ def _read_binary_tokenized_sample(
self,
sample: Dict[str, Any],
) -> torch.Tensor:
# Modeling code still expects int64 tensors.
if isinstance(sample['tokens'], np.ndarray):
return torch.from_numpy(sample['tokens'][:self.max_seq_len].copy())
return torch.from_numpy(sample['tokens'][:self.max_seq_len].copy()
).to(torch.int64)
else:
return torch.from_numpy(
np.frombuffer(
sample['tokens'],
dtype=getattr(np, self.token_encoding_type),
)[:self.max_seq_len].copy(),
)
).to(torch.int64)

# How to process a sample
def __getitem__(self,
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/eval/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
import random
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set

import torch
import transformers
Expand Down Expand Up @@ -272,7 +272,7 @@ def __init__(

def __call__(
self,
input_ids: Union[torch.LongTensor, torch.IntTensor],
input_ids: torch.LongTensor,
scores: Optional[torch.FloatTensor] = None,
**kwargs: Dict[str, Any],
) -> bool:
Expand Down
16 changes: 8 additions & 8 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def gen_flash_attn_padding_info(

def apply_sequence_id(
attn_bias: torch.Tensor,
sequence_id: Union[torch.LongTensor, torch.IntTensor],
sequence_id: torch.LongTensor,
max_seq_len: int,
) -> torch.Tensor:
seq_len = sequence_id.shape[-1]
Expand Down Expand Up @@ -470,7 +470,7 @@ def _attn_bias(
device: torch.device,
dtype: torch.dtype,
attention_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
sequence_id: Optional[torch.LongTensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
if not self._attn_bias_initialized:
if self.attn_bias_shape:
Expand Down Expand Up @@ -533,10 +533,10 @@ def _attn_bias(

def forward(
self,
input_ids: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
sequence_id: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -877,11 +877,11 @@ def get_decoder(self) -> MPTModel:

def forward(
self,
input_ids: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
labels: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def prepare_inputs_for_generation(
@staticmethod
def _reorder_cache(
past_key_values: List[Tuple[torch.Tensor, torch.Tensor]],
beam_idx: Union[torch.LongTensor, torch.IntTensor],
beam_idx: torch.LongTensor,
) -> List[Tuple[torch.Tensor, ...]]:
"""Used by HuggingFace generate when using beam search with kv-caching.
Expand Down
2 changes: 1 addition & 1 deletion scripts/inference/hf_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class StopOnTokens(StoppingCriteria):

def __call__(
self,
input_ids: Union[torch.LongTensor, torch.IntTensor],
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
**kwargs: Any,
) -> bool:
Expand Down
3 changes: 1 addition & 2 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ def test_correct_padding(
batch = next(iter(eval_loader))

assert batch['input_ids'].shape == torch.Size([batch_size, 2048])
assert batch['input_ids'].type(
) == 'torch.LongTensor' or batch['input_ids'].type() == 'torch.IntTensor'
assert batch['input_ids'].type() == 'torch.LongTensor'

# we follow the convention (from huggingface) that non-attended tokens are 0 in the attn mask and -100 in the labels
attention_mask = batch.get(
Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_mpt_gen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
from unittest.mock import Mock, patch

import pytest
Expand All @@ -27,11 +27,11 @@ class MockMPTForCausalLM(MPTForCausalLM):

def forward(
self,
input_ids: Union[torch.LongTensor, torch.IntTensor],
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
sequence_id: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
labels: Optional[Union[torch.LongTensor, torch.IntTensor]] = None,
sequence_id: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down

0 comments on commit 66ea6dc

Please sign in to comment.