Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logderivative sum: polish the implementation #523

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions prover/protocol/compiler/logderivativesum/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,39 @@ func CompileLogDerivSum(comp *wizard.CompiledIOP) {
// compilation process. We know that the query was already ignored at
// the beginning because we are iterating over the unignored keys.
comp.QueriesParams.MarkAsIgnored(qName)
// get the Numerator and Denominator from the input and prepare their compilation.
zEntries := logDeriv.Inputs
va := FinalEvaluationCheck{}

var (
zEntries = logDeriv.Inputs
va = FinalEvaluationCheck{
LogDerivSumID: qName,
}
lastRound = logDeriv.Round
)

for _, entry := range zEntries {

// get the Numerator and Denominator from the input and prepare their compilation.
zC := &lookup.ZCtx{
Round: entry.Round,
Round: lastRound,
Size: entry.Size,
SigmaNumerator: entry.Numerator,
SigmaDenominator: entry.Denominator,
}

// z-packing compile; it imposes the correct accumulation over Numerator and Denominator.
zC.Compile(comp)

// prover step; Z assignments
zAssignmentTask := lookup.ZAssignmentTask(*zC)
comp.SubProvers.AppendToInner(zC.Round, func(run *wizard.ProverRuntime) {
zAssignmentTask.Run(run)
})

// collect all the zOpening for all the z columns
va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...)
}

// verifer step
va.LogDerivSumID = qName
lastRound := comp.NumRounds() - 1
comp.RegisterVerifierAction(lastRound, &va)
}

