Skip to content

Commit

Permalink
Remove jnp.DeviceArray since it is deprecated. (huggingface#24875)
Browse files Browse the repository at this point in the history
* Remove jnp.DeviceArray since it is deprecated.

* Replace all instances of jnp.DeviceArray with jax.Array

* Update src/transformers/models/bert/modeling_flax_bert.py

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
  • Loading branch information
mariecwhite and sanchit-gandhi authored Aug 4, 2023
1 parent fdd81ae commit a6e6b1c
Show file tree
Hide file tree
Showing 24 changed files with 38 additions and 38 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/bart/modeling_flax_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,8 +1467,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down Expand Up @@ -1960,7 +1960,7 @@ def __call__(
class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
module_class = FlaxBartForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,7 @@ def __call__(
class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
module_class = FlaxBertForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/big_bird/modeling_flax_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -2599,7 +2599,7 @@ def __call__(
class FlaxBigBirdForCausalLM(FlaxBigBirdPreTrainedModel):
module_class = FlaxBigBirdForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1443,8 +1443,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1441,8 +1441,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/electra/modeling_flax_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,7 +1565,7 @@ def __call__(
class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
module_class = FlaxElectraForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_flax_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def __call__(
class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
module_class = FlaxGPT2LMHeadModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def __call__(
class FlaxGPTNeoForCausalLM(FlaxGPTNeoPreTrainedModel):
module_class = FlaxGPTNeoForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_flax_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def __call__(
class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
module_class = FlaxGPTJForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/longt5/modeling_flax_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,8 +2388,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/marian/modeling_flax_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,8 +1436,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mbart/modeling_flax_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,8 +1502,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/opt/modeling_flax_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def __call__(
class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
module_class = FlaxOPTForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/pegasus/modeling_flax_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,8 +1450,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/roberta/modeling_flax_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ def __call__(
class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
module_class = FlaxRobertaForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ def __call__(
class FlaxRobertaPreLayerNormForCausalLM(FlaxRobertaPreLayerNormPreTrainedModel):
module_class = FlaxRobertaPreLayerNormForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,8 +1740,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/whisper/modeling_flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,8 +1448,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/xglm/modeling_flax_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def __call__(
class FlaxXGLMForCausalLM(FlaxXGLMPreTrainedModel):
module_class = FlaxXGLMForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,7 @@ def __call__(
class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel):
module_class = FlaxXLMRobertaForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,7 @@ def __call__(
class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.camelcase_modelname}}PreTrainedModel):
module_class = Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule

def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
# initializing the cache
batch_size, seq_length = input_ids.shape

Expand Down Expand Up @@ -2969,8 +2969,8 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
attention_mask: Optional[jax.Array] = None,
decoder_attention_mask: Optional[jax.Array] = None,
encoder_outputs=None,
**kwargs
):
Expand Down

0 comments on commit a6e6b1c

Please sign in to comment.