Skip to content

Commit

Permalink
WIP: pointer fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Sep 12, 2024
1 parent 8c139dc commit 196a58e
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 14 deletions.
34 changes: 29 additions & 5 deletions compiler/ast/lrvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ import (
// - structField is nil
// - valueType is the type of Name
// - value is the value of Name
//
// 4. Array element:
// - baseInfo points to the containing variable
// - baseValue is the value of the containing variable
// - valueType is the type of the array element
// - value is the value of the array element
type LRValue struct {
ctx *Codegen
ast AST
Expand Down Expand Up @@ -176,10 +182,15 @@ func (lrv *LRValue) RValue() ssa.Value {
fieldType := lrv.valueType
fieldType.Offset = 0

fmt.Printf("*** RValue: fieldType=%v\n", fieldType)
fmt.Printf(" - valueType=%v: [%d-%d[\n", lrv.valueType,
lrv.valueType.Offset, lrv.valueType.Offset+lrv.valueType.Bits)
fmt.Printf(" - baseInfo.Offset=%v\n", lrv.baseInfo.Offset)

lrv.value = lrv.gen.AnonVal(fieldType)

from := int64(lrv.valueType.Offset)
to := int64(lrv.valueType.Offset + lrv.valueType.Bits)
from := int64(lrv.baseInfo.Offset + lrv.valueType.Offset)
to := from + int64(lrv.valueType.Bits)

if to > from {
fromConst := lrv.gen.Constant(from, types.Undefined)
Expand Down Expand Up @@ -262,13 +273,18 @@ func (ctx *Codegen) LookupVar(block *ssa.Block, gen *ssa.Generator,
lrv.baseValue = *v
}

// XXX must keep the base value intact
var structType types.Info

if lrv.baseValue.Type.Type == types.TPtr {
structType = *lrv.baseValue.Type.ElementType
lrv.baseInfo = lrv.baseValue.PtrInfo
lrv.baseValue, err = lrv.ptrBaseValue()
if err != nil {
return nil, false, false, err
return nil, false, false, nil
}
} else {
structType = lrv.baseValue.Type
lrv.baseInfo = &ssa.PtrInfo{
Name: ref.Name.Package,
Bindings: env,
Expand All @@ -277,11 +293,13 @@ func (ctx *Codegen) LookupVar(block *ssa.Block, gen *ssa.Generator,
}
}

if lrv.baseValue.Type.Type != types.TStruct {
// fmt.Printf(" => structType=%v\n", structType)

if structType.Type != types.TStruct {
return nil, false, false, fmt.Errorf("%s undefined", ref.Name)
}

for _, f := range lrv.baseValue.Type.Struct {
for _, f := range structType.Struct {
if f.Name == ref.Name.Name {
lrv.structField = &f
break
Expand All @@ -293,6 +311,12 @@ func (ctx *Codegen) LookupVar(block *ssa.Block, gen *ssa.Generator,
ref.Name, lrv.baseValue.Type, ref.Name.Name)
}
lrv.valueType = lrv.structField.Type
lrv.value.Type = types.Undefined
lrv.value.PtrInfo = nil

fmt.Printf(" => valueType=%v@%d, baseInfo.Offset=%d\n",
lrv.valueType,
lrv.valueType.Offset, lrv.baseInfo.Offset)

return lrv, true, false, nil
}
Expand Down
82 changes: 82 additions & 0 deletions compiler/ast/ptr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//
// Copyright (c) 2024 Markku Rossi
//
// All rights reserved.
//

package ast

import (
"slices"

"github.com/markkurossi/mpc/compiler/ssa"
"github.com/markkurossi/mpc/types"
)

func indexOfs(idx *Index, block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
*ssa.Block, *LRValue, types.Info, types.Info, types.Size, error) {
undef := types.Undefined

lv := idx

var err error
var values []ssa.Value
var indices []arrayIndex
var lrv *LRValue

for lrv == nil {
block, values, err = idx.Index.SSA(block, ctx, gen)
if err != nil {
return nil, nil, undef, undef, 0, err
}
if len(values) != 1 {
return nil, nil, undef, undef, 0,
ctx.Errorf(idx.Index, "invalid index")
}
index, err := values[0].ConstInt()
if err != nil {
return nil, nil, undef, undef, 0, ctx.Error(idx.Index, err.Error())
}
indices = append(indices, arrayIndex{
i: index,
ast: idx.Index,
})
switch i := idx.Expr.(type) {
case *Index:
idx = i

case *VariableRef:
lrv, _, _, err = ctx.LookupVar(block, gen, block.Bindings, i)
if err != nil {
return nil, nil, undef, undef, 0, err
}

default:
return nil, nil, undef, undef, 0, ctx.Errorf(idx.Expr,
"invalid operation: cannot index %v (%T)",
idx.Expr, idx.Expr)
}
}
slices.Reverse(indices)

lrv = lrv.Indirect()
baseType := lrv.ValueType()
elType := baseType
var offset types.Size

for _, index := range indices {
if !elType.Type.Array() {
return nil, nil, undef, undef, 0, ctx.Errorf(index.ast,
"indexing non-array %s (%s)", lv.Expr, elType)
}
if index.i >= elType.ArraySize {
return nil, nil, undef, undef, 0, ctx.Errorf(index.ast,
"invalid array index %d (out of bounds for %d-element array)",
index.i, elType.ArraySize)
}
offset += index.i * elType.ElementType.Bits
elType = *elType.ElementType
}

return block, lrv, baseType, elType, offset, nil
}
32 changes: 24 additions & 8 deletions compiler/ast/ssagen.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"slices"

"github.com/markkurossi/mpc/compiler/ssa"
"github.com/markkurossi/mpc/compiler/utils"
"github.com/markkurossi/mpc/types"
"github.com/markkurossi/tabulate"
)
Expand Down Expand Up @@ -318,20 +319,20 @@ func (ast *Assign) SSA(block *ssa.Block, ctx *Codegen,
"a non-name %s on left side of :=", lv)
}
var err error
var v []ssa.Value
var values []ssa.Value
var indices []arrayIndex
var lrv *LRValue
idx := lv

for lrv == nil {
block, v, err = idx.Index.SSA(block, ctx, gen)
block, values, err = idx.Index.SSA(block, ctx, gen)
if err != nil {
return nil, nil, err
}
if len(v) != 1 {
if len(values) != 1 {
return nil, nil, ctx.Errorf(idx.Index, "invalid index")
}
index, err := v[0].ConstInt()
index, err := values[0].ConstInt()
if err != nil {
return nil, nil, ctx.Error(idx.Index, err.Error())
}
Expand Down Expand Up @@ -1681,6 +1682,8 @@ func (ast *Unary) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
return block, []ssa.Value{constVal}, nil
}

utils.Tracef("Unary.SSA\n")

switch ast.Type {
case UnaryMinus:
block, exprs, err := ast.Expr.SSA(block, ctx, gen)
Expand Down Expand Up @@ -1733,6 +1736,7 @@ func (ast *Unary) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
return block, []ssa.Value{t}, nil

case UnaryAddr:
utils.Tracef("UnaryAddr: %T", ast.Expr)
switch v := ast.Expr.(type) {
case *VariableRef:
lrv, _, _, err := ctx.LookupVar(block, gen, block.Bindings, v)
Expand All @@ -1759,19 +1763,30 @@ func (ast *Unary) SSA(block *ssa.Block, ctx *Codegen, gen *ssa.Generator) (
return block, []ssa.Value{t}, nil

case *Index:
lrv, ptrType, offset, err := ast.addrIndex(block, ctx, gen, v)
var lrv *LRValue
var baseType, elType types.Info
var offset types.Size

block, lrv, baseType, elType, offset, err = indexOfs(v, block, ctx,
gen)
if err != nil {
return nil, nil, err
}
utils.Tracef(" - Index: offset=%v, base=%v, el=%v\n",
offset, baseType, elType)

t := gen.AnonVal(types.Info{
Type: types.TPtr,
IsConcrete: true,
Bits: ptrType.Bits,
MinBits: ptrType.Bits,
ElementType: ptrType,
Bits: elType.Bits,
MinBits: elType.Bits,
ElementType: &elType,
})
t.PtrInfo = lrv.BasePtrInfo()
t.PtrInfo.Offset += offset

utils.Tracef(" => %v\n", t)

return block, []ssa.Value{t}, nil

default:
Expand Down Expand Up @@ -2095,6 +2110,7 @@ func (ast *VariableRef) SSA(block *ssa.Block, ctx *Codegen,
}

value := lrv.RValue()
fmt.Printf(" => value=%v\n", value)
if value.Const {
gen.AddConstant(value)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/ssa/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type PtrInfo struct {
}

func (ptr PtrInfo) String() string {
return fmt.Sprintf("*%s@%d", ptr.Name, ptr.Scope)
return fmt.Sprintf("*%s{%d}%s", ptr.Name, ptr.Scope, ptr.ContainerType)
}

// Equal tests if this PtrInfo is equal to the argument PtrInfo.
Expand Down
43 changes: 43 additions & 0 deletions compiler/utils/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//
// Copyright (c) 2024 Markku Rossi
//
// All rights reserved.
//

package utils

import (
"fmt"
"path/filepath"
"runtime"
"strings"
)

func Tracef(format string, a ...interface{}) {
var filename string
var linenum int

for skip := 1; ; skip++ {
pc, file, line, ok := runtime.Caller(skip)
if !ok {
break
}
f := runtime.FuncForPC(pc)
if f != nil && strings.HasSuffix(f.Name(), ".errf") {
continue
}

filename = filepath.Base(file)
linenum = line
break
}

if len(filename) > 0 {
fmt.Printf("%s:%d: ", filename, linenum)
}
msg := fmt.Sprintf(format, a...)
if msg[len(msg)-1] != '\n' {
msg += "\n"
}
fmt.Print(msg)
}
1 change: 1 addition & 0 deletions pkg/crypto/ed25519/internal/edwards25519/ed25519.mpcl
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ func selectPoint(t *PreComputedGroupElement, pos int32, b int32) {
t.Zero()
var x PreComputedGroupElement
for i := int32(0); i < 8; i++ {
// PreComputedGroupElementCMove(t, &base[pos][i], equal(bAbs, i+1))
x = base[pos][i]
PreComputedGroupElementCMove(t, &x, equal(bAbs, i+1))
}
Expand Down

0 comments on commit 196a58e

Please sign in to comment.