Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Context trait similar to anyhow::Context. #2676

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ pub struct MatMulUnexpectedStriding {
pub msg: &'static str,
}

impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self}")
}
}

/// Main library error type.
#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error)]
pub enum Error {
// === DType Errors ===
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
Expand Down Expand Up @@ -199,8 +205,14 @@ pub enum Error {
UnsupportedSafeTensorDtype(safetensors::Dtype),

/// Arbitrary errors wrapping.
#[error(transparent)]
Wrapped(Box<dyn std::error::Error + Send + Sync>),
#[error("{0}")]
Wrapped(Box<dyn std::fmt::Display + Send + Sync>),

#[error("{context}\n{inner}")]
Context {
inner: Box<Self>,
context: Box<dyn std::fmt::Display + Send + Sync>,
},

/// Adding path information to an error.
#[error("path: {path:?} {inner}")]
Expand All @@ -218,16 +230,19 @@ pub enum Error {
/// User generated error message, typically created via `bail!`.
#[error("{0}")]
Msg(String),

#[error("unwrap none")]
UnwrapNone,
}

pub type Result<T> = std::result::Result<T, Error>;

impl Error {
pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self {
pub fn wrap(err: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Wrapped(Box::new(err)).bt()
}

pub fn msg(err: impl std::error::Error) -> Self {
pub fn msg(err: impl std::fmt::Display) -> Self {
Self::Msg(err.to_string()).bt()
}

Expand All @@ -253,6 +268,13 @@ impl Error {
path: p.as_ref().to_path_buf(),
}
}

pub fn context(self, c: impl std::fmt::Display + Send + Sync + 'static) -> Self {
Self::Context {
inner: Box::new(self),
context: Box::new(c),
}
}
}

#[macro_export]
Expand All @@ -275,3 +297,41 @@ pub fn zip<T, U>(r1: Result<T>, r2: Result<U>) -> Result<(T, U)> {
(_, Err(e)) => Err(e),
}
}

// Taken from anyhow.
pub trait Context<T> {
/// Wrap the error value with additional context.
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static;

/// Wrap the error value with additional context that is evaluated lazily
/// only once an error does occur.
fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C;
}

impl<T> Context<T> for Option<T> {
fn context<C>(self, context: C) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(context).bt()),
}
}

