Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: autocast #246

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions crates/ratchet-core/src/dtype/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,11 @@ impl DType {
matches!(self, DType::F16 | DType::BF16 | DType::F32)
}

/// Returns the activation dtype for the given quantized dtype.
pub fn activation_dt(&self) -> DType {
/// Returns the compute dtype for the given quantized dtype.
pub fn compute_dt(&self) -> DType {
match self {
DType::Q8_0H(_) => DType::F16,
DType::Q8_0F(_) => DType::F32,
DType::Q4_KH(_) => DType::F16,
DType::Q4_KF(_) => DType::F32,
DType::Q8_0H(_) | DType::Q4_KH(_) => DType::F16,
DType::Q8_0F(_) | DType::Q4_KF(_) => DType::F32,
_ => *self,
}
}
Expand Down
65 changes: 63 additions & 2 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use crate::gpu::{
PoolError, WgpuDevice,
};
use crate::{
ops::*, rvec, CompiledOp, InvariantError, Kernel, KernelBuildError, KernelMetadata,
KernelModuleDesc, RVec, StorageView, Tensor, WgslFragment, WorkgroupSize,
ops::*, rvec, CompiledOp, DType, InvariantError, Kernel, KernelBuildError, KernelMetadata,
KernelModuleDesc, RVec, Shape, StorageView, Tensor, WgslFragment, WorkgroupSize,
};
use std::borrow::Cow;
use std::fmt::Debug;
use std::panic::Location;

#[derive(Clone, Debug)]
#[non_exhaustive]
Expand Down Expand Up @@ -139,6 +140,8 @@ pub enum OperationError {
InplaceError(String),
#[error(transparent)]
DeviceError(#[from] crate::DeviceError),
#[error("Unsupported cast: {0} -> {1}")]
UnsupportedCast(DType, DType),
}

/// Unique string representing a kernel.
Expand Down Expand Up @@ -204,6 +207,64 @@ impl std::fmt::Display for KernelSource {
}
}

#[derive(thiserror::Error, Debug)]
pub enum GuardError {
#[error("Shape mismatch in {op_name} operation\n{shapes:?}\n{expected:?}\nHint: Ensure that the shapes are compatible for this operation.")]
ShapeMismatch {
op_name: &'static str,
shapes: Vec<Shape>,
expected: Option<Shape>,
},
#[error("DType mismatch in {op_name} operation\n{dtypes:?}\nHint: Ensure all dtypes are compatible for this operation.")]
DTypeMismatch {
op_name: &'static str,
dtypes: Vec<DType>,
},
#[error("Error in {op_name} operation: {message}")]
CustomError {
op_name: &'static str,
message: Cow<'static, str>,
},
}

impl GuardError {
pub fn custom<O: Operation>(op: &O, message: impl Into<Cow<'static, str>>) -> Self {
Self::CustomError {
op_name: op.name(),
message: message.into(),
}
}

pub fn shape_mismatch<O: Operation>(
op: &O,
shapes: impl Into<Vec<Shape>>,
expected: Option<Shape>,
) -> Self {
Self::ShapeMismatch {
op_name: op.name(),
shapes: shapes.into(),
expected,
}
}

pub fn dtype_mismatch<O: Operation>(op: &O, dtypes: Vec<DType>) -> Self {
Self::DTypeMismatch {
op_name: op.name(),
dtypes,
}
}

pub fn panic(self, location: &Location) -> ! {
panic!(
"{}\n\nError occurred at {}:{}:{}",
self,
location.file(),
location.line(),
location.column()
)
}
}

/// # Operation Guards - Runtime guards for operation correctness.
///
/// Guards should be implemented for all types that will be a node on the high-level CFG.
Expand Down
22 changes: 16 additions & 6 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::panic::Location;

use derive_new::new;
use encase::ShaderType;
use half::f16;
Expand All @@ -6,10 +8,10 @@ use ratchet_macros::WgslMetadata;

use crate::{
gpu::{dtype::WgslDType, BindGroupLayoutDescriptor},
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, InvariantError, Kernel, KernelElement,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, Shape,
StorageView, Strides, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize,
Workload,
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, GuardError, InvariantError, Kernel,
KernelElement, KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec,
Scalar, Shape, StorageView, Strides, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive,
WorkgroupSize, Workload,
};
#[cfg(test)]
use test_strategy::Arbitrary;
Expand Down Expand Up @@ -125,14 +127,22 @@ pub struct BinaryMeta {
}

impl OpGuards for Binary {
#[track_caller]
fn check_shapes(&self) {
let shapes = [self.lhs.shape(), self.rhs.shape()];
let broadcasted = Shape::multi_broadcast(&shapes);
assert!(broadcasted.is_some());
if broadcasted.is_none() {
let shapes = shapes.iter().map(|s| (*s).clone()).collect::<Vec<_>>();
GuardError::shape_mismatch(self, shapes, None).panic(Location::caller());
}
}

#[track_caller]
fn check_dtypes(&self) {
assert_eq!(self.lhs.dt(), self.rhs.dt());
if self.lhs.dt() != self.rhs.dt() {
GuardError::dtype_mismatch(self, vec![self.lhs.dt(), self.rhs.dt()])
.panic(Location::caller());
}
}
}

Expand Down
25 changes: 21 additions & 4 deletions crates/ratchet-core/src/ops/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use wgpu::BindGroupLayoutEntry;

use crate::{
gpu::{BindGroupLayoutDescriptor, BindGroupLayoutEntryExt},
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, Kernel, KernelElement,
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, GuardError, Kernel, KernelElement,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, Shape,
StorageView, Strides, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize,
Workload,
Expand Down Expand Up @@ -118,12 +118,29 @@ pub struct CacheMeta {

impl OpGuards for Cache {
fn check_shapes(&self) {
assert!(self.cache.rank() >= 3);
assert!(self.offset <= self.cache.shape()[self.dim]);
if self.cache.rank() < 3 {
let msg = format!(
"Cache tensor must have rank >= 3, got {}",
self.cache.rank()
);
GuardError::custom(self, msg);
}

if self.offset > self.cache.shape()[self.dim] {
let msg = format!(
"Cache capacity exceeded, attempted to write at offset {} in dim {}, but cache size is {}",
self.offset,
self.dim,
self.cache.shape()[self.dim]
);
GuardError::custom(self, msg);
}
}

fn check_dtypes(&self) {
assert_eq!(self.cache.dt(), self.source.dt());
if self.cache.dt() != self.source.dt() {
GuardError::dtype_mismatch(self, vec![self.cache.dt(), self.source.dt()]);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def cast(a):
let device = Device::request_device(DeviceRequest::GPU).unwrap();
let device_precision = device.compute_precision();
if matches!(device_precision, DType::F32) {
return Ok(())
return Ok(());
}
let CastProblem { dst_dt, B, M, N } = prob;
let input = Tensor::randn::<f32>(shape![B, M, N], Device::CPU);
Expand Down
18 changes: 5 additions & 13 deletions crates/ratchet-core/src/ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use std::cmp::Ordering;

use crate::{
gpu::{BindGroupLayoutDescriptor, CpuUniform},
rvec, DType, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, KernelRenderable,
KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView, Strides, Tensor,
WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
rvec, DType, GPUOperation, GuardError, Kernel, KernelElement, KernelKey, KernelMetadata,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView,
Strides, Tensor, WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};

//https://link.springer.com/chapter/10.1007/978-3-642-29737-3_42
Expand Down Expand Up @@ -404,20 +404,12 @@ impl OpGuards for Matmul {
];

if !allowed_pairs.contains(&(self.lhs.dt(), self.rhs.dt())) {
panic!(
"DType mismatch: lhs: {:?}, rhs: {:?}",
self.lhs.dt(),
self.rhs.dt()
);
GuardError::dtype_mismatch(self, vec![self.lhs.dt(), self.rhs.dt()]);
}

if let Some(bias) = &self.bias {
if bias.dt() != self.rhs.dt() {
panic!(
"DType mismatch: bias: {:?}, rhs: {:?}",
bias.dt(),
self.rhs.dt()
);
GuardError::dtype_mismatch(self, vec![bias.dt(), self.rhs.dt()]);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/ops/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Operation for IndexSelect {
let strides = Strides::from(&output_shape);
Ok(StorageView::new(
output_shape,
self.src.dt().activation_dt(),
self.src.dt().compute_dt(),
strides,
))
}
Expand Down
29 changes: 24 additions & 5 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::panic::Location;

use derive_new::new;
use encase::ShaderType;
use half::f16;
Expand All @@ -6,9 +8,10 @@ use ratchet_macros::WgslMetadata;

use crate::{
gpu::{dtype::WgslDType, BindGroupLayoutDescriptor},
rvec, wgc, wgs, Array, BindingMode, BuiltIn, DType, GPUOperation, Kernel, KernelElement,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, StorageView,
Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize, Workload,
rvec, wgc, wgs, Array, BindingMode, BuiltIn, DType, GPUOperation, GuardError, Kernel,
KernelElement, KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec,
Scalar, StorageView, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize,
Workload,
};

#[derive(new, Debug, Clone)]
Expand All @@ -26,10 +29,26 @@ pub struct SoftmaxMeta {
}

impl OpGuards for Softmax {
#[track_caller]
fn check_shapes(&self) {
let input = &self.input;
assert!(input.rank() >= 2);
assert!(self.dim < input.rank());

if input.rank() < 2 {
GuardError::custom(
self,
format!("Input rank must be at least 2, got: {}", input.rank()),
)
.panic(Location::caller());
}

if self.dim >= input.rank() {
let msg = format!(
"Dim {} is out of bounds for input with rank {}",
self.dim,
input.rank(),
);
GuardError::custom(self, msg).panic(Location::caller());
}
}

fn check_dtypes(&self) {
Expand Down
1 change: 0 additions & 1 deletion crates/ratchet-core/src/ops/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use half::f16;
use inline_wgsl::wgsl;
use ratchet_macros::WgslMetadata;

use strum::IntoEnumIterator;
use strum_macros::EnumIter;

use crate::{
Expand Down
42 changes: 40 additions & 2 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ impl Tensor {
return Ok(self);
}

let dst_dt = if dst_dt.is_quantized() {
log::warn!(
"Cannot cast to quantized type: {:?}, casting to associated compute precision: {:?}",
dst_dt,
dst_dt.compute_dt()
);
dst_dt.compute_dt()
} else {
dst_dt
};

let device = self.device.clone();
let cast = Cast::new(self, dst_dt);
let new_view = cast.compute_view()?;
Expand Down Expand Up @@ -384,7 +395,15 @@ impl Tensor {
//TODO: horrific interface
pub fn matmul(self, rhs: Tensor, trans_lhs: bool, trans_rhs: bool) -> anyhow::Result<Tensor> {
let device = self.device.clone();
let matmul = Matmul::new(self, rhs, None, trans_lhs, trans_rhs, false);

let (lhs, rhs) = if self.dt() != rhs.dt() {
let unified_dt = self.dt();
(self, rhs.cast(unified_dt)?)
} else {
(self, rhs)
};

let matmul = Matmul::new(lhs, rhs, None, trans_lhs, trans_rhs, false);
let new_view = matmul.compute_view()?;
Ok(Tensor::lazy(LazyOp::Matmul(matmul), new_view, device))
}
Expand All @@ -398,7 +417,26 @@ impl Tensor {
trans_out: bool,
) -> anyhow::Result<Tensor> {
let device = self.device.clone();
let gemm = Matmul::new(self, rhs, bias, trans_lhs, trans_rhs, trans_out);

let (lhs, rhs) = if self.dt() != rhs.dt() {
let unified_dt = self.dt();
(self, rhs.cast(unified_dt)?)
} else {
(self, rhs)
};

// Cast bias if required
let bias = if let Some(b) = bias {
if b.dt() != rhs.dt() {
Some(b.cast(rhs.dt())?)
} else {
Some(b)
}
} else {
None
};

let gemm = Matmul::new(lhs, rhs, bias, trans_lhs, trans_rhs, trans_out);
let new_view = gemm.compute_view()?;
Ok(Tensor::lazy(LazyOp::Matmul(gemm), new_view, device))
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-models/src/moondream/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl Moondream {
LayerNorm::new(lt("vision_encoder.encoder.model.visual.norm.weight"), Some(lt("vision_encoder.encoder.model.visual.norm.bias")), ln_eps),
);

let vision_encoder = VisionEncoder::new(projection, transformer);
let vision_encoder = VisionEncoder::new(transformer, projection);
Ok(Self {
vision_encoder,
text_model,
Expand Down
3 changes: 2 additions & 1 deletion crates/ratchet-models/src/moondream/vision_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,16 @@ impl Module for VisionProjection {

#[derive(Debug, derive_new::new)]
pub struct VisionEncoder {
projection: VisionProjection,
transformer: VisionTransformer,
projection: VisionProjection,
}

impl Module for VisionEncoder {
type Input = Tensor;

fn schedule(&self, input: Self::Input) -> anyhow::Result<Tensor> {
let transformed = self.transformer.schedule(input)?;
log::warn!("SUCCESSFULLY TRANSFORMED");
self.projection.schedule(Tensor::cat(
rvec![transformed.clone(), transformed.clone()],
2,
Expand Down
Loading
Loading