diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a2c3f4286c..4c38a2d699 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2539,6 +2539,23 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + pub fn is_inf(&self) -> Result { + self.broadcast_eq(&Tensor::new(f64::INFINITY, self.device())?.to_dtype(self.dtype)?) + } + + pub fn any(&self) -> Result { + let sum = self.sum_all()?; + match self.dtype { + DType::U8 => Ok(sum.to_scalar::()? == 0), + DType::U32 => Ok(sum.to_scalar::()? == 0), + DType::I64 => Ok(sum.to_scalar::()? == 0), + DType::F16 => Ok(sum.to_scalar::()? == half::f16::from_f32_const(0.)), + DType::BF16 => Ok(sum.to_scalar::()? == half::bf16::from_f32_const(0.)), + DType::F32 => Ok(sum.to_scalar::()? == 0.), + DType::F64 => Ok(sum.to_scalar::()? == 0.), + } + } } macro_rules! bin_trait { diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 6589b4b146..7e01b9a8ff 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -24,6 +24,7 @@ serde = { workspace = true } serde_json = { workspace = true } serde_plain = { workspace = true } tracing = { workspace = true } +half = { workspace = true } [features] default = [] diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 84e072a294..7e3a2e9ec3 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -577,6 +577,22 @@ impl T5LayerCrossAttention { } } +fn clamp_for_f16(xs: &Tensor) -> Result { + let mut max = match xs.dtype() { + DType::U8 => u8::MAX as f64 - 1000., + DType::U32 => u32::MAX as f64 - 1000., + DType::I64 => i64::MAX as f64 - 1000., + DType::F16 => half::f16::MAX.to_f64_const() - 1000., + DType::BF16 => half::bf16::MAX.to_f64_const() - 1000., + DType::F32 => f32::MAX as f64 - 1000., + DType::F64 => f64::MAX - 1000., + }; + if xs.is_inf()?.any()? { + max = max - 1000.; + } + xs.clamp(-max, max) +} + #[derive(Debug, Clone)] struct T5Block { self_attn: T5LayerSelfAttention, @@ -632,13 +648,22 @@ impl T5Block { false => None, }; let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias, mask.as_ref())?; - // TODO: clamp for f16? + // Clamp for f16 + if xs.dtype() == DType::F16 { + xs = clamp_for_f16(&xs)?; + } if let Some(cross_attn) = &mut self.cross_attn { (xs, _) = cross_attn.forward(&xs, None, encoder_hidden_states.unwrap())?; - // TODO: clamp for f16? + // Clamp for f16 + if xs.dtype() == DType::F16 { + xs = clamp_for_f16(&xs)?; + } + } + let mut xs = self.ff.forward(&xs)?; + // Clamp for f16 + if xs.dtype() == DType::F16 { + xs = clamp_for_f16(&xs)?; } - let xs = self.ff.forward(&xs)?; - // TODO: clamp for f16? Ok((xs, position_bias)) }