Skip to content

Commit

Permalink
+10% using SSE in packet traversal.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbikker committed Nov 6, 2024
1 parent 6891070 commit b74235d
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 17 deletions.
227 changes: 210 additions & 17 deletions tiny_bvh.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ THE SOFTWARE.
// gcc / clang
#include <cstdlib>
#include <cmath>

#include <cstring>
#define ALIGNED( x ) __attribute__( ( aligned( x ) ) )

#if defined(__x86_64__) || defined(_M_X64)
// https://stackoverflow.com/questions/32612881/why-use-mm-malloc-as-opposed-to-aligned-malloc-alligned-alloc-or-posix-mem
#include <xmmintrin.h>
Expand Down Expand Up @@ -124,8 +123,8 @@ struct ALIGNED( 16 ) bvhvec4
bvhvec4() = default;
bvhvec4( const float a, const float b, const float c, const float d ) : x( a ), y( b ), z( c ), w( d ) {}
bvhvec4( const float a ) : x( a ), y( a ), z( a ), w( a ) {}
bvhvec4( const bvhvec3& a );
bvhvec4( const bvhvec3& a, float b );
bvhvec4( const bvhvec3 & a );
bvhvec4( const bvhvec3 & a, float b );
float& operator [] ( const int i ) { return cell[i]; }
union { struct { float x, y, z, w; }; float cell[4]; };
};
Expand Down Expand Up @@ -259,11 +258,14 @@ struct Ray
Ray() = default;
Ray( bvhvec3 origin, bvhvec3 direction, float t = 1e30f )
{
memset( this, 0, sizeof( Ray ) );
O = origin, D = normalize( direction ), rD = tinybvh_safercp( D );
hit.t = t;
}
bvhvec3 O, D, rD;
Intersection hit;
ALIGNED( 16 ) bvhvec3 O; unsigned int dummy1;
ALIGNED( 16 ) bvhvec3 D; unsigned int dummy2;
ALIGNED( 16 ) bvhvec3 rD; unsigned int dummy3;
ALIGNED( 16 ) Intersection hit;
};