Expand Down
9 changes: 4 additions & 5 deletions prover/protocol/compiler/logderivativesum/logderivsum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ func TestLogDerivSum(t *testing.T) {
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 4}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Round: 0,
Size: 4,
size := 4
zCat1 := map[int]*query.LogDerivativeSumInput{}
zCat1[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerators,
Denominator: denominators,
}
Expand Down
81 changes: 36 additions & 45 deletions prover/protocol/distributed/compiler/inclusion/inclusion.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type DistributionInputs struct {
// Name of the module
ModuleName distributed.ModuleName
// query is supposed to be the global LogDerivativeSum.
QueryID ifaces.QueryID
Query query.LogDerivativeSum
// it contains the whole witness,
// and also the witness for the auxiliary columns such as multiplicity column for Inclusion.
InitialProver *wizard.ProverRuntime
Expand All @@ -41,115 +41,105 @@ func DistributeLogDerivativeSum(
var (
queryID ifaces.QueryID
)

for _, qName := range initialComp.QueriesParams.AllUnignoredKeys() {

_, ok := initialComp.QueriesParams.Data(qName).(query.LogDerivativeSum)
if !ok {
continue
}

// panic if there is more than a LogDerivativeSum query in the initialComp.
if string(queryID) != "" {
utils.Panic("found more than a LogDerivativeSum query in the initialComp")
}

queryID = qName
}
input := DistributionInputs{

// get the share of the module from the LogDerivativeSum query
GetShareOfLogDerivativeSum(DistributionInputs{
ModuleComp: moduleComp,
InitialComp: initialComp,
Disc: disc,
ModuleName: moduleName,
QueryID: queryID,
Query: initialComp.QueriesParams.Data(queryID).(query.LogDerivativeSum),
InitialProver: initialProver,
}
// get the share of the module from the LogDerivativeSum query
GetShareOfLogDerivativeSum(input)
})

}

// GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query.
// It inserts a new LogDerivativeSum for the extracted share.
func GetShareOfLogDerivativeSum(in DistributionInputs) {

var (
initialComp = in.InitialComp
moduleComp = in.ModuleComp
initialProver = in.InitialProver
numerator []*symbolic.Expression
denominator []*symbolic.Expression
keyIsInModule bool
zCatalog = make(map[[2]int]*query.LogDerivativeSumInput)
zCatalog = make(map[int]*query.LogDerivativeSumInput)
logDeriv = in.Query
round = logDeriv.Round
)
// check that the given query is a valid LogDerivateSum query in the CompiledIOP.
logDeriv, ok := initialComp.QueriesParams.Data(in.QueryID).(query.LogDerivativeSum)
if !ok {
panic("the given query is not a valid LogDerivativeSum from the compiledIOP")
}

// @Azam this is because for the moment we dont know how the module-discoverer extracts moduleComp from InitialComp.
// if we are sure that inclusions are already removed from modComp, we can skip this step here.
for _, qName := range moduleComp.QueriesNoParams.AllUnignoredKeys() {
// Filter out non lookup queries
_, ok := moduleComp.QueriesNoParams.Data(qName).(query.Inclusion)
if !ok {
continue
}
// ignore the query as it is about to be compiled and replaces with low level queries.
moduleComp.QueriesNoParams.MarkAsIgnored(qName)
}

// extract the share of the module from the global sum.
for key := range logDeriv.Inputs {
for i := range logDeriv.Inputs[key].Numerator {
for size := range logDeriv.Inputs {

for i := range logDeriv.Inputs[size].Numerator {

// if Denominator is in the module pass the numerator from initialComp to moduleComp
// Particularly, T might be in the module and needs to take M from initialComp.
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Denominator[i], in.ModuleName) {
if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Numerator[i], in.ModuleName) {
distributed.PassColumnToModule(initialComp, moduleComp, initialProver, logDeriv.Inputs[key].Numerator[i])
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Denominator[i], in.ModuleName) {

if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Numerator[i], in.ModuleName) {
distributed.PassColumnToModule(initialComp, moduleComp, initialProver, logDeriv.Inputs[size].Numerator[i])
}
denominator = append(denominator, logDeriv.Inputs[key].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[key].Numerator[i])

denominator = append(denominator, logDeriv.Inputs[size].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[size].Numerator[i])

// replaces the external coins with local coins
// note that they just appear in the denominator.
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[key].Denominator[i])
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[size].Denominator[i])
keyIsInModule = true
}
}

// if there in any expression relevant to the current key, add them to zCatalog
if keyIsInModule {
// zCatalog specific to the module
zCatalog[key] = &query.LogDerivativeSumInput{
Round: key[0],
Size: key[1],
zCatalog[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerator,
Denominator: denominator,
}
}
keyIsInModule = false

}
// sanity check; the initialComp has only two rounds
if initialComp.NumRounds() != 2 {
utils.Panic("expected initialComp to have 2 rounds but it has %v rounds", initialComp.NumRounds())
keyIsInModule = false
}

// insert a LogDerivativeSum specific to the module at round 1 (since initialComp has 2 rounds).
moduleComp.InsertLogDerivativeSum(
1,
round,
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
zCatalog,
)

// prover step to assign the parameters of LogDerivativeSum at the same round.
moduleComp.SubProvers.AppendToInner(1, func(run *wizard.ProverRuntime) {
moduleComp.SubProvers.AppendToInner(round, func(run *wizard.ProverRuntime) {
run.AssignLogDerivSum(
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
GetLogDerivativeSumResult(zCatalog, run),
getLogDerivativeSumResult(zCatalog, run),
)
})

}

// GetLogDerivativeSumResult is a helper allowing the prover to calculate the result of its associated LogDerivativeSum query.
func GetLogDerivativeSumResult(zCatalog map[[2]int]*query.LogDerivativeSumInput, run *wizard.ProverRuntime) field.Element {
// getLogDerivativeSumResult is a helper allowing the prover to calculate the result of its associated LogDerivativeSum query.
func getLogDerivativeSumResult(zCatalog map[int]*query.LogDerivativeSumInput, run *wizard.ProverRuntime) field.Element {
// compute the actual sum from the Numerator and Denominator
actualSum := field.Zero()
for key := range zCatalog {
Expand Down Expand Up @@ -178,6 +168,7 @@ func GetLogDerivativeSumResult(zCatalog map[[2]int]*query.LogDerivativeSumInput,
packedZ[k].Add(&packedZ[k], &packedZ[k-1])
}
}

actualSum.Add(&actualSum, &packedZ[len(packedZ)-1])
}
}
Expand Down
29 changes: 12 additions & 17 deletions prover/protocol/distributed/preparation.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) {
// zCatalog stores a mapping (round, size) into query.LogDerivativeSumInput and helps finding
// which Z context should be used to handle a part of a given inclusion
// query.
zCatalog = map[[2]int]*query.LogDerivativeSumInput{}
zCatalog = map[int]*query.LogDerivativeSumInput{}
)

// Skip the compilation phase if no lookup constraint is being used. Otherwise
Expand All @@ -60,45 +60,41 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) {
)

// push single-columns into zCatalog
PushToZCatalog(tableCtx, zCatalog)
pushToZCatalog(tableCtx, zCatalog)

a := lookup.MAssignmentTask{
M: tableCtx.M,
S: checkTable,
T: lookupTable,
SFilter: includedFilters,
}

// assign the multiplicity column
comp.SubProvers.AppendToInner(round, a.Run)

}

