Skip to content

Commit

Permalink
use push constants for image stride
Browse files Browse the repository at this point in the history
  • Loading branch information
jo-m committed Dec 17, 2023
1 parent dec0cef commit f5f8eea
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 27 deletions.
16 changes: 9 additions & 7 deletions pkg/pmatch/vk.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ layout(constant_id = 3) const int M = 0;
layout(constant_id = 4) const int N = 0;
layout(constant_id = 5) const int DU = 0;
layout(constant_id = 6) const int DV = 0;
layout(constant_id = 7) const int IS = 0;
layout(constant_id = 8) const int PS = 0;
layout(constant_id = 9) const int SS = 0;

layout(std430, push_constant) uniform _constants {
uint IS;
uint PS;
}
push_constants;

struct results {
uint max_uint;
Expand Down Expand Up @@ -59,12 +62,12 @@ void main() {
return;
}

uint imgPatStartIx = y * IS / 4 + x;
uint imgPatStartIx = y * push_constants.IS / 4 + x;
float dot = 0, absI2 = 0, absP2 = 0;

for (uint v = 0; v < DV; v++) {
uint pxIi = v * IS / 4;
uint pxPi = v * PS / 4;
uint pxIi = v * push_constants.IS / 4;
uint pxPi = v * push_constants.PS / 4;

for (uint u = 0; u < DU; u++) {
uvec4 pxI = rgb_le(img[imgPatStartIx + pxIi + u]);
Expand All @@ -87,7 +90,6 @@ void main() {
uint ucos = uint(cos * pow(2, 32));
atomicMax(res.max_uint, ucos);
memoryBarrierBuffer();
// TODO: compute max per local group first via shared variable
// This thread computed the max.
if (res.max_uint == ucos) {
res.max = cos;
Expand Down
48 changes: 30 additions & 18 deletions pkg/pmatch/vk.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type results struct {
MaxY uint32
}

type pushConstants struct {
imgStride uint32
patStride uint32
}

func (r results) size() int {
buf := bytes.Buffer{}
err := binary.Write(&buf, binary.LittleEndian, results{})
Expand All @@ -43,14 +48,14 @@ func (r results) size() int {
}

type SearchVk struct {
searchRect image.Rectangle
h *vk.Handle
search *image.RGBA
resSize int
resultsBuf *vk.Buffer
imgBuf *vk.Buffer
patBuf *vk.Buffer
pipe *vk.Pipe
searchRect image.Rectangle
h *vk.Handle
resSize int
pushConstants bytes.Buffer
resultsBuf *vk.Buffer
imgBuf *vk.Buffer
patBuf *vk.Buffer
pipe *vk.Pipe
}

func (s *SearchVk) Destroy() {
Expand Down Expand Up @@ -89,7 +94,6 @@ func NewSearchVk(imgBounds, patBounds image.Rectangle, imgStride, patStride int)
Min: imgBounds.Min,
Max: imgBounds.Max.Sub(patBounds.Size()).Add(image.Pt(1, 1)),
}
s.search = image.NewRGBA(s.searchRect)

// Create instance.
var err error
Expand Down Expand Up @@ -119,6 +123,7 @@ func NewSearchVk(imgBounds, patBounds image.Rectangle, imgStride, patStride int)
}

// Create pipe.
binary.Write(&s.pushConstants, binary.LittleEndian, pushConstants{})
specInfo := []int{
// Local size.
localSizeX,
Expand All @@ -129,11 +134,8 @@ func NewSearchVk(imgBounds, patBounds image.Rectangle, imgStride, patStride int)
s.searchRect.Dy(),
patBounds.Dx(),
patBounds.Dy(),
imgStride,
patStride,
s.search.Stride,
}
s.pipe, err = s.h.NewPipe(shaderCode, []*vk.Buffer{s.resultsBuf, s.imgBuf, s.patBuf}, specInfo, 0)
s.pipe, err = s.h.NewPipe(shaderCode, []*vk.Buffer{s.resultsBuf, s.imgBuf, s.patBuf}, specInfo, s.pushConstants.Len())
if err != nil {
if err != nil {
s.Destroy()
Expand Down Expand Up @@ -165,12 +167,22 @@ func (s *SearchVk) Run(img, pat *image.RGBA) (maxX, maxY int, maxCos float64, er
return
}

// Prepare push constants.
s.pushConstants.Reset()
binary.Write(
&s.pushConstants,
binary.LittleEndian,
pushConstants{uint32(img.Stride), uint32(pat.Stride)})

// Run.
err = s.pipe.Run(s.h, [3]uint{
uint(s.searchRect.Dx()/localSizeX + 1),
uint(s.searchRect.Dy()/localSizeY + 1),
1,
}, nil)
err = s.pipe.Run(
s.h,
[3]uint{
uint(s.searchRect.Dx()/localSizeX + 1),
uint(s.searchRect.Dy()/localSizeY + 1),
1,
},
s.pushConstants.Bytes())
if err != nil {
return
}
Expand Down
28 changes: 27 additions & 1 deletion pkg/pmatch/vk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)

func Test_SearchRGBAVk(t *testing.T) {
func Test_SearchRGBAVk_Simple(t *testing.T) {
img := imutil.ToRGBA(LoadTestImg())
pat, err := imutil.Sub(img, image.Rect(x0, y0, x0+w, y0+h))
require.NoError(t, err)
Expand All @@ -29,6 +29,32 @@ func Test_SearchRGBAVk(t *testing.T) {
assert.Equal(t, y0, y)
}

func Test_SearchRGBAVk_Instance(t *testing.T) {
img := imutil.ToRGBA(LoadTestImg())
pat, err := imutil.Sub(img, image.Rect(x0, y0, x0+w, y0+h))
require.NoError(t, err)

search, err := NewSearchVk(img.Bounds(), pat.Bounds(), img.Stride, pat.(*image.RGBA).Stride)
assert.NoError(t, err)
defer search.Destroy()

search.Run(img, pat.(*image.RGBA))
x, y, score, err := search.Run(img, pat.(*image.RGBA))
assert.NoError(t, err)
assert.InDelta(t, 1., score, delta)
assert.Equal(t, x0, x)
assert.Equal(t, y0, y)

// Also resets pat bounds origin to (0,0).
patCopy := imutil.ToRGBA(pat.(*image.RGBA))

x, y, score, err = search.Run(img, patCopy)
assert.NoError(t, err)
assert.InDelta(t, 1., score, delta)
assert.Equal(t, x0, x)
assert.Equal(t, y0, y)
}

func Benchmark_SearchRGBAVk(b *testing.B) {
img := imutil.ToRGBA(LoadTestImg())
pat, err := imutil.Sub(img, image.Rect(x0, y0, x0+w, y0+h))
Expand Down
2 changes: 1 addition & 1 deletion pkg/vk/testfiles/pushconstant.comp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ layout(std430, set = 0, binding = 1) buffer readonly _buf1 { uint buf1[]; };
layout(std430, set = 0, binding = 2) buffer _buf2 { uint buf2[]; };

// Push constants.
layout(push_constant) uniform _constants { int VALUE_PUSH; }
layout(std430, push_constant) uniform _constants { int VALUE_PUSH; }
push_constants;

void main() {
Expand Down

0 comments on commit f5f8eea

Please sign in to comment.