Skip to content

Commit

Permalink
Element-wise sum and product for vectors. (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ababwa committed Jan 25, 2024
1 parent 76f6a19 commit 441c8c4
Show file tree
Hide file tree
Showing 35 changed files with 702 additions and 0 deletions.
94 changes: 94 additions & 0 deletions codegen/templates/vec.rs.tera
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,100 @@ impl {{ self_t }} {
{% endif %}
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> {{ scalar_t }} {
{% if is_scalar %}
{% for c in components %}
self.{{ c }} {% if not loop.last %} + {% endif %}
{%- endfor %}
{% elif is_sse2 %}
{% if dim == 3 %}
unsafe {
let v = self.0;
let v = _mm_add_ps(v, _mm_shuffle_ps(v, Self::ZERO.0, 0b00_11_00_01));
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
{% elif dim == 4 %}
unsafe {
let v = self.0;
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_11_00_01));
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
{% endif %}
{% elif is_wasm32 %}
{% if dim == 3 %}
let v = self.0;
let v = f32x4_add(v, i32x4_shuffle::<1, 0, 4, 0>(v, Self::ZERO.0));
let v = f32x4_add(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
{% elif dim == 4 %}
let v = self.0;
let v = f32x4_add(v, i32x4_shuffle::<1, 0, 3, 0>(v, v));
let v = f32x4_add(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
{% endif %}
{% elif is_coresimd %}
{% if dim == 3 %}
simd_swizzle!(self.0, Self::ZERO.0, [0, 1, 2, 4]).reduce_sum()
{% elif dim == 4 %}
self.0.reduce_sum()
{% endif %}
{% endif %}
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> {{ scalar_t }} {
{% if is_scalar %}
{% for c in components %}
self.{{ c }} {% if not loop.last %} * {% endif %}
{%- endfor %}
{% elif is_sse2 %}
{% if dim == 3 %}
unsafe {
let v = self.0;
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, Self::ONE.0, 0b00_11_00_01));
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
{% elif dim == 4 %}
unsafe {
let v = self.0;
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_11_00_01));
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
{% endif %}
{% elif is_wasm32 %}
{% if dim == 3 %}
let v = self.0;
let v = f32x4_mul(v, i32x4_shuffle::<1, 0, 4, 0>(v, Self::ONE.0));
let v = f32x4_mul(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
{% elif dim == 4 %}
let v = self.0;
let v = f32x4_mul(v, i32x4_shuffle::<1, 0, 3, 0>(v, v));
let v = f32x4_mul(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
{% endif %}
{% elif is_coresimd %}
{% if dim == 3 %}
simd_swizzle!(self.0, Self::ONE.0, [0, 1, 2, 4]).reduce_product()
{% elif dim == 4 %}
self.0.reduce_product()
{% endif %}
{% endif %}
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/coresimd/vec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,24 @@ impl Vec3A {
v[0]
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
simd_swizzle!(self.0, Self::ZERO.0, [0, 1, 2, 4]).reduce_sum()
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
simd_swizzle!(self.0, Self::ONE.0, [0, 1, 2, 4]).reduce_product()
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/coresimd/vec4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,24 @@ impl Vec4 {
self.0.reduce_max()
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
self.0.reduce_sum()
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
self.0.reduce_product()
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/scalar/vec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ impl Vec3A {
self.x.max(self.y.max(self.z))
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
self.x + self.y + self.z
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
self.x * self.y * self.z
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/scalar/vec4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,24 @@ impl Vec4 {
self.x.max(self.y.max(self.z.max(self.w)))
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
self.x + self.y + self.z + self.w
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
self.x * self.y * self.z * self.w
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
28 changes: 28 additions & 0 deletions src/f32/sse2/vec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,34 @@ impl Vec3A {
}
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
unsafe {
let v = self.0;
let v = _mm_add_ps(v, _mm_shuffle_ps(v, _mm_setzero_ps(), 0b00_11_00_01));
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
unsafe {
let v = self.0;
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, _mm_set1_ps(1.0), 0b00_11_00_01));
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
28 changes: 28 additions & 0 deletions src/f32/sse2/vec4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,34 @@ impl Vec4 {
}
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
unsafe {
let v = self.0;
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_11_00_01));
let v = _mm_add_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
unsafe {
let v = self.0;
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_11_00_01));
let v = _mm_mul_ps(v, _mm_shuffle_ps(v, v, 0b00_00_00_10));
_mm_cvtss_f32(v)
}
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/vec2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,24 @@ impl Vec2 {
self.x.max(self.y)
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
self.x + self.y
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
self.x * self.y
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
18 changes: 18 additions & 0 deletions src/f32/vec3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,24 @@ impl Vec3 {
self.x.max(self.y.max(self.z))
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
self.x + self.y + self.z
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
self.x * self.y * self.z
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
24 changes: 24 additions & 0 deletions src/f32/wasm32/vec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,30 @@ impl Vec3A {
f32x4_extract_lane::<0>(v)
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
let v = self.0;
let v = f32x4_add(v, i32x4_shuffle::<1, 0, 4, 0>(v, f32x4_splat(0.0)));
let v = f32x4_add(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
let v = self.0;
let v = f32x4_mul(v, i32x4_shuffle::<1, 0, 4, 0>(v, f32x4_splat(1.0)));
let v = f32x4_mul(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
24 changes: 24 additions & 0 deletions src/f32/wasm32/vec4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,30 @@ impl Vec4 {
f32x4_extract_lane::<0>(v)
}

/// Returns the sum of all elements of `self`.
///
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {
let v = self.0;
let v = f32x4_add(v, i32x4_shuffle::<1, 0, 3, 0>(v, v));
let v = f32x4_add(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {
let v = self.0;
let v = f32x4_mul(v, i32x4_shuffle::<1, 0, 3, 0>(v, v));
let v = f32x4_mul(v, i32x4_shuffle::<2, 0, 0, 0>(v, v));
f32x4_extract_lane::<0>(v)
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
///
Expand Down
Loading

0 comments on commit 441c8c4

Please sign in to comment.