diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index ad008177ee..fe15187b5a 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -8,11 +8,10 @@ use anyhow::Result; use candle_core::{Device, Tensor}; fn main() -> Result<()> { - let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; - let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; - let start = std::time::Instant::now(); - let res = inp.conv2d(&w, 0, 1, 1, 1)?; - println!("{:?}", start.elapsed()); - println!("{res:?}"); + let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; + let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; + let new_a = a.slice_scatter(&b, 1, 2)?; + assert_eq!(a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + assert_eq!(new_a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); Ok(()) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 36f5f6b17a..6c4fea91dc 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -123,12 +123,6 @@ pub trait Module { fn forward(&self, xs: &Tensor) -> Result; } -impl Module for quantized::QMatMul { - fn forward(&self, xs: &Tensor) -> Result { - self.forward(xs) - } -} - impl Result> Module for T { fn forward(&self, xs: &Tensor) -> Result { self(xs) diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 58f261b484..043733ae87 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -307,8 +307,8 @@ impl crate::CustomOp1 for QTensor { } } -impl QMatMul { - pub fn forward(&self, xs: &Tensor) -> Result { +impl crate::Module for QMatMul { + fn forward(&self, xs: &Tensor) -> Result { match self { Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), Self::Tensor(w) => { diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index a2cecbc3b0..716cca8dee 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -1,7 +1,7 @@ use candle_core::{ quantized::{self, GgmlDType}, test_utils::to_vec2_round, - Device, Result, Tensor, + Device, Module, Result, Tensor, }; use quantized::{k_quants, GgmlType}; use rand::prelude::*; diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 9ded5f71b5..68d384a6df 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -6,7 +6,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::quantized::GgmlType; -use candle::{CpuStorage, Device, Layout, Result, Shape, Tensor, D}; +use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D}; use clap::{Parser, Subcommand}; const CHECK_CONV2D: bool = false; diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b0c623d3f1..ade0001291 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -17,7 +17,7 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType}; +use ::candle::{quantized::QTensor, DType, Device, Module, Tensor, WithDType}; mod utils; use utils::wrap_err; diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 5d53728c75..e5fa7dec23 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -1,7 +1,7 @@ use candle::{ quantized::{self, k_quants, GgmlDType, GgmlType}, test_utils::to_vec2_round, - Device, Result, Tensor, + Device, Module, Result, Tensor, }; use wasm_bindgen_test::*;