Skip to content

Commit

Permalink
Fix by indices bug (#106)
Browse files Browse the repository at this point in the history
There was a subtle bug in `ByIndices`. The tests have also been updated to detect a wider class of bugs.
  • Loading branch information
chewxy authored Jan 14, 2021
1 parent 4ce03d1 commit d5ff158
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 107 deletions.
7 changes: 7 additions & 0 deletions api_matop.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,20 @@ func Diag(t Tensor) (retVal Tensor, err error) {
// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor.
// The `indices` tensor has to be a vector-like tensor of ints.
func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndices(a, indices, axis, opts...)
}
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
}

// ByIndicesB is the backpropagation of ByIndices.
func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
if axis >= a.Shape().Dims() {
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
}
if sbi, ok := a.Engine().(ByIndiceser); ok {
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
}
Expand Down
18 changes: 14 additions & 4 deletions defaultengine_selbyidx.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
dstCoord := make([]int, apRet.shape.Dims())

if isInnermost {
prevStride := apA.strides[axis-1]
retPrevStride := apRet.strides[axis-1]
prevAxis := axis - 1
if prevAxis < 0 {
// this may be the case if input is a vector
prevAxis = 0
}
prevStride := apA.strides[prevAxis]
retPrevStride := apRet.strides[prevAxis]
for i, idx := range indices {
srcCoord[axis] = idx
dstCoord[axis] = i
Expand Down Expand Up @@ -194,8 +199,13 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data
srcCoord := make([]int, apRet.shape.Dims())

if isInnermost {
retPrevStride := apB.strides[axis-1]
prevStride := apRet.strides[axis-1]
prevAxis := axis - 1
if prevAxis < 0 {
// this may be the case if input is a vector
prevAxis = 0
}
retPrevStride := apB.strides[prevAxis]
prevStride := apRet.strides[prevAxis]
for i, idx := range indices {
dstCoord[axis] = idx
srcCoord[axis] = i
Expand Down
193 changes: 90 additions & 103 deletions dense_selbyidx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,121 +6,108 @@ import (
"github.com/stretchr/testify/assert"
)

func TestDense_SelectByIndices(t *testing.T) {
assert := assert.New(t)

a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4))
indices := New(WithBacking([]int{1, 1}))

e := StdEng{}

a1, err := e.SelectByIndices(a, indices, 1)
if err != nil {
t.Errorf("%v", err)
}
correct1 := []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}
assert.Equal(correct1, a1.Data())

a0, err := e.SelectByIndices(a, indices, 0)
if err != nil {
t.Errorf("%v", err)
}
correct0 := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
assert.Equal(correct0, a0.Data())
type selByIndicesTest struct {
Name string
Data interface{}
Shape Shape
Indices []int
Axis int
WillErr bool

Correct interface{}
CorrectShape Shape
}

a2, err := e.SelectByIndices(a, indices, 2)
if err != nil {
t.Errorf("%v", err)
}
correct2 := []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}
assert.Equal(correct2, a2.Data())
var selByIndicesTests = []selByIndicesTest{
{Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}},

// !safe
aUnsafe := a.Clone().(*Dense)
indices = New(WithBacking([]int{1, 1, 1}))
aUnsafeSelect, err := e.SelectByIndices(aUnsafe, indices, 0, UseUnsafe())
if err != nil {
t.Errorf("%v", err)
}
correctUnsafe := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
assert.Equal(correctUnsafe, aUnsafeSelect.Data())
{Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false,
Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}},

// 3 indices, just to make sure the sanity of the algorithm
indices = New(WithBacking([]int{1, 1, 1}))
a1, err = e.SelectByIndices(a, indices, 1)
if err != nil {
t.Errorf("%v", err)
}
correct1 = []float64{
4, 5, 6, 7,
4, 5, 6, 7,
4, 5, 6, 7,
{Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false,
Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}},

12, 13, 14, 15,
12, 13, 14, 15,
12, 13, 14, 15,
{Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
Correct: []int{1, 1}, CorrectShape: Shape{2}},

20, 21, 22, 23,
20, 21, 22, 23,
20, 21, 22, 23,
}
assert.Equal(correct1, a1.Data())
{Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true,
Correct: []int{1, 1}, CorrectShape: Shape{2}},
{Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false,
Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}},
{Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false,
Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10},
},
}

a0, err = e.SelectByIndices(a, indices, 0)
if err != nil {
t.Errorf("%v", err)
func TestDense_SelectByIndices(t *testing.T) {
assert := assert.New(t)
for i, tc := range selByIndicesTests {
T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
indices := New(WithBacking(tc.Indices))
ret, err := ByIndices(T, indices, tc.Axis)
if checkErr(t, tc.WillErr, err, tc.Name, i) {
continue
}
assert.Equal(tc.Correct, ret.Data())
assert.True(tc.CorrectShape.Eq(ret.Shape()))
}
correct0 = []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
assert.Equal(correct0, a0.Data())
}

a2, err = e.SelectByIndices(a, indices, 2)
if err != nil {
t.Errorf("%v", err)
}
correct2 = []float64{1, 1, 1, 5, 5, 5, 9, 9, 9, 13, 13, 13, 17, 17, 17, 21, 21, 21}
assert.Equal(correct2, a2.Data())
var selByIndicesBTests = []struct {
selByIndicesTest

CorrectGrad interface{}
CorrectGradShape Shape
}{
{
selByIndicesTest: selByIndicesTests[0],
CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0},
CorrectGradShape: Shape{3, 2, 4},
},
{
selByIndicesTest: selByIndicesTests[1],
CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46},
CorrectGradShape: Shape{3, 2, 4},
},
{
selByIndicesTest: selByIndicesTests[2],
CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0},
CorrectGradShape: Shape{3, 2, 4},
},
{
selByIndicesTest: selByIndicesTests[3],
CorrectGrad: []int{0, 2, 0, 0, 0},
CorrectGradShape: Shape{5},
},
{
selByIndicesTest: selByIndicesTests[5],
CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0},
CorrectGradShape: Shape{4, 2},
},
{
selByIndicesTest: selByIndicesTests[6],
CorrectGrad: []float64{0, 10},
CorrectGradShape: Shape{2, 1},
},
}

func TestDense_SelectByIndicesB(t *testing.T) {

a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4))
indices := New(WithBacking([]int{1, 1}))

t.Logf("a\n%v", a)

e := StdEng{}

a1, err := e.SelectByIndices(a, indices, 1)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("a1\n%v", a1)

a1Grad, err := e.SelectByIndicesB(a, a1, indices, 1)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("a1Grad \n%v", a1Grad)

a0, err := e.SelectByIndices(a, indices, 0)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("a0\n%v", a0)
a0Grad, err := e.SelectByIndicesB(a, a0, indices, 0)
if err != nil {
t.Errorf("%v", err)
assert := assert.New(t)
for i, tc := range selByIndicesBTests {
T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
indices := New(WithBacking(tc.Indices))
ret, err := ByIndices(T, indices, tc.Axis)
if checkErr(t, tc.WillErr, err, tc.Name, i) {
continue
}
grad, err := ByIndicesB(T, ret, indices, tc.Axis)
if checkErr(t, tc.WillErr, err, tc.Name, i) {
continue
}
assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name)
assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape())
}
t.Logf("a0Grad\n%v", a0Grad)

a2, err := e.SelectByIndices(a, indices, 2)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("\n%v", a2)
a2Grad, err := e.SelectByIndicesB(a, a2, indices, 2)
if err != nil {
t.Errorf("%v", err)
}
t.Logf("a2Grad\n%v", a2Grad)
}

0 comments on commit d5ff158

Please sign in to comment.