// insert a single LogDerivativeSum query for the global zCatalog.
comp.InsertLogDerivativeSum(lastRound, "GlobalLogDerivativeSum", zCatalog)
comp.InsertLogDerivativeSum(lastRound+1, "GlobalLogDerivativeSum", zCatalog)

// assign parameters of LogDerivativeSum, it is just to prevent the panic attack in the prover
comp.SubProvers.AppendToInner(lastRound, func(run *wizard.ProverRuntime) {
comp.SubProvers.AppendToInner(lastRound+1, func(run *wizard.ProverRuntime) {
run.AssignLogDerivSum("GlobalLogDerivativeSum", field.Zero())
})
}

// PushToZCatalog constructs the numerators and denominators for the collapsed S and T
// pushToZCatalog constructs the numerators and denominators for the collapsed S and T
// into zCatalog, for their corresponding rounds and size.
func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) {

var (
round = stc.Gamma.Round
)
func pushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[int]*query.LogDerivativeSumInput) {

// tableCtx push to -> zCtx
// Process the T columns
for frag := range stc.T {
size := stc.M[frag].Size()

key := [2]int{round, size}
key := size
if zCatalog[key] == nil {
zCatalog[key] = &query.LogDerivativeSumInput{
Size: size,
Round: round,
Size: size,
}
}

Expand All @@ -118,11 +114,10 @@ func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDer
sFilter = symbolic.NewVariable(stc.SFilters[table])
}

key := [2]int{round, size}
key := size
if zCatalog[key] == nil {
zCatalog[key] = &query.LogDerivativeSumInput{
Size: size,
Round: round,
Size: size,
}
}

Expand Down
10 changes: 7 additions & 3 deletions prover/protocol/query/logderiv_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// LogDerivativeSumInput stores the input to the query
type LogDerivativeSumInput struct {
Round, Size int
Size int
Numerator []*sym.Expression // T -> -M, S -> +Filter
Denominator []*sym.Expression // S or T -> ({S,T} + X)
}
Expand All @@ -27,7 +27,8 @@ type LogDerivativeSumInput struct {
// N_{i,j} is the i-th element of the underlying column of j-th Numerator
// D_{i,j} is the i-th element of the underlying column of j-th Denominator
type LogDerivativeSum struct {
Inputs map[[2]int]*LogDerivativeSumInput
Round int
Inputs map[int]*LogDerivativeSumInput
ID ifaces.QueryID
}

Expand All @@ -42,24 +43,27 @@ func (l LogDerivSumParams) UpdateFS(fs *fiatshamir.State) {
}

// NewLogDerivativeSum creates the new context LogDerivativeSum.
func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum {
func NewLogDerivativeSum(round int, inp map[int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum {

// check the length consistency
for key := range inp {
if len(inp[key].Numerator) != len(inp[key].Denominator) || len(inp[key].Numerator) == 0 {
utils.Panic("Numerator and Denominator should have the same (no-zero) length, %v , %v", len(inp[key].Numerator), len(inp[key].Denominator))
}
for i := range inp[key].Numerator {

if err := inp[key].Numerator[i].Validate(); err != nil {
utils.Panic(" Numerator[%v] is not a valid expression", i)
}

if err := inp[key].Denominator[i].Validate(); err != nil {
utils.Panic(" Denominator[%v] is not a valid expression", i)
}
}
}

return LogDerivativeSum{
Round: round,
Inputs: inp,
ID: id,
}
Expand Down
4 changes: 2 additions & 2 deletions prover/protocol/query/logderiv_sum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func TestLogDerivSum(t *testing.T) {
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 0}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
key := 0
zCat1 := map[int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Numerator: numerators,
Denominator: denominators,
Expand Down
4 changes: 2 additions & 2 deletions prover/protocol/wizard/compiled.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,9 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa
// InsertLogDerivativeSum registers a new LogDerivativeSum query [query.LogDerivativeSum].
// It generates a single global summation for many Sigma Columns from Lookup compilation.
// The sigma columns are categorized by [round,size].
func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[[2]int]*query.LogDerivativeSumInput) query.LogDerivativeSum {
func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[int]*query.LogDerivativeSumInput) query.LogDerivativeSum {
c.assertConsistentRound(lastRound)
q := query.NewLogDerivativeSum(in, id)
q := query.NewLogDerivativeSum(lastRound, in, id)
// Finally registers the query
c.QueriesParams.AddToRound(lastRound, id, q)
return q
Expand Down
Loading