Skip to content

Commit

Permalink
Add support to UL2 model family (#1300)
Browse files Browse the repository at this point in the history
* Add support to UL2 model family

* Update docs with UL2

* Create ActivationWithOptionalGating to avoid polluting activations

* Also refactor quantized t5

* Remove useless conversion

* Revert Activation::NewGelu name change

* Remove useless return

* Apply rustfmt and clippy recommendations

* Reuse t5::ActivationWithOptionalGating in quantized version

* (cosmetic change) use a match rather than ifs + avoid early returns.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
Juarez Bochi and LaurentMazare authored Nov 9, 2023
1 parent 6958384 commit 18d3000
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false }
safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2"
serde_json = "1.0.99"
thiserror = "1"
tokenizers = { version = "0.13.4", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Text to text.
- T5 and its variants: FlanT5, MADLAD400 (translation), CoEdit (Grammar correction).
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image.
Expand Down
3 changes: 2 additions & 1 deletion candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);

/// `gelu` operation
/// Tanh based approximation of the `gelu` operation
/// GeluErf is the more precise one.
/// <https://en.wikipedia.org/wiki/Activation_function#Comparison_of_activation_functions>
impl UnaryOpT for Gelu {
const NAME: &'static str = "gelu";
Expand Down
4 changes: 3 additions & 1 deletion candle-examples/examples/t5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ $ cargo run --example t5 --release -- --model-id "t5-small" --prompt "translate
9 tokens generated (2.42 token/s)
```

Variants such as [flan-t5](https://huggingface.co/google/flan-t5-small), [flan-ul2](https://huggingface.co/google/flan-ul2) (with `--revision "refs/pr/25"`), and [Co-EdIT](https://huggingface.co/grammarly/coedit-large) are also supported.

## Translation with [MADLAD-400](https://arxiv.org/abs/2309.04662)

MADLAD-400 is a series of multilingual machine translation T5 models trained on 250 billion tokens covering over 450 languages using publicly available data. These models are competitive with significantly larger models.
Expand All @@ -22,7 +24,7 @@ cargo run --example t5 --release -- \
Wie geht es dir, mein Freund?
```

## Sentence embedding example:
## Sentence embedding example

```bash
$ cargo run --example t5 --release -- --model-id "t5-small" --prompt "A beautiful candle."
Expand Down
11 changes: 11 additions & 0 deletions candle-examples/examples/t5/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ impl T5ModelBuilder {
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};
Expand Down
1 change: 0 additions & 1 deletion candle-nn/src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use serde::Deserialize;
pub enum Activation {
#[default]
Gelu,
#[serde(rename = "gated-gelu")]
NewGelu,
Relu,
Relu2,
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ rand = { workspace = true }
rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
serde_plain = { workspace = true }
tracing = { workspace = true }
wav = { workspace = true }

Expand Down
14 changes: 9 additions & 5 deletions candle-transformers/src/models/quantized_t5.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// T5 Text Model, quantized version
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py

use crate::models::t5::{deserialize_feed_forward_proj_activation, ActivationWithOptionalGating};
use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder;
Expand Down Expand Up @@ -54,8 +55,8 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
pub feed_forward_proj: ActivationWithOptionalGating,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
Expand Down Expand Up @@ -83,7 +84,10 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
Expand Down Expand Up @@ -176,7 +180,7 @@ impl T5DenseGatedActDense {
wi_0,
wi_1,
wo,
act: Activation::NewGelu,
act: cfg.feed_forward_proj.activation,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
})
}
Expand Down Expand Up @@ -205,7 +209,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
(
None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
Expand Down
49 changes: 43 additions & 6 deletions candle-transformers/src/models/t5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}

#[derive(Debug, Deserialize, Default, Clone, PartialEq)]
pub struct ActivationWithOptionalGating {
pub gated: bool,
pub activation: candle_nn::Activation,
}

pub fn deserialize_feed_forward_proj_activation<'de, D>(
deserializer: D,
) -> std::result::Result<ActivationWithOptionalGating, D::Error>
where
D: serde::de::Deserializer<'de>,
{
match String::deserialize(deserializer)?.as_str() {
"gated-gelu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::NewGelu,
}),
"gated-silu" => Ok(ActivationWithOptionalGating {
gated: true,
activation: candle_nn::Activation::Silu,
}),
buf => {
let activation = serde_plain::from_str(buf).map_err(serde::de::Error::custom)?;
Ok(ActivationWithOptionalGating {
gated: false,
activation,
})
}
}
}

#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
Expand All @@ -52,8 +83,8 @@ pub struct Config {
dropout_rate: f64,
layer_norm_epsilon: f64,
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
#[serde(default, deserialize_with = "deserialize_feed_forward_proj_activation")]
feed_forward_proj: ActivationWithOptionalGating,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
Expand Down Expand Up @@ -81,7 +112,10 @@ impl Default for Config {
dropout_rate: 0.1,
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
Expand All @@ -102,7 +136,10 @@ impl Config {
d_model: 768,
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
feed_forward_proj: ActivationWithOptionalGating {
gated: false,
activation: Activation::Relu,
},
tie_word_embeddings: true,
initializer_factor: 1.0,
is_decoder: false,
Expand Down Expand Up @@ -202,7 +239,7 @@ impl T5DenseGatedActDense {
wi_0,
wi_1,
wo,
act: Activation::NewGelu,
act: cfg.feed_forward_proj.activation,
span: tracing::span!(tracing::Level::TRACE, "dense-gated-act-dense"),
})
}
Expand Down Expand Up @@ -231,7 +268,7 @@ impl T5LayerFF {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let layer_norm =
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj == Activation::NewGelu {
let (dense_act, gated_dense_act) = if cfg.feed_forward_proj.gated {
(
None,
Some(T5DenseGatedActDense::load(vb.pp("DenseReluDense"), cfg)?),
Expand Down

0 comments on commit 18d3000

Please sign in to comment.