Skip to content

Commit

Permalink
polish the implementation (#523)
Browse files Browse the repository at this point in the history
* removing the round from zcatalog
  • Loading branch information
AlexandreBelling authored Jan 10, 2025
1 parent 2c0e1b3 commit 5f2718c
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 80 deletions.
20 changes: 14 additions & 6 deletions prover/protocol/compiler/logderivativesum/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,39 @@ func CompileLogDerivSum(comp *wizard.CompiledIOP) {
// compilation process. We know that the query was already ignored at
// the beginning because we are iterating over the unignored keys.
comp.QueriesParams.MarkAsIgnored(qName)
// get the Numerator and Denominator from the input and prepare their compilation.
zEntries := logDeriv.Inputs
va := FinalEvaluationCheck{}

var (
zEntries = logDeriv.Inputs
va = FinalEvaluationCheck{
LogDerivSumID: qName,
}
lastRound = logDeriv.Round
)

for _, entry := range zEntries {

// get the Numerator and Denominator from the input and prepare their compilation.
zC := &lookup.ZCtx{
Round: entry.Round,
Round: lastRound,
Size: entry.Size,
SigmaNumerator: entry.Numerator,
SigmaDenominator: entry.Denominator,
}

// z-packing compile; it imposes the correct accumulation over Numerator and Denominator.
zC.Compile(comp)

// prover step; Z assignments
zAssignmentTask := lookup.ZAssignmentTask(*zC)
comp.SubProvers.AppendToInner(zC.Round, func(run *wizard.ProverRuntime) {
zAssignmentTask.Run(run)
})

// collect all the zOpening for all the z columns
va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...)
}

// verifer step
va.LogDerivSumID = qName
lastRound := comp.NumRounds() - 1
comp.RegisterVerifierAction(lastRound, &va)
}

