Skip to content

Commit

Permalink
Add mat! macro for a nice way to write matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
jdahlstrom committed Dec 27, 2024
1 parent ca812d2 commit 91bac87
Showing 1 changed file with 98 additions and 99 deletions.
197 changes: 98 additions & 99 deletions core/src/math/mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use core::{
array::{self, from_fn},
fmt::{self, Debug, Formatter},
marker::PhantomData,
marker::PhantomData as Pd,
ops::Range,
};

Expand Down Expand Up @@ -43,17 +43,17 @@ pub trait Compose<Inner: LinearMap>: LinearMap<Source = Inner::Dest> {
/// A change of basis in real vector space of dimension `DIM`.
#[derive(Copy, Clone, Default, Eq, PartialEq)]
pub struct RealToReal<const DIM: usize, SrcBasis = (), DstBasis = ()>(
PhantomData<(SrcBasis, DstBasis)>,
Pd<(SrcBasis, DstBasis)>,
);

/// Mapping from real to projective space.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct RealToProj<SrcBasis>(PhantomData<SrcBasis>);
pub struct RealToProj<SrcBasis>(Pd<SrcBasis>);

/// A generic matrix type.
#[repr(transparent)]
#[derive(Copy, Eq, PartialEq)]
pub struct Matrix<Repr, Map>(pub Repr, PhantomData<Map>);
pub struct Matrix<Repr, Map>(pub Repr, Pd<Map>);

/// Type alias for a 3x3 float matrix.
pub type Mat3x3<Map = ()> = Matrix<[[f32; 3]; 3], Map>;
Expand All @@ -64,11 +64,17 @@ pub type Mat4x4<Map = ()> = Matrix<[[f32; 4]; 4], Map>;
// Inherent impls
//

macro_rules! mat {
($($i:expr),+; $($j:expr),+; $($k:expr),+; $($($l:expr),+)? $(;)?) => {
Matrix([[$($i),+], [$($j),+], [$($k),+], $([$($l),+])?], Pd)
};
}

impl<Repr, Map> Matrix<Repr, Map> {
/// Returns a matrix with the given elements.
#[inline]
pub const fn new(els: Repr) -> Self {
Self(els, PhantomData)
Self(els, Pd)
}

/// Returns a matrix equal to `self` but with mapping `M`.
Expand All @@ -80,7 +86,7 @@ impl<Repr, Map> Matrix<Repr, Map> {
where
Repr: Clone,
{
Matrix::new(self.0.clone())
Matrix(self.0.clone(), Pd)
}
}

Expand Down Expand Up @@ -149,12 +155,12 @@ impl Mat4x4 {
k: Vec3<D>,
) -> Mat4x4<RealToReal<3, S, D>> {
let (i, j, k) = (i.0, j.0, k.0);
Mat4x4::new([
[i[0], j[0], k[0], 0.0],
[i[1], j[1], k[1], 0.0],
[i[2], j[2], k[2], 0.0],
[0.0, 0.0, 0.0, 1.0],
])
mat![
i[0], j[0], k[0], 0.0;
i[1], j[1], k[1], 0.0;
i[2], j[2], k[2], 0.0;
0.0, 0.0, 0.0, 1.0
]
}
}

Expand Down Expand Up @@ -472,7 +478,7 @@ where

impl<Repr, M> From<Repr> for Matrix<Repr, M> {
fn from(repr: Repr) -> Self {
Self(repr, PhantomData)
Self(repr, Pd)
}
}

Expand All @@ -495,12 +501,12 @@ pub const fn scale(s: Vec3) -> Mat4x4<RealToReal<3>> {
}

pub const fn scale3(x: f32, y: f32, z: f32) -> Mat4x4<RealToReal<3>> {
Matrix::new([
[x, 0.0, 0.0, 0.0],
[0.0, y, 0.0, 0.0],
[0.0, 0.0, z, 0.0],
[0.0, 0.0, 0.0, 1.0],
])
mat! [
x, 0.0, 0.0, 0.0;
0.0, y, 0.0, 0.0;
0.0, 0.0, z, 0.0;
0.0, 0.0, 0.0, 1.0;
]
}

/// Returns a matrix applying a translation by `t`.
Expand All @@ -509,12 +515,12 @@ pub const fn translate(t: Vec3) -> Mat4x4<RealToReal<3>> {
}

pub const fn translate3(x: f32, y: f32, z: f32) -> Mat4x4<RealToReal<3>> {
Matrix::new([
[1.0, 0.0, 0.0, x],
[0.0, 1.0, 0.0, y],
[0.0, 0.0, 1.0, z],
[0.0, 0.0, 0.0, 1.0],
])
mat![
1.0, 0.0, 0.0, x ;
0.0, 1.0, 0.0, y ;
0.0, 0.0, 1.0, z ;
0.0, 0.0, 0.0, 1.0;
]
}

#[cfg(feature = "fp")]
Expand Down Expand Up @@ -574,54 +580,50 @@ fn orient(new_y: Vec3, new_z: Vec3) -> Mat4x4<RealToReal<3>> {
#[cfg(feature = "fp")]
pub fn rotate_x(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[1.0, 0.0, 0.0, 0.0],
[0.0, cos, sin, 0.0],
[0.0, -sin, cos, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
1.0, 0.0, 0.0, 0.0;
0.0, cos, sin, 0.0;
0.0, -sin, cos, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}
/// Returns a matrix applying a 3D rotation about the y axis.
#[cfg(feature = "fp")]
pub fn rotate_y(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[cos, 0.0, -sin, 0.0],
[0.0, 1.0, 0.0, 0.0],
[sin, 0.0, cos, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
cos, 0.0, -sin, 0.0;
0.0, 1.0, 0.0, 0.0;
sin, 0.0, cos, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}
/// Returns a matrix applying a 3D rotation about the z axis.
#[cfg(feature = "fp")]
pub fn rotate_z(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[cos, sin, 0.0, 0.0],
[-sin, cos, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
cos, sin, 0.0, 0.0;
-sin, cos, 0.0, 0.0;
0.0, 0.0, 1.0, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

/// Returns a matrix applying a 2D rotation by an angle.
#[cfg(feature = "fp")]
pub fn rotate2(a: super::angle::Angle) -> Mat3x3<RealToReal<2>> {
pub fn rotate2(a: Angle) -> Mat3x3<RealToReal<2>> {
let (sin, cos) = a.sin_cos();
[
[cos, sin, 0.0], //
[-sin, cos, 0.0], //
[0.0, 0.0, 1.0],
mat![
cos, sin, 0.0;
-sin, cos, 0.0;
0.0, 0.0, 1.0;
]
.into()
}

/// Returns a matrix applying a 3D rotation about an arbitrary axis.
#[cfg(feature = "fp")]
pub fn rotate(axis: Vec3, a: super::angle::Angle) -> Mat4x4<RealToReal<3>> {
pub fn rotate(axis: Vec3, a: Angle) -> Mat4x4<RealToReal<3>> {
use crate::math::approx::ApproxEq;

// 1. Change of basis such that `axis` is mapped to the z-axis,
Expand Down Expand Up @@ -668,13 +670,12 @@ pub fn perspective(
let e11 = e00 * aspect_ratio;
let e22 = (far + near) / (far - near);
let e23 = 2.0 * far * near / (near - far);
[
[e00, 0.0, 0.0, 0.0],
[0.0, e11, 0.0, 0.0],
[0.0, 0.0, e22, e23],
[0.0, 0.0, 1.0, 0.0],
mat![
e00, 0.0, 0.0, 0.0;
0.0, e11, 0.0, 0.0;
0.0, 0.0, e22, e23;
0.0, 0.0, 1.0, 0.0;
]
.into()
}

/// Creates an orthographic projection matrix.
Expand All @@ -686,13 +687,12 @@ pub fn orthographic(lbn: Point3, rtf: Point3) -> Mat4x4<ViewToProj> {
let half_d = (rtf - lbn) / 2.0;
let [cx, cy, cz] = (lbn + half_d).0;
let [idx, idy, idz] = half_d.map(f32::recip).0;
[
[idx, 0.0, 0.0, -cx * idx],
[0.0, idy, 0.0, -cy * idy],
[0.0, 0.0, idz, -cz * idz],
[0.0, 0.0, 0.0, 1.0],
mat![
idx, 0.0, 0.0, -cx * idx;
0.0, idy, 0.0, -cy * idy;
0.0, 0.0, idz, -cz * idz;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

/// Creates a viewport transform matrix with the given pixel space bounds.
Expand All @@ -706,13 +706,12 @@ pub fn viewport(bounds: Range<Point2u>) -> Mat4x4<NdcToScreen> {
let half_d = (e - s) / 2.0;
let [dx, dy] = half_d.0;
let [cx, cy] = (s + half_d).0;
[
[dx, 0.0, 0.0, cx],
[0.0, dy, 0.0, cy],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
dx, 0.0, 0.0, cx;
0.0, dy, 0.0, cy;
0.0, 0.0, 1.0, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

#[cfg(test)]
Expand Down Expand Up @@ -742,11 +741,11 @@ mod tests {
use super::*;
use crate::math::pt2;

const MAT: Mat3x3<Map> = Matrix::new([
[0.0, 1.0, 2.0], //
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
]);
const MAT: Mat3x3<Map> = mat![
0.0, 1.0, 2.0;
10.0, 11.0, 12.0;
20.0, 21.0, 22.0;
];

#[test]
fn row_col_vecs() {
Expand All @@ -756,16 +755,16 @@ mod tests {

#[test]
fn composition() {
let t = Mat3x3::<Map<2>>::new([
[1.0, 0.0, 2.0], //
[0.0, 1.0, -3.0],
[0.0, 0.0, 1.0],
]);
let s = Mat3x3::<InvMap<2>>::new([
[-1.0, 0.0, 0.0], //
[0.0, 2.0, 0.0],
[0.0, 0.0, 1.0],
]);
let t: Mat3x3<Map<2>> = mat![
1.0, 0.0, 2.0;
0.0, 1.0, -3.0;
0.0, 0.0, 1.0;
];
let s: Mat3x3<InvMap<2>> = mat![
-1.0, 0.0, 0.0;
0.0, 2.0, 0.0;
0.0, 0.0, 1.0;
];

let ts = t.then(&s);
let st = s.then(&t);
Expand All @@ -779,22 +778,22 @@ mod tests {

#[test]
fn scaling() {
let m = Mat3x3::<Map<2>>::new([
[2.0, 0.0, 0.0], //
[0.0, -3.0, 0.0],
[0.0, 0.0, 1.0],
]);
let m: Mat3x3<Map<2>> = mat![
2.0, 0.0, 0.0;
0.0, -3.0, 0.0;
0.0, 0.0, 1.0;
];
assert_eq!(m.apply(&vec2(1.0, 2.0)), vec2(2.0, -6.0));
assert_eq!(m.apply_pt(&pt2(2.0, -1.0)), pt2(4.0, 3.0));
}

#[test]
fn translation() {
let m = Mat3x3::<Map<2>>::new([
[1.0, 0.0, 2.0], //
[0.0, 1.0, -3.0],
[0.0, 0.0, 1.0],
]);
let m: Mat3x3<Map<2>> = mat![
1.0, 0.0, 2.0;
0.0, 1.0, -3.0;
0.0, 0.0, 1.0;
];
assert_eq!(m.apply(&vec2(1.0, 2.0)), vec2(3.0, -1.0));
assert_eq!(m.apply_pt(&pt2(2.0, -1.0)), pt2(4.0, -4.0));
}
Expand All @@ -816,12 +815,12 @@ mod tests {
use super::*;
use crate::math::pt3;

const MAT: Mat4x4<Map> = Matrix::new([
[0.0, 1.0, 2.0, 3.0],
[10.0, 11.0, 12.0, 13.0],
[20.0, 21.0, 22.0, 23.0],
[30.0, 31.0, 32.0, 33.0],
]);
const MAT: Mat4x4<Map> = mat![
0.0, 1.0, 2.0, 3.0;
10.0, 11.0, 12.0, 13.0;
20.0, 21.0, 22.0, 23.0;
30.0, 31.0, 32.0, 33.0;
];

#[test]
fn row_col_vecs() {
Expand Down

0 comments on commit 91bac87

Please sign in to comment.