Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mat! macro for a nice way to write matrices #219

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 98 additions & 99 deletions core/src/math/mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use core::{
array::{self, from_fn},
fmt::{self, Debug, Formatter},
marker::PhantomData,
marker::PhantomData as Pd,
ops::Range,
};

Expand Down Expand Up @@ -43,17 +43,17 @@ pub trait Compose<Inner: LinearMap>: LinearMap<Source = Inner::Dest> {
/// A change of basis in real vector space of dimension `DIM`.
#[derive(Copy, Clone, Default, Eq, PartialEq)]
pub struct RealToReal<const DIM: usize, SrcBasis = (), DstBasis = ()>(
PhantomData<(SrcBasis, DstBasis)>,
Pd<(SrcBasis, DstBasis)>,
);

/// Mapping from real to projective space.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct RealToProj<SrcBasis>(PhantomData<SrcBasis>);
pub struct RealToProj<SrcBasis>(Pd<SrcBasis>);

/// A generic matrix type.
#[repr(transparent)]
#[derive(Copy, Eq, PartialEq)]
pub struct Matrix<Repr, Map>(pub Repr, PhantomData<Map>);
pub struct Matrix<Repr, Map>(pub Repr, Pd<Map>);

/// Type alias for a 3x3 float matrix.
pub type Mat3x3<Map = ()> = Matrix<[[f32; 3]; 3], Map>;
Expand All @@ -64,11 +64,17 @@ pub type Mat4x4<Map = ()> = Matrix<[[f32; 4]; 4], Map>;
// Inherent impls
//

macro_rules! mat {
($($i:expr),+; $($j:expr),+; $($k:expr),+; $($($l:expr),+)? $(;)?) => {
Matrix([[$($i),+], [$($j),+], [$($k),+], $([$($l),+])?], Pd)
};
}

impl<Repr, Map> Matrix<Repr, Map> {
/// Returns a matrix with the given elements.
#[inline]
pub const fn new(els: Repr) -> Self {
Self(els, PhantomData)
Self(els, Pd)
}

/// Returns a matrix equal to `self` but with mapping `M`.
Expand All @@ -80,7 +86,7 @@ impl<Repr, Map> Matrix<Repr, Map> {
where
Repr: Clone,
{
Matrix::new(self.0.clone())
Matrix(self.0.clone(), Pd)
}
}

Expand Down Expand Up @@ -149,12 +155,12 @@ impl Mat4x4 {
k: Vec3<D>,
) -> Mat4x4<RealToReal<3, S, D>> {
let (i, j, k) = (i.0, j.0, k.0);
Mat4x4::new([
[i[0], j[0], k[0], 0.0],
[i[1], j[1], k[1], 0.0],
[i[2], j[2], k[2], 0.0],
[0.0, 0.0, 0.0, 1.0],
])
mat![
i[0], j[0], k[0], 0.0;
i[1], j[1], k[1], 0.0;
i[2], j[2], k[2], 0.0;
0.0, 0.0, 0.0, 1.0
]
}
}

Expand Down Expand Up @@ -472,7 +478,7 @@ where

impl<Repr, M> From<Repr> for Matrix<Repr, M> {
fn from(repr: Repr) -> Self {
Self(repr, PhantomData)
Self(repr, Pd)
}
}

Expand All @@ -495,12 +501,12 @@ pub const fn scale(s: Vec3) -> Mat4x4<RealToReal<3>> {
}

pub const fn scale3(x: f32, y: f32, z: f32) -> Mat4x4<RealToReal<3>> {
Matrix::new([
[x, 0.0, 0.0, 0.0],
[0.0, y, 0.0, 0.0],
[0.0, 0.0, z, 0.0],
[0.0, 0.0, 0.0, 1.0],
])
mat! [
x, 0.0, 0.0, 0.0;
0.0, y, 0.0, 0.0;
0.0, 0.0, z, 0.0;
0.0, 0.0, 0.0, 1.0;
]
}

/// Returns a matrix applying a translation by `t`.
Expand All @@ -509,12 +515,12 @@ pub const fn translate(t: Vec3) -> Mat4x4<RealToReal<3>> {
}

pub const fn translate3(x: f32, y: f32, z: f32) -> Mat4x4<RealToReal<3>> {
Matrix::new([
[1.0, 0.0, 0.0, x],
[0.0, 1.0, 0.0, y],
[0.0, 0.0, 1.0, z],
[0.0, 0.0, 0.0, 1.0],
])
mat![
1.0, 0.0, 0.0, x ;
0.0, 1.0, 0.0, y ;
0.0, 0.0, 1.0, z ;
0.0, 0.0, 0.0, 1.0;
]
}

#[cfg(feature = "fp")]
Expand Down Expand Up @@ -574,54 +580,50 @@ fn orient(new_y: Vec3, new_z: Vec3) -> Mat4x4<RealToReal<3>> {
#[cfg(feature = "fp")]
pub fn rotate_x(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[1.0, 0.0, 0.0, 0.0],
[0.0, cos, sin, 0.0],
[0.0, -sin, cos, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
1.0, 0.0, 0.0, 0.0;
0.0, cos, sin, 0.0;
0.0, -sin, cos, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}
/// Returns a matrix applying a 3D rotation about the y axis.
#[cfg(feature = "fp")]
pub fn rotate_y(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[cos, 0.0, -sin, 0.0],
[0.0, 1.0, 0.0, 0.0],
[sin, 0.0, cos, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
cos, 0.0, -sin, 0.0;
0.0, 1.0, 0.0, 0.0;
sin, 0.0, cos, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}
/// Returns a matrix applying a 3D rotation about the z axis.
#[cfg(feature = "fp")]
pub fn rotate_z(a: Angle) -> Mat4x4<RealToReal<3>> {
let (sin, cos) = a.sin_cos();
[
[cos, sin, 0.0, 0.0],
[-sin, cos, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
cos, sin, 0.0, 0.0;
-sin, cos, 0.0, 0.0;
0.0, 0.0, 1.0, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

/// Returns a matrix applying a 2D rotation by an angle.
#[cfg(feature = "fp")]
pub fn rotate2(a: super::angle::Angle) -> Mat3x3<RealToReal<2>> {
pub fn rotate2(a: Angle) -> Mat3x3<RealToReal<2>> {
let (sin, cos) = a.sin_cos();
[
[cos, sin, 0.0], //
[-sin, cos, 0.0], //
[0.0, 0.0, 1.0],
mat![
cos, sin, 0.0;
-sin, cos, 0.0;
0.0, 0.0, 1.0;
]
.into()
}

/// Returns a matrix applying a 3D rotation about an arbitrary axis.
#[cfg(feature = "fp")]
pub fn rotate(axis: Vec3, a: super::angle::Angle) -> Mat4x4<RealToReal<3>> {
pub fn rotate(axis: Vec3, a: Angle) -> Mat4x4<RealToReal<3>> {
use crate::math::approx::ApproxEq;

// 1. Change of basis such that `axis` is mapped to the z-axis,
Expand Down Expand Up @@ -668,13 +670,12 @@ pub fn perspective(
let e11 = e00 * aspect_ratio;
let e22 = (far + near) / (far - near);
let e23 = 2.0 * far * near / (near - far);
[
[e00, 0.0, 0.0, 0.0],
[0.0, e11, 0.0, 0.0],
[0.0, 0.0, e22, e23],
[0.0, 0.0, 1.0, 0.0],
mat![
e00, 0.0, 0.0, 0.0;
0.0, e11, 0.0, 0.0;
0.0, 0.0, e22, e23;
0.0, 0.0, 1.0, 0.0;
]
.into()
}

/// Creates an orthographic projection matrix.
Expand All @@ -686,13 +687,12 @@ pub fn orthographic(lbn: Point3, rtf: Point3) -> Mat4x4<ViewToProj> {
let half_d = (rtf - lbn) / 2.0;
let [cx, cy, cz] = (lbn + half_d).0;
let [idx, idy, idz] = half_d.map(f32::recip).0;
[
[idx, 0.0, 0.0, -cx * idx],
[0.0, idy, 0.0, -cy * idy],
[0.0, 0.0, idz, -cz * idz],
[0.0, 0.0, 0.0, 1.0],
mat![
idx, 0.0, 0.0, -cx * idx;
0.0, idy, 0.0, -cy * idy;
0.0, 0.0, idz, -cz * idz;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

/// Creates a viewport transform matrix with the given pixel space bounds.
Expand All @@ -706,13 +706,12 @@ pub fn viewport(bounds: Range<Point2u>) -> Mat4x4<NdcToScreen> {
let half_d = (e - s) / 2.0;
let [dx, dy] = half_d.0;
let [cx, cy] = (s + half_d).0;
[
[dx, 0.0, 0.0, cx],
[0.0, dy, 0.0, cy],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
mat![
dx, 0.0, 0.0, cx;
0.0, dy, 0.0, cy;
0.0, 0.0, 1.0, 0.0;
0.0, 0.0, 0.0, 1.0;
]
.into()
}

#[cfg(test)]
Expand Down Expand Up @@ -742,11 +741,11 @@ mod tests {
use super::*;
use crate::math::pt2;

const MAT: Mat3x3<Map> = Matrix::new([
[0.0, 1.0, 2.0], //
[10.0, 11.0, 12.0],
[20.0, 21.0, 22.0],
]);
const MAT: Mat3x3<Map> = mat![
0.0, 1.0, 2.0;
10.0, 11.0, 12.0;
20.0, 21.0, 22.0;
];

#[test]
fn row_col_vecs() {
Expand All @@ -756,16 +755,16 @@ mod tests {

#[test]
fn composition() {
let t = Mat3x3::<Map<2>>::new([
[1.0, 0.0, 2.0], //
[0.0, 1.0, -3.0],
[0.0, 0.0, 1.0],
]);
let s = Mat3x3::<InvMap<2>>::new([
[-1.0, 0.0, 0.0], //
[0.0, 2.0, 0.0],
[0.0, 0.0, 1.0],
]);
let t: Mat3x3<Map<2>> = mat![
1.0, 0.0, 2.0;
0.0, 1.0, -3.0;
0.0, 0.0, 1.0;
];
let s: Mat3x3<InvMap<2>> = mat![
-1.0, 0.0, 0.0;
0.0, 2.0, 0.0;
0.0, 0.0, 1.0;
];

let ts = t.then(&s);
let st = s.then(&t);
Expand All @@ -779,22 +778,22 @@ mod tests {

#[test]
fn scaling() {
let m = Mat3x3::<Map<2>>::new([
[2.0, 0.0, 0.0], //
[0.0, -3.0, 0.0],
[0.0, 0.0, 1.0],
]);
let m: Mat3x3<Map<2>> = mat![
2.0, 0.0, 0.0;
0.0, -3.0, 0.0;
0.0, 0.0, 1.0;
];
assert_eq!(m.apply(&vec2(1.0, 2.0)), vec2(2.0, -6.0));
assert_eq!(m.apply_pt(&pt2(2.0, -1.0)), pt2(4.0, 3.0));
}

#[test]
fn translation() {
let m = Mat3x3::<Map<2>>::new([
[1.0, 0.0, 2.0], //
[0.0, 1.0, -3.0],
[0.0, 0.0, 1.0],
]);
let m: Mat3x3<Map<2>> = mat![
1.0, 0.0, 2.0;
0.0, 1.0, -3.0;
0.0, 0.0, 1.0;
];
assert_eq!(m.apply(&vec2(1.0, 2.0)), vec2(3.0, -1.0));
assert_eq!(m.apply_pt(&pt2(2.0, -1.0)), pt2(4.0, -4.0));
}
Expand All @@ -816,12 +815,12 @@ mod tests {
use super::*;
use crate::math::pt3;

const MAT: Mat4x4<Map> = Matrix::new([
[0.0, 1.0, 2.0, 3.0],
[10.0, 11.0, 12.0, 13.0],
[20.0, 21.0, 22.0, 23.0],
[30.0, 31.0, 32.0, 33.0],
]);
const MAT: Mat4x4<Map> = mat![
0.0, 1.0, 2.0, 3.0;
10.0, 11.0, 12.0, 13.0;
20.0, 21.0, 22.0, 23.0;
30.0, 31.0, 32.0, 33.0;
];

#[test]
fn row_col_vecs() {
Expand Down
Loading