Expand Down
9 changes: 4 additions & 5 deletions prover/protocol/compiler/logderivativesum/logderivsum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ func TestLogDerivSum(t *testing.T) {
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 4}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Round: 0,
Size: 4,
size := 4
zCat1 := map[int]*query.LogDerivativeSumInput{}
zCat1[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerators,
Denominator: denominators,
}
Expand Down
81 changes: 36 additions & 45 deletions prover/protocol/distributed/compiler/inclusion/inclusion.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type DistributionInputs struct {
// Name of the module
ModuleName distributed.ModuleName
// query is supposed to be the global LogDerivativeSum.
QueryID ifaces.QueryID
Query query.LogDerivativeSum
// it contains the whole witness,
// and also the witness for the auxiliary columns such as multiplicity column for Inclusion.
InitialProver *wizard.ProverRuntime
Expand All @@ -41,115 +41,105 @@ func DistributeLogDerivativeSum(
var (
queryID ifaces.QueryID
)

for _, qName := range initialComp.QueriesParams.AllUnignoredKeys() {

_, ok := initialComp.QueriesParams.Data(qName).(query.LogDerivativeSum)
if !ok {
continue
}

// panic if there is more than a LogDerivativeSum query in the initialComp.
if string(queryID) != "" {
utils.Panic("found more than a LogDerivativeSum query in the initialComp")
}

queryID = qName
}
input := DistributionInputs{

// get the share of the module from the LogDerivativeSum query
GetShareOfLogDerivativeSum(DistributionInputs{
ModuleComp: moduleComp,
InitialComp: initialComp,
Disc: disc,
ModuleName: moduleName,
QueryID: queryID,
Query: initialComp.QueriesParams.Data(queryID).(query.LogDerivativeSum),
InitialProver: initialProver,
}
// get the share of the module from the LogDerivativeSum query
GetShareOfLogDerivativeSum(input)
})

}

// GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query.
// It inserts a new LogDerivativeSum for the extracted share.
func GetShareOfLogDerivativeSum(in DistributionInputs) {

var (
initialComp = in.InitialComp
moduleComp = in.ModuleComp
initialProver = in.InitialProver
numerator []*symbolic.Expression
denominator []*symbolic.Expression
keyIsInModule bool
zCatalog = make(map[[2]int]*query.LogDerivativeSumInput)
zCatalog = make(map[int]*query.LogDerivativeSumInput)
logDeriv = in.Query
round = logDeriv.Round
)
// check that the given query is a valid LogDerivateSum query in the CompiledIOP.
logDeriv, ok := initialComp.QueriesParams.Data(in.QueryID).(query.LogDerivativeSum)
if !ok {
panic("the given query is not a valid LogDerivativeSum from the compiledIOP")
}

// @Azam this is because for the moment we dont know how the module-discoverer extracts moduleComp from InitialComp.
// if we are sure that inclusions are already removed from modComp, we can skip this step here.
for _, qName := range moduleComp.QueriesNoParams.AllUnignoredKeys() {
// Filter out non lookup queries
_, ok := moduleComp.QueriesNoParams.Data(qName).(query.Inclusion)
if !ok {
continue
}
// ignore the query as it is about to be compiled and replaces with low level queries.
moduleComp.QueriesNoParams.MarkAsIgnored(qName)
}

// extract the share of the module from the global sum.
for key := range logDeriv.Inputs {
for i := range logDeriv.Inputs[key].Numerator {
for size := range logDeriv.Inputs {

for i := range logDeriv.Inputs[size].Numerator {

// if Denominator is in the module pass the numerator from initialComp to moduleComp
// Particularly, T might be in the module and needs to take M from initialComp.
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Denominator[i], in.ModuleName) {
if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Numerator[i], in.ModuleName) {
distributed.PassColumnToModule(initialComp, moduleComp, initialProver, logDeriv.Inputs[key].Numerator[i])
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Denominator[i], in.ModuleName) {

if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Numerator[i], in.ModuleName) {
distributed.PassColumnToModule(initialComp, moduleComp, initialProver, logDeriv.Inputs[size].Numerator[i])
}
denominator = append(denominator, logDeriv.Inputs[key].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[key].Numerator[i])

denominator = append(denominator, logDeriv.Inputs[size].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[size].Numerator[i])

// replaces the external coins with local coins
// note that they just appear in the denominator.
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[key].Denominator[i])
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[size].Denominator[i])
keyIsInModule = true
}
}

// if there in any expression relevant to the current key, add them to zCatalog
if keyIsInModule {
// zCatalog specific to the module
zCatalog[key] = &query.LogDerivativeSumInput{
Round: key[0],
Size: key[1],
zCatalog[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerator,
Denominator: denominator,
}
}
keyIsInModule = false

}
// sanity check; the initialComp has only two rounds
if initialComp.NumRounds() != 2 {
utils.Panic("expected initialComp to have 2 rounds but it has %v rounds", initialComp.NumRounds())
keyIsInModule = false
}

// insert a LogDerivativeSum specific to the module at round 1 (since initialComp has 2 rounds).
moduleComp.InsertLogDerivativeSum(
1,
round,
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
zCatalog,
)

// prover step to assign the parameters of LogDerivativeSum at the same round.
moduleComp.SubProvers.AppendToInner(1, func(run *wizard.ProverRuntime) {
moduleComp.SubProvers.AppendToInner(round, func(run *wizard.ProverRuntime) {
run.AssignLogDerivSum(
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
GetLogDerivativeSumResult(zCatalog, run),
getLogDerivativeSumResult(zCatalog, run),
)
})

}

// GetLogDerivativeSumResult is a helper allowing the prover to calculate the result of its associated LogDerivativeSum query.
func GetLogDerivativeSumResult(zCatalog map[[2]int]*query.LogDerivativeSumInput, run *wizard.ProverRuntime) field.Element {
// getLogDerivativeSumResult is a helper allowing the prover to calculate the result of its associated LogDerivativeSum query.
func getLogDerivativeSumResult(zCatalog map[int]*query.LogDerivativeSumInput, run *wizard.ProverRuntime) field.Element {
// compute the actual sum from the Numerator and Denominator
actualSum := field.Zero()
for key := range zCatalog {
Expand Down Expand Up @@ -178,6 +168,7 @@ func GetLogDerivativeSumResult(zCatalog map[[2]int]*query.LogDerivativeSumInput,
packedZ[k].Add(&packedZ[k], &packedZ[k-1])
}
}

actualSum.Add(&actualSum, &packedZ[len(packedZ)-1])
}
}
Expand Down
29 changes: 12 additions & 17 deletions prover/protocol/distributed/preparation.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) {
// zCatalog stores a mapping (round, size) into query.LogDerivativeSumInput and helps finding
// which Z context should be used to handle a part of a given inclusion
// query.
zCatalog = map[[2]int]*query.LogDerivativeSumInput{}
zCatalog = map[int]*query.LogDerivativeSumInput{}
)

// Skip the compilation phase if no lookup constraint is being used. Otherwise
Expand All @@ -60,45 +60,41 @@ func IntoLogDerivativeSum(comp *wizard.CompiledIOP) {
)

// push single-columns into zCatalog
PushToZCatalog(tableCtx, zCatalog)
pushToZCatalog(tableCtx, zCatalog)

a := lookup.MAssignmentTask{
M: tableCtx.M,
S: checkTable,
T: lookupTable,
SFilter: includedFilters,
}

// assign the multiplicity column
comp.SubProvers.AppendToInner(round, a.Run)

}

// insert a single LogDerivativeSum query for the global zCatalog.
comp.InsertLogDerivativeSum(lastRound, "GlobalLogDerivativeSum", zCatalog)
comp.InsertLogDerivativeSum(lastRound+1, "GlobalLogDerivativeSum", zCatalog)

