diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e7112e2e61..80050d0822 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -1,3 +1,8 @@ +use std::{ + convert::Infallible, + fmt::{Debug, Display}, +}; + use crate::{DType, DeviceLocation, Layout, MetalError, Shape}; #[derive(Debug, Clone)] @@ -194,6 +199,13 @@ pub enum Error { #[error(transparent)] Wrapped(Box), + /// Arbitrary errors wrapping with context. + #[error("{wrapped:?}\n{context:?}")] + WrappedContext { + wrapped: Box, + context: String, + }, + /// Adding path information to an error. #[error("path: {path:?} {inner}")] WithPath { @@ -215,14 +227,21 @@ pub enum Error { pub type Result = std::result::Result; impl Error { + /// Create a new error by wrapping another. pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { Self::Wrapped(Box::new(err)).bt() } - pub fn msg(err: impl std::error::Error) -> Self { - Self::Msg(err.to_string()).bt() + /// Create a new error based on a printable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. + pub fn msg(msg: M) -> Self { + Self::Msg(msg.to_string()).bt() } + /// Create a new error based on a debuggable error message. + /// + /// If the message implements `std::error::Error`, prefer using [`Error::wrap`] instead. pub fn debug(err: impl std::fmt::Debug) -> Self { Self::Msg(format!("{err:?}")).bt() } @@ -267,3 +286,86 @@ pub fn zip(r1: Result, r2: Result) -> Result<(T, U)> { (_, Err(e)) => Err(e), } } + +pub(crate) mod private { + pub trait Sealed {} + + impl Sealed for std::result::Result where E: std::error::Error {} + impl Sealed for Option {} +} + +/// Attach more context to an error. +/// +/// Inspired by [`anyhow::Context`]. +pub trait Context: private::Sealed { + /// Wrap the error value with additional context. + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static; + + /// Wrap the error value with additional context that is evaluated lazily + /// only once an error does occur. + fn with_context(self, f: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C; +} + +impl Context for std::result::Result +where + E: std::error::Error + Send + Sync + 'static, +{ + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using map_err to save 2 useless frames off the captured backtrace + // in ext_context. + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context.to_string(), + }), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Ok(ok) => Ok(ok), + Err(error) => Err(Error::WrappedContext { + wrapped: Box::new(error), + context: context().to_string(), + }), + } + } +} + +impl Context for Option { + fn context(self, context: C) -> std::result::Result + where + C: Display + Send + Sync + 'static, + { + // Not using ok_or_else to save 2 useless frames off the captured + // backtrace. + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context)), + } + } + + fn with_context(self, context: F) -> std::result::Result + where + C: Display + Send + Sync + 'static, + F: FnOnce() -> C, + { + match self { + Some(ok) => Ok(ok), + None => Err(Error::msg(context())), + } + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 2a36cebd34..a38ed4790b 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -80,7 +80,7 @@ pub use cpu_backend::{CpuStorage, CpuStorageRef}; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; -pub use error::{Error, Result}; +pub use error::{Context, Error, Result}; pub use indexer::IndexOp; pub use layout::Layout; pub use shape::{Shape, D};