fn with_context<C, F>(self, f: F) -> Result<T>
where
C: std::fmt::Display + Send + Sync + 'static,
F: FnOnce() -> C,
{
match self {
Some(v) => Ok(v),
None => Err(Error::UnwrapNone.context(f()).bt()),
}
}
}
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
pub use error::{Context, Error, Result};
pub use indexer::{IndexOp, TensorIndexer};
pub use layout::Layout;
pub use shape::{Shape, D};
Expand Down
8 changes: 4 additions & 4 deletions candle-core/src/pickle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result, Tensor};
use crate::{Context, DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
Expand Down Expand Up @@ -537,7 +537,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
d.push((key, value))
}
} else {
Expand All @@ -557,7 +557,7 @@ impl Stack {
crate::bail!("setitems: not an even number of objects")
}
while let Some(value) = objs.pop() {
let key = objs.pop().unwrap();
let key = objs.pop().context("empty objs")?;
pydict.push((key, value))
}
self.push(Object::Dict(pydict))
Expand Down Expand Up @@ -661,7 +661,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
if !file_name.ends_with("data.pkl") {
continue;
}
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
let reader = zip.by_name(file_name)?;
let mut reader = std::io::BufReader::new(reader);
let mut stack = Stack::empty();
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/gguf_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//!

use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::{Context, Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;

Expand Down Expand Up @@ -338,7 +338,7 @@ impl Value {
if value_type.len() != 1 {
crate::bail!("multiple value-types in the same array {value_type:?}")
}
value_type.into_iter().next().unwrap()
value_type.into_iter().next().context("empty value_type")?
};
w.write_u32::<LittleEndian>(value_type.to_u32())?;
w.write_u64::<LittleEndian>(v.len() as u64)?;
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/quantized/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! Code for GGML and GGUF files
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;

Expand Down Expand Up @@ -481,7 +481,7 @@ impl crate::CustomOp1 for QTensor {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
let last_k = dst_shape.pop().context("empty dst_shape")?;
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
Expand Down
4 changes: 2 additions & 2 deletions candle-core/src/tensor_cat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{shape::Dim, Error, Result, Shape, Tensor};
use crate::{shape::Dim, Context, Error, Result, Shape, Tensor};

impl Tensor {
/// Concatenates two or more tensors along a particular dimension.
Expand Down Expand Up @@ -134,7 +134,7 @@ impl Tensor {
.bt())?
}
}
let next_offset = offsets.last().unwrap() + arg.elem_count();
let next_offset = offsets.last().context("empty offsets")? + arg.elem_count();
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/generation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Functionality for modeling sampling strategies and logits processing in text generation
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
//! and combinations thereof.
use candle::{DType, Error, Result, Tensor};
use candle::{Context, DType, Error, Result, Tensor};
use rand::{distributions::Distribution, SeedableRng};

#[derive(Clone, PartialEq, Debug)]
Expand Down Expand Up @@ -45,7 +45,7 @@ impl LogitsProcessor {
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap();
.context("empty logits")?;
Ok(next_token)
}

Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/chinese_clip/vision_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! - 💻 [Chinese-CLIP](https://github.com/OFA-Sys/Chinese-CLIP)
//! - 💻 [GH](https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py_

use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle::{Context, DType, IndexOp, Module, Result, Shape, Tensor, D};
use candle_nn as nn;

use super::{Activation, EncoderConfig};
Expand Down Expand Up @@ -363,7 +363,7 @@ impl ChineseClipVisionTransformer {
.apply(&self.pre_layer_norm)?;

let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/clip/vision_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! https://github.com/openai/CLIP
//! https://github.com/huggingface/transformers/tree/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip

use candle::{IndexOp, Result, Shape, Tensor, D};
use candle::{Context, IndexOp, Result, Shape, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use nn::Conv2dConfig;
Expand Down Expand Up @@ -149,7 +149,7 @@ impl ClipVisionTransformer {
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let encoder_outputs = result.last().context("no last")?;
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/efficientnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! See:
//! - ["EfficientBERT: Progressively Searching Multilayer Perceptron Architectures for BERT"](https://arxiv.org/abs/2201.00462)
//!
use candle::{Result, Tensor, D};
use candle::{Context, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};

Expand Down Expand Up @@ -289,7 +289,7 @@ impl EfficientNet {
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let last_out_c = configs.last().context("no last")?.out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/fastvit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//!
//! Implementation based on [timm model](https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/fastvit.py)

use candle::{DType, Result, Tensor, D};
use candle::{Context, DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, linear_no_bias, ops::sigmoid, ops::softmax,
BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
Expand Down Expand Up @@ -178,7 +178,7 @@ fn squeeze_and_excitation(
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let (gamma, beta) = bn.weight_and_bias().context("no weight-bias")?;
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
Expand Down
22 changes: 9 additions & 13 deletions candle-transformers/src/models/llava/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer}
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;

use candle::{bail, Device, IndexOp, Result, Tensor};
use candle::{bail, Context, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;
Expand Down Expand Up @@ -145,7 +145,7 @@ impl ClipVisionTower {
let config = if config.is_none() {
ClipVisionConfig::clip_vit_large_patch14_336()
} else {
config.clone().unwrap()
config.clone().context("no config")?
};
let select_layer = match select_layer {
-1 | -2 => select_layer,
Expand Down Expand Up @@ -262,14 +262,14 @@ impl LLaVA {
let image_features = if mm_patch_merge_type == "flat" {
image_features
.iter()
.map(|x| x.flatten(0, 1).unwrap())
.collect::<Vec<Tensor>>()
.map(|x| x.flatten(0, 1))
.collect::<Result<Vec<Tensor>>>()?
} else if mm_patch_merge_type.starts_with("spatial") {
let mut new_image_features = Vec::new();
for (image_idx, image_feature) in image_features.iter().enumerate() {
let new_image_feature = if image_feature.dims()[0] > 1 {
let base_image_feature = image_feature.get(0).unwrap();
let patch_image_feature = image_feature.i(1..).unwrap();
let base_image_feature = image_feature.get(0)?;
let patch_image_feature = image_feature.i(1..)?;
let height = self.clip_vision_tower.num_patches_per_side();
let width = height;
assert_eq!(height * width, base_image_feature.dims()[0]);
Expand Down Expand Up @@ -313,16 +313,12 @@ impl LLaVA {
};
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
} else {
let new_image_feature = image_feature.get(0).unwrap();
let new_image_feature = image_feature.get(0)?;
if mm_patch_merge_type.contains("unpad") {
Tensor::cat(
&[
new_image_feature,
self.image_newline.clone().unsqueeze(0).unwrap(),
],
&[new_image_feature, self.image_newline.clone().unsqueeze(0)?],
0,
)
.unwrap()
)?
} else {
new_image_feature
}
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/segformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
//!

use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle::{Context, Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
Expand Down Expand Up @@ -633,7 +633,7 @@ impl ImageClassificationModel {
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
let hidden_states = all_hidden_states.last().unwrap();
let hidden_states = all_hidden_states.last().context("no last")?;
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)
Expand Down
Loading