From beb2a0968771ba0a07f9633a7846361cbb9746c0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 31 Jan 2024 14:39:07 +0000 Subject: [PATCH] DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect initialization (#28760) --- src/transformers/models/clvp/modeling_clvp.py | 2 +- .../models/codegen/modeling_codegen.py | 4 +- .../modeling_conditional_detr.py | 2 +- src/transformers/models/ctrl/modeling_ctrl.py | 4 +- .../modeling_deformable_detr.py | 6 +- .../open_llama/modeling_open_llama.py | 10 +-- .../transfo_xl/modeling_transfo_xl.py | 4 +- src/transformers/models/deta/modeling_deta.py | 6 +- src/transformers/models/detr/modeling_detr.py | 2 +- src/transformers/models/esm/modeling_esm.py | 2 +- .../models/falcon/modeling_falcon.py | 10 +-- .../modeling_fastspeech2_conformer.py | 4 +- src/transformers/models/fsmt/modeling_fsmt.py | 4 +- .../models/funnel/modeling_funnel.py | 10 +-- .../models/fuyu/image_processing_fuyu.py | 4 +- .../models/gpt_neox/modeling_gpt_neox.py | 10 +-- .../modeling_gpt_neox_japanese.py | 4 +- src/transformers/models/gptj/modeling_gptj.py | 4 +- .../models/idefics/modeling_idefics.py | 4 +- .../models/kosmos2/modeling_kosmos2.py | 4 +- .../models/llama/modeling_llama.py | 10 +-- .../models/m2m_100/modeling_m2m_100.py | 4 +- .../mask2former/modeling_mask2former.py | 4 +- .../models/maskformer/modeling_maskformer.py | 2 +- src/transformers/models/mega/modeling_mega.py | 2 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mixtral/modeling_mixtral.py | 4 +- src/transformers/models/mpt/modeling_mpt.py | 2 +- .../models/musicgen/modeling_musicgen.py | 4 +- .../models/nezha/modeling_nezha.py | 2 +- .../models/nllb_moe/modeling_nllb_moe.py | 4 +- .../models/oneformer/modeling_oneformer.py | 4 +- .../models/pegasus_x/modeling_pegasus_x.py | 2 +- .../models/persimmon/modeling_persimmon.py | 10 +-- src/transformers/models/phi/modeling_phi.py | 10 +-- .../models/qwen2/modeling_qwen2.py | 4 +- .../seamless_m4t/modeling_seamless_m4t.py | 10 +-- .../modeling_seamless_m4t_v2.py | 4 +- .../speech_to_text/modeling_speech_to_text.py | 4 +- .../modeling_speech_to_text_2.py | 4 +- .../models/speecht5/modeling_speecht5.py | 6 +- .../models/swin2sr/modeling_swin2sr.py | 4 +- .../models/swinv2/modeling_swinv2.py | 4 +- .../modeling_table_transformer.py | 2 +- .../models/trocr/modeling_trocr.py | 4 +- .../wav2vec2_bert/modeling_wav2vec2_bert.py | 6 +- .../modeling_wav2vec2_conformer.py | 6 +- src/transformers/models/xglm/modeling_xglm.py | 4 +- .../models/xlnet/modeling_xlnet.py | 8 +-- tests/deepspeed/test_deepspeed.py | 72 +++++++++++++++++++ 50 files changed, 192 insertions(+), 118 deletions(-) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index 64c6927e4a44ba..b660f54e5d820f 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -255,7 +255,7 @@ class ClvpRotaryPositionalEmbedding(nn.Module): def __init__(self, config): super().__init__() dim = max(config.projection_dim // (config.num_attention_heads * 2), 32) - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.cached_sequence_length = None diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 6fc054254a4806..60496f57212226 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -53,8 +53,8 @@ # Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) - sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 6c2cbb859c8e8e..2a9bbdeff6bd7c 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -443,7 +443,7 @@ def forward(self, pixel_values, pixel_mask): y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 3f1607eb95c4de..3814f897d545fa 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -47,8 +47,8 @@ def angle_defn(pos, i, d_model_size): def positional_encoding(position, d_model_size, dtype): # create the sinusoidal pattern for the positional encoding angle_rads = angle_defn( - torch.arange(position, dtype=dtype).unsqueeze(1), - torch.arange(d_model_size, dtype=dtype).unsqueeze(0), + torch.arange(position, dtype=torch.int64).to(dtype).unsqueeze(1), + torch.arange(d_model_size, dtype=torch.int64).to(dtype).unsqueeze(0), d_model_size, ) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index c77eecb75cceb2..aea5b60bdee453 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -491,7 +491,7 @@ def forward(self, pixel_values, pixel_mask): y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -617,7 +617,7 @@ def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int): def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) @@ -1557,7 +1557,7 @@ def get_proposal_pos_embed(self, proposals): temperature = 10000 scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float() dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # batch_size, num_queries, 4 proposals = proposals.sigmoid() * scale diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 1d780e9f5eb8ad..4bf11dd1b41bc4 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -71,7 +71,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -81,7 +81,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -110,7 +110,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -135,10 +135,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py index 57d5e0b725054f..2fa251399b1bd4 100644 --- a/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py @@ -942,7 +942,9 @@ def forward( hids = [] attentions = [] if output_attentions else None if self.attn_type == 0: # default - pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype) + pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, dtype=torch.int64).type_as( + dtype=word_emb.dtype + ) if self.clamp_len > 0: pos_seq.clamp_(max=self.clamp_len) pos_emb = self.pos_emb(pos_seq) diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 3c5687b40aedbc..eb0336f85bcf23 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -401,7 +401,7 @@ def forward(self, pixel_values, pixel_mask): y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t @@ -526,7 +526,7 @@ def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int) def _reset_parameters(self): nn.init.constant_(self.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + thetas = torch.arange(self.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) @@ -1447,7 +1447,7 @@ def get_proposal_pos_embed(self, proposals): temperature = 10000 scale = 2 * math.pi - dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = torch.arange(num_pos_feats, dtype=torch.int64, device=proposals.device).float() dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) # batch_size, num_queries, 4 proposals = proposals.sigmoid() * scale diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 026100b2450698..218d63a412b170 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -435,7 +435,7 @@ def forward(self, pixel_values, pixel_mask): y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) + dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index b7d0253fc4c9e3..57c436224099cc 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -94,7 +94,7 @@ class RotaryEmbedding(torch.nn.Module): def __init__(self, dim: int): super().__init__() # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) inv_freq = inv_freq self.register_buffer("inv_freq", inv_freq) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c7e7c99bbe4c2a..8a850012a5dd36 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -138,7 +138,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -148,7 +148,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -177,7 +177,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -202,10 +202,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py index 89148ee09d6fc5..9b8fa4ab004f29 100644 --- a/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +++ b/src/transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py @@ -820,9 +820,9 @@ def extend_pos_enc(self, x): # are to the left (i>j) and negative relative positions otherwise (i torch.Tensor: - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) - sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.float), inv_freq).float() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index e10c62ed8a047e..d5613a8254bcb6 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -477,7 +477,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -487,7 +487,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index e99be059f86b5b..7bbbbe8d765c23 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -774,8 +774,8 @@ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) - emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index bd2269ce1606b6..90706cae68c0c1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -127,7 +127,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -137,7 +137,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -165,7 +165,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -189,10 +189,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 769a6b2d21859b..1aad2bde81c8c7 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -111,8 +111,8 @@ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) - emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index a88028a8071765..15f1759045f6a7 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -860,7 +860,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -2129,7 +2129,7 @@ def _init_weights(self, module: nn.Module): elif isinstance(module, Mask2FormerPixelDecoderEncoderMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) + thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 026ea15d443969..dd8f7ccfdf9eb1 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -1351,7 +1351,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t diff --git a/src/transformers/models/mega/modeling_mega.py b/src/transformers/models/mega/modeling_mega.py index 60628fb5df81d3..dda31f5d949ea4 100644 --- a/src/transformers/models/mega/modeling_mega.py +++ b/src/transformers/models/mega/modeling_mega.py @@ -169,7 +169,7 @@ def __init__(self, config: MegaConfig): def get_sinusoid_embeddings(max_positions: int, embedding_dim: int): half_dim = embedding_dim // 2 emb = math.log(10000) / half_dim - emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) emb = torch.arange(max_positions, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) return torch.sin(emb), torch.cos(emb) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b0550943765034..fe51d7ed2afc96 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -96,7 +96,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -106,7 +106,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index b8bc13fbe038bc..5c347b38bb1e86 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -189,7 +189,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -199,7 +199,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 74f214c4fcab75..fc4af29d8c696d 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -66,7 +66,7 @@ def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8, device= alibi = torch.arange(1 - sequence_length, 1, dtype=torch.int32, device=device).view(1, 1, 1, sequence_length) num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads)) - base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.float32, device=device) + base = torch.arange(1, num_heads_power_of_2 + 1, dtype=torch.int64, device=device).float() base = base * (alibi_bias_max / num_heads_power_of_2) slopes = 1.0 / torch.pow(2, base) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a60159c7a003f2..9a6518a4d11881 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -126,8 +126,8 @@ def get_embedding(num_embeddings: int, embedding_dim: int): """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) - emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index b6d024b9d6639a..918a10b2759a2d 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -150,7 +150,7 @@ def __init__(self, length, depth, max_relative_position=127): final_mat = distance_mat_clipped + max_relative_position embeddings_table = torch.zeros(vocab_size, depth) - position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1) + position = torch.arange(0, vocab_size, dtype=torch.int64).float().unsqueeze(1) div_term = torch.exp(torch.arange(0, depth, 2).float() * (-math.log(10000.0) / depth)) embeddings_table[:, 0::2] = torch.sin(position * div_term) embeddings_table[:, 1::2] = torch.cos(position * div_term) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index a106d5bcc410b6..e02c0b0fd77506 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -164,8 +164,8 @@ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) - emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: # zero pad diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 894dac10f7ea4d..87014d8afbf6fa 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2400,7 +2400,7 @@ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=x.device).type_as(x) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t @@ -2799,7 +2799,7 @@ def _init_weights(self, module: nn.Module): module.query_input_projection._is_hf_initialized = True elif isinstance(module, OneFormerPixelDecoderEncoderMultiscaleDeformableAttention): nn.init.constant_(module.sampling_offsets.weight.data, 0.0) - thetas = torch.arange(module.n_heads, dtype=torch.float32) * (2.0 * math.pi / module.n_heads) + thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = ( (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 98d53b5d0f015f..49539514378a08 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -109,7 +109,7 @@ def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) - pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype) half_d_feature = self.embed_dim // 2 div_term = torch.exp( - torch.arange(half_d_feature, device=input_embeds.device, dtype=input_embeds.dtype) + torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds) * -(np.log(float(self.max_scale)) / (half_d_feature - 1)) ) pe[:, :half_d_feature] = torch.sin(positions * div_term) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 90d89a6d9e9f5a..a936a7f89f06d0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -48,7 +48,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -58,7 +58,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -87,7 +87,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -112,10 +112,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 823807a475db4f..52a7123a952399 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -86,7 +86,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -96,7 +96,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -125,7 +125,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) @@ -150,10 +150,10 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f8290928a5ca9e..5f7ad4bd4049d9 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -103,7 +103,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. @@ -113,7 +113,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 4410a18bd1bcbe..6b00754930b333 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -365,7 +365,7 @@ def __init__(self, config): dim = config.hidden_size // config.speech_encoder_attention_heads base = config.rotary_embedding_base - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.cached_sequence_length = None self.cached_rotary_positional_embedding = None @@ -414,9 +414,9 @@ def extend_pe(self, x): # are to the left (i>j) and negative relative positions otherwise (ij) and negative relative positions otherwise (ij) and negative relative positions otherwise (i 0: fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) @@ -1049,7 +1049,7 @@ def relative_positional_encoding(self, qlen, klen, bsz=None): pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1) else: - fwd_pos_seq = torch.arange(beg, end, -1.0) + fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float() if self.clamp_len > 0: fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) diff --git a/tests/deepspeed/test_deepspeed.py b/tests/deepspeed/test_deepspeed.py index 982578d455bd7b..fe623d972c86f0 100644 --- a/tests/deepspeed/test_deepspeed.py +++ b/tests/deepspeed/test_deepspeed.py @@ -25,6 +25,7 @@ from parameterized import parameterized import tests.trainer.test_trainer +import transformers from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers.integrations.deepspeed import ( @@ -53,6 +54,8 @@ if is_torch_available(): + import torch + from tests.trainer.test_trainer import ( # noqa RegressionModelConfig, RegressionPreTrainedModel, @@ -70,6 +73,7 @@ T5_SMALL = "t5-small" T5_TINY = "patrickvonplaten/t5-tiny-random" GPT2_TINY = "sshleifer/tiny-gpt2" +GPTJ_TINY = "hf-internal-testing/tiny-random-gptj" def load_json(path): @@ -297,6 +301,74 @@ def _init_weights(self, module): torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)), ) + def test_arange_bf16(self): + # Tests that configuring DeepSpeed with 16 bits does not cause float `torch.arange()` tensors to be cast down. + # NOTE -- this assumes that the function calls have the following downcast-preventing pattern, i.e. + # `torch.arange(...,dtype=torch.int64)` followed by a cast like `.to(torch.float32)`. 🚨 If this pattern is + # NOT applied (e.g. `torch.arange(...,dtype=torch.float32)` is used), DeepSpeed can automatically cast it down + # at init time. See https://github.com/huggingface/transformers/issues/28685 for more info. + + ds_config = { + "train_batch_size": 1, + "zero_optimization": { + "stage": 3, + }, + "bf16": {"enabled": True}, + } + + dschf = HfDeepSpeedConfig(ds_config) + + self.assertTrue(dschf.is_zero3()) + self.assertTrue(is_deepspeed_zero3_enabled()) + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + model = AutoModel.from_pretrained(GPTJ_TINY) + self.assertIn("Detected DeepSpeed ZeRO-3", cl.out) + + # The model weights are in BF16 as per deepspeed config + self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16") + good_deepspeed_sin_cos = model.h[0].attn.embed_positions + + # Monkeypatches the function that creates RoPE embeddings using the INCORRECT torch.arange() pattern, and + # then recreates the model + def bad_deepspeed_create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) + # Incorrect pattern here: torch.arange has dtype=torch.float32 as its argument, and it will automatically + # converted to BF16 by DeepSpeed + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=inv_freq.dtype), inv_freq) + return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) + + good_deepspeed_create_sinusoidal_positions = transformers.models.gptj.modeling_gptj.create_sinusoidal_positions + transformers.models.gptj.modeling_gptj.create_sinusoidal_positions = bad_deepspeed_create_sinusoidal_positions + + with LoggingLevel(logging.INFO): + with mockenv_context(**self.dist_env_1_gpu): + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + model = AutoModel.from_pretrained(GPTJ_TINY) + self.assertIn("Detected DeepSpeed ZeRO-3", cl.out) + + self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16") + bad_deepspeed_sin_cos = model.h[0].attn.embed_positions + + # Compares the two values: the two sets of values are different, and the correct one matches the torch + # (i.e. outside DeepSpeed) version. + good_torch_sin_cos = good_deepspeed_create_sinusoidal_positions( + model.config.max_position_embeddings, model.config.rotary_dim + ) + self.assertFalse(torch.allclose(good_deepspeed_sin_cos, bad_deepspeed_sin_cos)) + self.assertTrue(torch.allclose(good_torch_sin_cos, good_deepspeed_sin_cos.cpu())) + + # Finally, we can see that the incorrect pattern is okay on vanilla torch, demostrating that this issue is + # exclusive to DeepSpeed + bad_torch_sin_cos = bad_deepspeed_create_sinusoidal_positions( + model.config.max_position_embeddings, model.config.rotary_dim + ) + self.assertTrue(torch.allclose(bad_torch_sin_cos, good_torch_sin_cos)) + class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus): def setUp(self):