Skip to content

Commit

Permalink
Add faster BVH4 traversal.
Browse files Browse the repository at this point in the history
  • Loading branch information
jbikker committed Nov 26, 2024
1 parent ed5d196 commit fc36151
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 37 deletions.
109 changes: 79 additions & 30 deletions tiny_bvh.h
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,8 @@ void BVH::Convert( const BVHLayout from, const BVHLayout to, const bool deleteOr
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
for (int cidx = 0, i = 0; i < 4; i++) if (orig.child[i])
int cidx = 0;
for (int i = 0; i < 4; i++) if (orig.child[i])
{
const BVHNode4& child = bvh4Node[orig.child[i]];
((float*)&newNode.xmin4)[cidx] = child.aabbMin.x;
Expand All @@ -1414,6 +1415,12 @@ void BVH::Convert( const BVHLayout from, const BVHLayout to, const bool deleteOr
stack[stackPtr++] = orig.child[i];
cidx++;
}
for (; cidx < 4; cidx++)
{
((float*)&newNode.xmin4)[cidx] = 1e30f, ((float*)&newNode.xmax4)[cidx] = 1.00001e30f;
((float*)&newNode.ymin4)[cidx] = 1e30f, ((float*)&newNode.ymax4)[cidx] = 1.00001e30f;
((float*)&newNode.zmin4)[cidx] = 1e30f, ((float*)&newNode.zmax4)[cidx] = 1.00001e30f;
}
// pop next task
if (!stackPtr) break;
nodeIdx = stack[--stackPtr];
Expand Down Expand Up @@ -3176,50 +3183,92 @@ int BVH::Intersect_AltSoA( Ray& ray ) const
return steps;
}

// Traverse a 4-way BVH stored in 'Atilla Áfra' layout.
int BVH::Intersect_Afra( Ray& ray ) const
{
#if 1
// quick-and-dirty intersect to verify data structure
unsigned nodeIdx = 0, stack[1024], stackPtr = 0, steps = 0;
const __m128 ox4 = _mm_set1_ps( ray.O.x ), rdx4 = _mm_set1_ps( ray.rD.x );
const __m128 oy4 = _mm_set1_ps( ray.O.y ), rdy4 = _mm_set1_ps( ray.rD.y );
const __m128 oz4 = _mm_set1_ps( ray.O.z ), rdz4 = _mm_set1_ps( ray.rD.z );
__m128 t4 = _mm_set1_ps( ray.hit.t ), zero4 = _mm_setzero_ps();
__m128 idx4 = _mm_castsi128_ps( _mm_setr_epi32( 0, 1, 2, 3 ) );
__m128 idxMask = _mm_castsi128_ps( _mm_set1_epi32( 0xfffffffc ) );
__m128 inf4 = _mm_set1_ps( 1e30f );
while (1)
{
const BVHNode4Alt2& node = bvh4Alt2[nodeIdx];
steps++;
BVHNode4Alt2& node = bvh4Alt2[nodeIdx];
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
for (unsigned i = 0; i < 4; i++) if (node.childFirst[i] + node.triCount[i] > 0)
{
bvhvec3 bmin, bmax;
bmin.x = ((float*)&node.xmin4)[i], bmax.x = ((float*)&node.xmax4)[i];
bmin.y = ((float*)&node.ymin4)[i], bmax.y = ((float*)&node.ymax4)[i];
bmin.z = ((float*)&node.zmin4)[i], bmax.z = ((float*)&node.zmax4)[i];
float t = IntersectAABB( ray, bmin, bmax );
if (t < 1e30f)
// intersect the ray with four AABBs
const __m128 x0 = _mm_sub_ps( node.xmin4, ox4 ), x1 = _mm_sub_ps( node.xmax4, ox4 );
const __m128 y0 = _mm_sub_ps( node.ymin4, oy4 ), y1 = _mm_sub_ps( node.ymax4, oy4 );
const __m128 z0 = _mm_sub_ps( node.zmin4, oz4 ), z1 = _mm_sub_ps( node.zmax4, oz4 );
const __m128 tx1 = _mm_mul_ps( x0, rdx4 ), tx2 = _mm_mul_ps( x1, rdx4 );
const __m128 ty1 = _mm_mul_ps( y0, rdy4 ), ty2 = _mm_mul_ps( y1, rdy4 );
const __m128 tz1 = _mm_mul_ps( z0, rdz4 ), tz2 = _mm_mul_ps( z1, rdz4 );
__m128 tmin = _mm_max_ps( _mm_max_ps( _mm_min_ps( tx1, tx2 ), _mm_min_ps( ty1, ty2 ) ), _mm_min_ps( tz1, tz2 ) );
const __m128 tmax = _mm_min_ps( _mm_min_ps( _mm_max_ps( tx1, tx2 ), _mm_max_ps( ty1, ty2 ) ), _mm_max_ps( tz1, tz2 ) );
const __m128 hit = _mm_and_ps( _mm_and_ps( _mm_cmpge_ps( tmax, tmin ), _mm_cmplt_ps( tmin, t4 ) ), _mm_cmpge_ps( tmax, zero4 ) );
const int hits = _mm_movemask_ps( hit );
nodeIdx = 0;
if (hits)
{
// blend in lane indices
tmin = _mm_or_ps( _mm_and_ps( _mm_blendv_ps( inf4, tmin, hit ), idxMask ), idx4 );
// sort
float tmp, d0 = LANE( tmin, 0 ), d1 = LANE( tmin, 1 ), d2 = LANE( tmin, 2 ), d3 = LANE( tmin, 3 );
if (d0 > d2) tmp = d0, d0 = d2, d2 = tmp;
if (d1 > d3) tmp = d1, d1 = d3, d3 = tmp;
if (d0 > d1) tmp = d0, d0 = d1, d1 = tmp;
if (d2 > d3) tmp = d2, d2 = d3, d3 = tmp;
if (d1 > d2) tmp = d1, d1 = d2, d2 = tmp;
// process hits
float d[4] = { d0, d1, d2, d3 };
for (int i = 0; i < 4; i++)
{
if (node.triCount[i] > 0)
if (d[i] > 1e29f) break;
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
unsigned lane = *(unsigned*)&d[i] & 3;
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
if (node.triCount[lane] + node.childFirst[lane] == 0) continue; // TODO - never happens?
if (node.triCount[lane] == 0)
{
// process leaf
const unsigned first = node.childFirst[i], count = node.triCount[i];
for (unsigned j = 0; j < count; j++) IntersectTri( ray, triIdx[first + j] );
const unsigned childIdx = node.childFirst[lane];
if (!nodeIdx) nodeIdx = childIdx; else stack[stackPtr++] = childIdx;
continue;
}
else
const unsigned first = node.childFirst[lane], count = node.triCount[lane];
for (unsigned j = 0; j < count; j++) // TODO: aim for 4 prims per leaf
{
// process interior node
stack[stackPtr++] = node.childFirst[i];
const unsigned idx = triIdx[first + j], vertIdx = idx * 3;
const bvhvec4 v0 = verts[vertIdx];
const bvhvec3 edge1 = verts[vertIdx + 1] - v0;
const bvhvec3 edge2 = verts[vertIdx + 2] - v0;
const bvhvec3 h = cross( ray.D, edge2 );
const float a = dot( edge1, h );
if (fabs( a ) < 0.0000001f) continue; // ray parallel to triangle
const float f = 1 / a;
const bvhvec3 s = ray.O - bvhvec3( v0 );
const float u = f * dot( s, h );
if (u < 0 || u > 1) continue;
const bvhvec3 q = cross( s, edge1 );
const float v = f * dot( ray.D, q );
if (v < 0 || u + v > 1) continue;
const float t = f * dot( edge2, q );
if (t > 0 && t < ray.hit.t)
ray.hit.u = u, ray.hit.v = v, ray.hit.prim = idx,
ray.hit.t = t, t4 = _mm_set1_ps( t );
}
}
}
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
// get next task
if (nodeIdx) continue;
if (stackPtr == 0) break; else nodeIdx = stack[--stackPtr];
}
#else
// proper SIMD traversal
// TODO
#endif
return steps;
}

Expand Down
39 changes: 32 additions & 7 deletions tiny_bvh_speedtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
#define TRAVERSE_SOA2WAY_ST
#define TRAVERSE_2WAY_MT
#define TRAVERSE_2WAY_MT_PACKET
#define TRAVERSE_2WAY_MT_DIVERGENT
#define TRAVERSE_OPTIMIZED_ST
#define TRAVERSE_4WAY_OPTIMIZED
// #define TRAVERSE_2WAY_MT_DIVERGENT // skipping; needs improvement.
// #define EMBREE_BUILD // win64-only for now.
// #define EMBREE_TRAVERSE // win64-only for now.

Expand Down Expand Up @@ -573,6 +574,30 @@ int main()

#endif

#ifdef TRAVERSE_4WAY_OPTIMIZED

// trace all rays three times to estimate average performance
// - single core version, BVH4 in SIMD-friendly layout
#ifndef TRAVERSE_OPTIMIZED_ST
printf( "Optimizing BVH, regular... " );
bvh.Convert( BVH::WALD_32BYTE, BVH::VERBOSE );
t.reset();
bvh.Optimize( 1000000 ); // optimize the raw SBVH
bvh.Convert( BVH::VERBOSE, BVH::WALD_32BYTE );
printf( "done (%.2fs). New: %i nodes, SAH=%.2f\n", t.elapsed(), bvh.NodeCount( BVH::WALD_32BYTE ), bvh.SAHCost() );
#endif
bvh.Convert( BVH::WALD_32BYTE, BVH::BASIC_BVH4 );
bvh.Convert( BVH::BASIC_BVH4, BVH::BVH4_AFRA );
printf( "- CPU, coherent, 4-way optimized, ST: " );
t.reset();
for (int pass = 0; pass < 3; pass++)
for (int i = 0; i < Nsmall; i++) bvh.Intersect( smallBatch[i], BVH::BVH4_AFRA );
float traceTimeAfra = t.elapsed() / 3.0f;
mrays = (float)Nsmall / traceTimeAfra;
printf( "%8.1fms for %6.2fM rays => %6.2fMRay/s\n", traceTimeAfra * 1000, (float)Nsmall * 1e-6f, mrays * 1e-6f );

#endif

#if defined EMBREE_TRAVERSE && defined EMBREE_BUILD

// trace all rays three times to estimate average performance
Expand All @@ -590,18 +615,18 @@ int main()
rayhits[i].hit.instID[0] = RTC_INVALID_GEOMETRY_ID;
}
t.reset();
for (int pass = 0; pass < 3; pass++)
for (int i = 0; i < Nfull; i++) rtcIntersect1( embreeScene, rayhits + i );
float traceTimeEmbree = t.elapsed() / 3.0f;
for (int pass = 0; pass < 6; pass++)
for (int i = 0; i < Nsmall; i++) rtcIntersect1( embreeScene, rayhits + i );
float traceTimeEmbree = t.elapsed() / 6.0f;
// retrieve intersection results
for (int i = 0; i < Nfull; i++)
for (int i = 0; i < Nsmall; i++)
{
fullBatch[i].hit.t = rayhits[i].ray.tfar;
fullBatch[i].hit.u = rayhits[i].hit.u, fullBatch[i].hit.u = rayhits[i].hit.v;
fullBatch[i].hit.prim = rayhits[i].hit.primID;
}
mrays = (float)Nfull / traceTimeEmbree;
printf( "%8.1fms for %6.2fM rays => %6.2fMRay/s\n", traceTimeEmbree * 1000, (float)Nfull * 1e-6f, mrays * 1e-6f );
mrays = (float)Nsmall / traceTimeEmbree;
printf( "%8.1fms for %6.2fM rays => %6.2fMRay/s\n", traceTimeEmbree * 1000, (float)Nsmall * 1e-6f, mrays * 1e-6f );
tinybvh::free64( rayhits );

#endif
Expand Down

0 comments on commit fc36151

Please sign in to comment.