class BVH
Expand Down Expand Up @@ -293,8 +295,8 @@ class BVH
~BVH()
{
ALIGNED_FREE( bvhNode );
delete [] triIdx;
delete [] fragment;
delete[] triIdx;
delete[] fragment;
bvhNode = 0, triIdx = 0, fragment = 0;
}
float SAHCost( const unsigned int nodeIdx = 0 ) const
Expand All @@ -321,6 +323,7 @@ class BVH
void Refit();
int Intersect( Ray& ray ) const;
void Intersect256Rays( Ray* first ) const;
void Intersect256RaysSSE( Ray* packet ) const;
private:
void IntersectTri( Ray& ray, const unsigned int triIdx ) const;
static float IntersectAABB( const Ray& ray, const bvhvec3& aabbMin, const bvhvec3& aabbMax );
Expand Down Expand Up @@ -747,18 +750,18 @@ void BVH::Intersect256Rays( Ray* packet ) const
// Traverse the tree with the packet
int first = 0, last = 255; // first and last active ray in the packet
BVHNode* node = &bvhNode[0];
ALIGNED(64) unsigned int stack[64], stackPtr = 0;
ALIGNED( 64 ) unsigned int stack[64], stackPtr = 0;
while (1)
{
if (node->isLeaf())
{
// handle leaf node
for (unsigned int j = 0; j < node->triCount; j++)
for (unsigned int j = 0; j < node->triCount; j++)
{
const unsigned int idx = triIdx[node->leftFirst + j], vid = idx * 3;
const bvhvec3 edge1 = verts[vid + 1] - verts[vid], edge2 = verts[vid + 2] - verts[vid];
const bvhvec3 s = O - bvhvec3( verts[vid] );
for( int i = first; i <= last; i++ )
for (int i = first; i <= last; i++)
{
Ray& ray = packet[i];
const bvhvec3 h = cross( ray.D, edge2 );
Expand All @@ -775,7 +778,7 @@ void BVH::Intersect256Rays( Ray* packet ) const
}
}
if (stackPtr == 0) break; else // pop
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
first = last >> 8, last &= 255;
}
else
Expand Down Expand Up @@ -811,15 +814,15 @@ void BVH::Intersect256Rays( Ray* packet ) const
else
{
// 3. Last resort: update first and last, stay in node if first > last
for( ; leftFirst <= leftLast; leftFirst++ )
for (; leftFirst <= leftLast; leftFirst++)
{
const bvhvec3 rD = packet[leftFirst].rD;
const float tx1 = ox1 * rD.x, tx2 = ox2 * rD.x, ty1 = oy1 * rD.y, ty2 = oy2 * rD.y, tz1 = oz1 * rD.z, tz2 = oz2 * rD.z;
const float tmin = tinybvh_max( tinybvh_max( tinybvh_min( tx1, tx2 ), tinybvh_min( ty1, ty2 ) ), tinybvh_min( tz1, tz2 ) );
const float tmax = tinybvh_min( tinybvh_min( tinybvh_max( tx1, tx2 ), tinybvh_max( ty1, ty2 ) ), tinybvh_max( tz1, tz2 ) );
if (tmax >= tmin && tmin < packet[leftFirst].hit.t && tmax >= 0) { distLeft = tmin; break; }
}
for( ; leftLast >= leftFirst; leftLast-- )
for (; leftLast >= leftFirst; leftLast--)
{
const bvhvec3 rD = packet[leftLast].rD;
const float tx1 = ox1 * rD.x, tx2 = ox2 * rD.x, ty1 = oy1 * rD.y, ty2 = oy2 * rD.y, tz1 = oz1 * rD.z, tz2 = oz2 * rD.z;
Expand Down Expand Up @@ -856,15 +859,15 @@ void BVH::Intersect256Rays( Ray* packet ) const
else
{
// 3. Last resort: update first and last, stay in node if first > last
for( ; rightFirst <= rightLast; rightFirst++ )
for (; rightFirst <= rightLast; rightFirst++)
{
const bvhvec3 rD = packet[rightFirst].rD;
const float tx1 = ox1 * rD.x, tx2 = ox2 * rD.x, ty1 = oy1 * rD.y, ty2 = oy2 * rD.y, tz1 = oz1 * rD.z, tz2 = oz2 * rD.z;
const float tmin = tinybvh_max( tinybvh_max( tinybvh_min( tx1, tx2 ), tinybvh_min( ty1, ty2 ) ), tinybvh_min( tz1, tz2 ) );
const float tmax = tinybvh_min( tinybvh_min( tinybvh_max( tx1, tx2 ), tinybvh_max( ty1, ty2 ) ), tinybvh_max( tz1, tz2 ) );
if (tmax >= tmin && tmin < packet[rightFirst].hit.t && tmax >= 0) { distRight = tmin; break; }
}
for( ; rightLast >= first; rightLast-- )
for (; rightLast >= first; rightLast--)
{
const bvhvec3 rD = packet[rightLast].rD;
const float tx1 = ox1 * rD.x, tx2 = ox2 * rD.x, ty1 = oy1 * rD.y, ty2 = oy2 * rD.y, tz1 = oz1 * rD.z, tz2 = oz2 * rD.z;
Expand Down Expand Up @@ -899,12 +902,202 @@ void BVH::Intersect256Rays( Ray* packet ) const
else if (visitRight) // continue with right
node = right, first = rightFirst, last = rightLast;
else if (stackPtr == 0) break; else // pop
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
first = last >> 8, last &= 255;
}
}
}

#ifdef BVH_USEAVX

// Intersect a BVH with a ray packet, basic SSE-optimized version.
void BVH::Intersect256RaysSSE( Ray* packet ) const
{
// Corner rays are: 0, 51, 204 and 255
// Construct the bounding planes, with normals pointing outwards
bvhvec3 O = packet[0].O; // same for all rays in this case
__m128 O4 = *(__m128*)&packet[0].O;
bvhvec3 p0 = packet[0].O + packet[0].D; // top-left
bvhvec3 p1 = packet[51].O + packet[51].D; // top-right
bvhvec3 p2 = packet[204].O + packet[204].D; // bottom-left
bvhvec3 p3 = packet[255].O + packet[255].D; // bottom-right
bvhvec3 plane0 = normalize( cross( p0 - O, p0 - p2 ) ); // left plane
bvhvec3 plane1 = normalize( cross( p3 - O, p3 - p1 ) ); // right plane
bvhvec3 plane2 = normalize( cross( p1 - O, p1 - p0 ) ); // top plane
bvhvec3 plane3 = normalize( cross( p2 - O, p2 - p3 ) ); // bottom plane
int sign0x = plane0.x < 0 ? 4 : 0, sign0y = plane0.y < 0 ? 5 : 1, sign0z = plane0.z < 0 ? 6 : 2;
int sign1x = plane1.x < 0 ? 4 : 0, sign1y = plane1.y < 0 ? 5 : 1, sign1z = plane1.z < 0 ? 6 : 2;
int sign2x = plane2.x < 0 ? 4 : 0, sign2y = plane2.y < 0 ? 5 : 1, sign2z = plane2.z < 0 ? 6 : 2;
int sign3x = plane3.x < 0 ? 4 : 0, sign3y = plane3.y < 0 ? 5 : 1, sign3z = plane3.z < 0 ? 6 : 2;
float t0 = dot( O, plane0 ), t1 = dot( O, plane1 );
float t2 = dot( O, plane2 ), t3 = dot( O, plane3 );
// Traverse the tree with the packet
int first = 0, last = 255; // first and last active ray in the packet
BVHNode* node = &bvhNode[0];
ALIGNED( 64 ) unsigned int stack[64], stackPtr = 0;
while (1)
{
if (node->isLeaf())
{
// handle leaf node
for (unsigned int j = 0; j < node->triCount; j++)
{
const unsigned int idx = triIdx[node->leftFirst + j], vid = idx * 3;
const bvhvec3 edge1 = verts[vid + 1] - verts[vid], edge2 = verts[vid + 2] - verts[vid];
const bvhvec3 s = O - bvhvec3( verts[vid] );
for (int i = first; i <= last; i++)
{
Ray& ray = packet[i];
const bvhvec3 h = cross( ray.D, edge2 );
const float a = dot( edge1, h );
if (fabs( a ) < 0.0000001f) continue; // ray parallel to triangle
const float f = 1 / a, u = f * dot( s, h );
if (u < 0 || u > 1) continue;
const bvhvec3 q = cross( s, edge1 );
const float v = f * dot( ray.D, q );
if (v < 0 || u + v > 1) continue;
const float t = f * dot( edge2, q );
if (t <= 0 || t >= ray.hit.t) continue;
ray.hit.t = t, ray.hit.u = u, ray.hit.v = v, ray.hit.prim = idx;
}
}
if (stackPtr == 0) break; else // pop
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
first = last >> 8, last &= 255;
}
else
{
// fetch pointers to child nodes
BVHNode* left = bvhNode + node->leftFirst;
BVHNode* right = bvhNode + node->leftFirst + 1;
bool visitLeft = true, visitRight = true;
int leftFirst = first, leftLast = last, rightFirst = first, rightLast = last;
float distLeft, distRight;
{
// see if we want to intersect the left child
const __m128 minO4 = _mm_sub_ps( *(__m128*)&left->aabbMin, O4 );
const __m128 maxO4 = _mm_sub_ps( *(__m128*)&left->aabbMax, O4 );
// 1. Early-in test: if first ray hits the node, the packet visits the node
const __m128 rD4 = *(__m128*)&packet[first].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
const bool earlyHit = (tmax >= tmin && tmin < packet[first].hit.t && tmax >= 0);
distLeft = tmin;
// 2. Early-out test: if the node aabb is outside the four planes, we skip the node
if (!earlyHit)
{
float* minmax = (float*)left;
bvhvec3 p0( minmax[sign0x], minmax[sign0y], minmax[sign0z] );
bvhvec3 p1( minmax[sign1x], minmax[sign1y], minmax[sign1z] );
bvhvec3 p2( minmax[sign2x], minmax[sign2y], minmax[sign2z] );
bvhvec3 p3( minmax[sign3x], minmax[sign3y], minmax[sign3z] );
if (dot( p0, plane0 ) > t0 || dot( p1, plane1 ) > t1 || dot( p2, plane2 ) > t2 || dot( p3, plane3 ) > t3)
visitLeft = false;
else
{
// 3. Last resort: update first and last, stay in node if first > last
for (; leftFirst <= leftLast; leftFirst++)
{
const __m128 rD4 = *(__m128*)&packet[leftFirst].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
if (tmax >= tmin && tmin < packet[leftFirst].hit.t && tmax >= 0) { distLeft = tmin; break; }
}
for (; leftLast >= leftFirst; leftLast--)
{
const __m128 rD4 = *(__m128*)&packet[leftLast].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
if (tmax >= tmin && tmin < packet[leftLast].hit.t && tmax >= 0) break;
}
visitLeft = leftLast >= leftFirst;
}
}
}
{
// see if we want to intersect the right child
const __m128 minO4 = _mm_sub_ps( *(__m128*)&right->aabbMin, O4 );
const __m128 maxO4 = _mm_sub_ps( *(__m128*)&right->aabbMax, O4 );
// 1. Early-in test: if first ray hits the node, the packet visits the node
const __m128 rD4 = *(__m128*)&packet[first].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
const bool earlyHit = (tmax >= tmin && tmin < packet[first].hit.t && tmax >= 0);
distRight = tmin;
// 2. Early-out test: if the node aabb is outside the four planes, we skip the node
if (!earlyHit)
{
float* minmax = (float*)right;
bvhvec3 p0( minmax[sign0x], minmax[sign0y], minmax[sign0z] );
bvhvec3 p1( minmax[sign1x], minmax[sign1y], minmax[sign1z] );
bvhvec3 p2( minmax[sign2x], minmax[sign2y], minmax[sign2z] );
bvhvec3 p3( minmax[sign3x], minmax[sign3y], minmax[sign3z] );
if (dot( p0, plane0 ) > t0 || dot( p1, plane1 ) > t1 || dot( p2, plane2 ) > t2 || dot( p3, plane3 ) > t3)
visitRight = false;
else
{
// 3. Last resort: update first and last, stay in node if first > last
for (; rightFirst <= rightLast; rightFirst++)
{
const __m128 rD4 = *(__m128*)&packet[rightFirst].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
if (tmax >= tmin && tmin < packet[rightFirst].hit.t && tmax >= 0) { distRight = tmin; break; }
}
for (; rightLast >= first; rightLast--)
{
const __m128 rD4 = *(__m128*)&packet[rightLast].rD;
const __m128 st1 = _mm_mul_ps( minO4, rD4 ), st2 = _mm_mul_ps( maxO4, rD4 );
const __m128 vmax4 = _mm_max_ps( st1, st2 ), vmin4 = _mm_min_ps( st1, st2 );
const float tmax = tinybvh_min( LANE( vmax4, 0 ), tinybvh_min( LANE( vmax4, 1 ), LANE( vmax4, 2 ) ) );
const float tmin = tinybvh_max( LANE( vmin4, 0 ), tinybvh_max( LANE( vmin4, 1 ), LANE( vmin4, 2 ) ) );
if (tmax >= tmin && tmin < packet[rightLast].hit.t && tmax >= 0) break;
}
visitRight = rightLast >= rightFirst;
}
}
}
// process intersection result
if (visitLeft && visitRight)
{
if (distLeft < distRight)
{
// push right, continue with left
stack[stackPtr++] = node->leftFirst + 1;
stack[stackPtr++] = (rightFirst << 8) + rightLast;
node = left, first = leftFirst, last = leftLast;
}
else
{
// push left, continue with right
stack[stackPtr++] = node->leftFirst;
stack[stackPtr++] = (leftFirst << 8) + leftLast;
node = right, first = rightFirst, last = rightLast;
}
}
else if (visitLeft) // continue with left
node = left, first = leftFirst, last = leftLast;
else if (visitRight) // continue with right
node = right, first = rightFirst, last = rightLast;
else if (stackPtr == 0) break; else // pop
last = stack[--stackPtr], node = bvhNode + stack[--stackPtr],
first = last >> 8, last &= 255;
}
}
}

#endif

// IntersectTri
void BVH::IntersectTri( Ray& ray, const unsigned int idx ) const
{
Expand Down
22 changes: 22 additions & 0 deletions tiny_bvh_speedtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,28 @@ int main()
mrays = (float)N / traceTimeMTP;
printf( "%.2fms for %.2fM rays (%.2fMRays/s)\n", traceTimeMTP * 1000, (float)N * 1e-6f, mrays * 1e-6f );

#ifdef BVH_USEAVX

// trace all rays three times to estimate average performance
// - coherent distribution, multi-core, packet traversal, SSE version
t.reset();
printf( "- CPU, coherent, basic 2-way layout, MT, packets (SSE): " );
for (int j = 0; j < 3; j++)
{
const int batchCount = N / (30 * 256); // batches of 30 packets of 256 rays
#pragma omp parallel for schedule(dynamic)
for (int batch = 0; batch < batchCount; batch++)
{
const int batchStart = batch * 30 * 256;
for (int i = 0; i < 30; i++) bvh.Intersect256RaysSSE( rays + batchStart + i * 256 );
}
}
float traceTimeMTPS = t.elapsed() / 3.0f;
mrays = (float)N / traceTimeMTPS;
printf( "%.2fms for %.2fM rays (%.2fMRays/s)\n", traceTimeMTPS * 1000, (float)N * 1e-6f, mrays * 1e-6f );

#endif

#endif

#ifdef TRAVERSE_2WAY_MT_DIVERGENT
Expand Down

0 comments on commit b74235d

Please sign in to comment.