// assign parameters of LogDerivativeSum, it is just to prevent the panic attack in the prover
comp.SubProvers.AppendToInner(lastRound, func(run *wizard.ProverRuntime) {
comp.SubProvers.AppendToInner(lastRound+1, func(run *wizard.ProverRuntime) {
run.AssignLogDerivSum("GlobalLogDerivativeSum", field.Zero())
})
}

// PushToZCatalog constructs the numerators and denominators for the collapsed S and T
// pushToZCatalog constructs the numerators and denominators for the collapsed S and T
// into zCatalog, for their corresponding rounds and size.
func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDerivativeSumInput) {

var (
round = stc.Gamma.Round
)
func pushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[int]*query.LogDerivativeSumInput) {

// tableCtx push to -> zCtx
// Process the T columns
for frag := range stc.T {
size := stc.M[frag].Size()

key := [2]int{round, size}
key := size
if zCatalog[key] == nil {
zCatalog[key] = &query.LogDerivativeSumInput{
Size: size,
Round: round,
Size: size,
}
}

Expand All @@ -118,11 +114,10 @@ func PushToZCatalog(stc lookup.SingleTableCtx, zCatalog map[[2]int]*query.LogDer
sFilter = symbolic.NewVariable(stc.SFilters[table])
}

key := [2]int{round, size}
key := size
if zCatalog[key] == nil {
zCatalog[key] = &query.LogDerivativeSumInput{
Size: size,
Round: round,
Size: size,
}
}

Expand Down
10 changes: 7 additions & 3 deletions prover/protocol/query/logderiv_sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

// LogDerivativeSumInput stores the input to the query
type LogDerivativeSumInput struct {
Round, Size int
Size int
Numerator []*sym.Expression // T -> -M, S -> +Filter
Denominator []*sym.Expression // S or T -> ({S,T} + X)
}
Expand All @@ -27,7 +27,8 @@ type LogDerivativeSumInput struct {
// N_{i,j} is the i-th element of the underlying column of j-th Numerator
// D_{i,j} is the i-th element of the underlying column of j-th Denominator
type LogDerivativeSum struct {
Inputs map[[2]int]*LogDerivativeSumInput
Round int
Inputs map[int]*LogDerivativeSumInput
ID ifaces.QueryID
}

Expand All @@ -42,24 +43,27 @@ func (l LogDerivSumParams) UpdateFS(fs *fiatshamir.State) {
}

// NewLogDerivativeSum creates the new context LogDerivativeSum.
func NewLogDerivativeSum(inp map[[2]int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum {
func NewLogDerivativeSum(round int, inp map[int]*LogDerivativeSumInput, id ifaces.QueryID) LogDerivativeSum {

// check the length consistency
for key := range inp {
if len(inp[key].Numerator) != len(inp[key].Denominator) || len(inp[key].Numerator) == 0 {
utils.Panic("Numerator and Denominator should have the same (no-zero) length, %v , %v", len(inp[key].Numerator), len(inp[key].Denominator))
}
for i := range inp[key].Numerator {

if err := inp[key].Numerator[i].Validate(); err != nil {
utils.Panic(" Numerator[%v] is not a valid expression", i)
}

if err := inp[key].Denominator[i].Validate(); err != nil {
utils.Panic(" Denominator[%v] is not a valid expression", i)
}
}
}

return LogDerivativeSum{
Round: round,
Inputs: inp,
ID: id,
}
Expand Down
4 changes: 2 additions & 2 deletions prover/protocol/query/logderiv_sum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ func TestLogDerivSum(t *testing.T) {
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 0}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
key := 0
zCat1 := map[int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Numerator: numerators,
Denominator: denominators,
Expand Down
4 changes: 2 additions & 2 deletions prover/protocol/wizard/compiled.go
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,9 @@ func (c *CompiledIOP) InsertLocalOpening(round int, name ifaces.QueryID, pol ifa
// InsertLogDerivativeSum registers a new LogDerivativeSum query [query.LogDerivativeSum].
// It generates a single global summation for many Sigma Columns from Lookup compilation.
// The sigma columns are categorized by [round,size].
func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[[2]int]*query.LogDerivativeSumInput) query.LogDerivativeSum {
func (c *CompiledIOP) InsertLogDerivativeSum(lastRound int, id ifaces.QueryID, in map[int]*query.LogDerivativeSumInput) query.LogDerivativeSum {
c.assertConsistentRound(lastRound)
q := query.NewLogDerivativeSum(in, id)
q := query.NewLogDerivativeSum(lastRound, in, id)
// Finally registers the query
c.QueriesParams.AddToRound(lastRound, id, q)
return q
Expand Down

0 comments on commit 5f2718c

Please sign in to comment.