Skip to content

Commit

Permalink
Merge branch 'prover/limitless-top-level' into prover/compile-grand-p…
Browse files Browse the repository at this point in the history
…roduct
  • Loading branch information
arijitdutta67 committed Jan 11, 2025
2 parents 4340a33 + 96e428f commit 66a2ab9
Show file tree
Hide file tree
Showing 16 changed files with 392 additions and 185 deletions.
20 changes: 14 additions & 6 deletions prover/protocol/compiler/logderivativesum/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,39 @@ func CompileLogDerivSum(comp *wizard.CompiledIOP) {
// compilation process. We know that the query was already ignored at
// the beginning because we are iterating over the unignored keys.
comp.QueriesParams.MarkAsIgnored(qName)
// get the Numerator and Denominator from the input and prepare their compilation.
zEntries := logDeriv.Inputs
va := FinalEvaluationCheck{}

var (
zEntries = logDeriv.Inputs
va = FinalEvaluationCheck{
LogDerivSumID: qName,
}
lastRound = logDeriv.Round
)

for _, entry := range zEntries {

// get the Numerator and Denominator from the input and prepare their compilation.
zC := &lookup.ZCtx{
Round: entry.Round,
Round: lastRound,
Size: entry.Size,
SigmaNumerator: entry.Numerator,
SigmaDenominator: entry.Denominator,
}

// z-packing compile; it imposes the correct accumulation over Numerator and Denominator.
zC.Compile(comp)

// prover step; Z assignments
zAssignmentTask := lookup.ZAssignmentTask(*zC)
comp.SubProvers.AppendToInner(zC.Round, func(run *wizard.ProverRuntime) {
zAssignmentTask.Run(run)
})

// collect all the zOpening for all the z columns
va.ZOpenings = append(va.ZOpenings, zC.ZOpenings...)
}

// verifer step
va.LogDerivSumID = qName
lastRound := comp.NumRounds() - 1
comp.RegisterVerifierAction(lastRound, &va)
}

Expand Down
9 changes: 4 additions & 5 deletions prover/protocol/compiler/logderivativesum/logderivsum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ func TestLogDerivSum(t *testing.T) {
ifaces.ColumnAsVariable(q2),
}

key := [2]int{0, 4}
zCat1 := map[[2]int]*query.LogDerivativeSumInput{}
zCat1[key] = &query.LogDerivativeSumInput{
Round: 0,
Size: 4,
size := 4
zCat1 := map[int]*query.LogDerivativeSumInput{}
zCat1[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerators,
Denominator: denominators,
}
Expand Down
99 changes: 99 additions & 0 deletions prover/protocol/distributed/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package distributed

import (
"github.com/consensys/linea-monorepo/prover/protocol/coin"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/consensys/linea-monorepo/prover/utils"
)

// ReplaceExternalCoins replaces the external coins with local coins, for a given expression.
// It does not check if all the columns from the expression are in the module.
// If this is required should be check before calling ReplaceExternalCoins.
// If the Coin does not exist in the initialComp it panics.
func ReplaceExternalCoins(initialComp, moduleComp *wizard.CompiledIOP, expr *symbolic.Expression) {
var (
board = expr.Board()
metadata = board.ListVariableMetadata()
)
for _, m := range metadata {
switch v := m.(type) {
case coin.Info:

if !initialComp.Coins.Exists(v.Name) {
utils.Panic("Coin %v does not exist in the InitialComp", v.Name)
}
if v.Round != 1 {
utils.Panic("Coin %v is declared in round %v != 1", v.Name, v.Round)
}
if !moduleComp.Coins.Exists(v.Name) {
moduleComp.InsertCoin(v.Round, v.Name, coin.Field)
}
}
}
}

// GetFreshModuleComp creates a [wizard.CompiledIOP] object including only the columns relevant to the module.
// It also contains the prover steps for assigning the module column
func GetFreshModuleComp(
initialComp *wizard.CompiledIOP,
disc ModuleDiscoverer,
initialProver wizard.ProverStep,
moduleName ModuleName,
) *wizard.CompiledIOP {

var (
// initialize the moduleComp
moduleComp = wizard.NewCompiledIOP()
initialRunTime = wizard.RunProver(initialComp, initialProver)
)

for round := 0; round < initialComp.NumRounds(); round++ {
var columnsInRound []ifaces.Column
// get the columns per round
for _, colName := range initialComp.Columns.AllKeysAt(round) {

col := initialComp.Columns.GetHandle(colName)
if !disc.ColumnIsInModule(col, moduleName) {
continue
}

moduleComp.InsertCommit(col.Round(), col.GetColID(), col.Size())
columnsInRound = append(columnsInRound, col)
}

// create a new moduleProver
moduleProver := moduleProver{
cols: columnsInRound,
initRun: initialRunTime,
round: round,
}

// register Prover action for the module to assign columns per round
moduleComp.RegisterProverAction(round, moduleProver)

}

return moduleComp
}

// it stores the input for the module prover
type moduleProver struct {
round int
// columns for a specific round
cols []ifaces.Column
// runtime of the initial Prover that is parent to the module.
initRun *wizard.ProverRuntime
}

// It implements [wizard.ProverAction] for the module prover.
func (p moduleProver) Run(run *wizard.ProverRuntime) {
for _, col := range p.cols {
// get the witness from the initialProver
colWitness := p.initRun.GetColumn(col.GetColID())
// assign it in the module in the round col was declared
run.AssignColumn(col.GetColID(), colWitness, col.Round())
}

}
178 changes: 115 additions & 63 deletions prover/protocol/distributed/compiler/inclusion/inclusion.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package inclusion

import (
"github.com/consensys/linea-monorepo/prover/maths/common/vector"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/protocol/column"
"github.com/consensys/linea-monorepo/prover/protocol/distributed"
"github.com/consensys/linea-monorepo/prover/protocol/ifaces"
"github.com/consensys/linea-monorepo/prover/protocol/query"
"github.com/consensys/linea-monorepo/prover/protocol/wizard"
"github.com/consensys/linea-monorepo/prover/symbolic"
"github.com/consensys/linea-monorepo/prover/utils"
)

const (
Expand All @@ -21,99 +25,147 @@ type DistributionInputs struct {
// Name of the module
ModuleName distributed.ModuleName
// query is supposed to be the global LogDerivativeSum.
QueryID ifaces.QueryID
Query query.LogDerivativeSum
}

// GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query.
// It inserts a new LogDerivativeSum for the extracted share.
func GetShareOfLogDerivativeSum(in DistributionInputs) {
// DistributeLogDerivativeSum distributes a share from a global [query.LogDerivativeSum] query to the given module.
func DistributeLogDerivativeSum(
initialComp, moduleComp *wizard.CompiledIOP,
moduleName distributed.ModuleName,
disc distributed.ModuleDiscoverer,
) {

var (
initialComp = in.InitialComp
moduleComp = in.ModuleComp
numerator []*symbolic.Expression
denominator []*symbolic.Expression
zCatalog = make(map[[2]int]*query.LogDerivativeSumInput)
lastRound = in.InitialComp.NumRounds() - 1
queryID ifaces.QueryID
)
// check that the given query is a valid LogDerivateSum query in the CompiledIOP.
logDeriv, ok := initialComp.QueriesParams.Data(in.QueryID).(query.LogDerivativeSum)
if !ok {
panic("the given query is not a valid LogDerivativeSum from the compiledIOP")
}

// This ensures that the logDerivative query is not used again in the
// compilation process for the module.
/* _, ok = moduleComp.QueriesParams.Data(in.QueryID).(query.LogDerivativeSum)
if ok {
moduleComp.QueriesNoParams.MarkAsIgnored(in.QueryID)
} */

// also mark all the inclusion queries in the module as ignored
// @Azam this is because for the moment we dont know how the module-discoverer extracts moduleComp from InitialComp.
// if we are sure that inclusions are already removed from modComp, we can skip this step here.
for _, qName := range moduleComp.QueriesNoParams.AllUnignoredKeys() {
// Filter out non lookup queries
_, ok := moduleComp.QueriesNoParams.Data(qName).(query.Inclusion)
for _, qName := range initialComp.QueriesParams.AllUnignoredKeys() {

_, ok := initialComp.QueriesParams.Data(qName).(query.LogDerivativeSum)
if !ok {
continue
}
moduleComp.QueriesNoParams.MarkAsIgnored(qName)

// panic if there is more than a LogDerivativeSum query in the initialComp.
if string(queryID) != "" {
utils.Panic("found more than a LogDerivativeSum query in the initialComp")
}

queryID = qName
}

// get the share of the module from the LogDerivativeSum query
GetShareOfLogDerivativeSum(DistributionInputs{
ModuleComp: moduleComp,
InitialComp: initialComp,
Disc: disc,
ModuleName: moduleName,
Query: initialComp.QueriesParams.Data(queryID).(query.LogDerivativeSum),
})

}

// GetShareOfLogDerivativeSum extracts the share of the given modules from the given LogDerivativeSum query.
// It inserts a new LogDerivativeSum for the extracted share.
func GetShareOfLogDerivativeSum(in DistributionInputs) {

var (
initialComp = in.InitialComp
moduleComp = in.ModuleComp
numerator []*symbolic.Expression
denominator []*symbolic.Expression
keyIsInModule bool
zCatalog = make(map[int]*query.LogDerivativeSumInput)
logDeriv = in.Query
round = logDeriv.Round
)

// extract the share of the module from the global sum.
for key := range logDeriv.Inputs {
for i := range logDeriv.Inputs[key].Numerator {
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Numerator[i], in.ModuleName) {
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[key].Denominator[i], in.ModuleName) {
denominator = append(denominator, logDeriv.Inputs[key].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[key].Numerator[i])
for size := range logDeriv.Inputs {

for i := range logDeriv.Inputs[size].Numerator {

// if Denominator is in the module pass the numerator from initialComp to moduleComp
// Particularly, T might be in the module and needs to take M from initialComp.
if in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Denominator[i], in.ModuleName) {

if !in.Disc.ExpressionIsInModule(logDeriv.Inputs[size].Numerator[i], in.ModuleName) {
utils.Panic("Denominator is in the module but not Numerator")
}

denominator = append(denominator, logDeriv.Inputs[size].Denominator[i])
numerator = append(numerator, logDeriv.Inputs[size].Numerator[i])

// replaces the external coins with local coins
// note that they just appear in the denominator.
distributed.ReplaceExternalCoins(initialComp, moduleComp, logDeriv.Inputs[size].Denominator[i])
keyIsInModule = true
}
}

// if there in any numerator associated with the current key add it to the map.
if len(numerator) != 0 {
// if there in any expression relevant to the current key, add them to zCatalog
if keyIsInModule {
// zCatalog specific to the module
zCatalog[key] = &query.LogDerivativeSumInput{
Round: key[0],
Size: key[1],
zCatalog[size] = &query.LogDerivativeSumInput{
Size: size,
Numerator: numerator,
Denominator: denominator,
}
}

keyIsInModule = false
}

// insert a LogDerivativeSum specific to the module.
// insert a LogDerivativeSum specific to the module at round 1 (since initialComp has 2 rounds).
moduleComp.InsertLogDerivativeSum(
lastRound,
round,
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
zCatalog,
)

// prover step to assign the parameters of LogDerivativeSum at the same round.
moduleComp.SubProvers.AppendToInner(round, func(run *wizard.ProverRuntime) {
run.AssignLogDerivSum(
ifaces.QueryIDf("%v_%v", LogDerivativeSum, in.ModuleName),
getLogDerivativeSumResult(zCatalog, run),
)
})

}

// DistributeLogDerivativeSum extract the LogDerivativeSum query that is subject to the distribution.
// It ignores the inclusion queries in the module compiledIOP and replaces them with its share of LogDerivativeSum.
func DistributeLogDerivativeSum(initialComp, moduleComp *wizard.CompiledIOP, moduleName distributed.ModuleName, disc distributed.ModuleDiscoverer) {
// getLogDerivativeSumResult is a helper allowing the prover to calculate the result of its associated LogDerivativeSum query.
func getLogDerivativeSumResult(zCatalog map[int]*query.LogDerivativeSumInput, run *wizard.ProverRuntime) field.Element {
// compute the actual sum from the Numerator and Denominator
actualSum := field.Zero()
for key := range zCatalog {
for i, num := range zCatalog[key].Numerator {

var queryID ifaces.QueryID
for _, qName := range initialComp.QueriesParams.AllUnignoredKeys() {
_, ok := initialComp.QueriesParams.Data(qName).(query.LogDerivativeSum)
if !ok {
continue
var (
numBoard = num.Board()
denBoard = zCatalog[key].Denominator[i].Board()
numeratorMetadata = numBoard.ListVariableMetadata()
denominator = column.EvalExprColumn(run, denBoard).IntoRegVecSaveAlloc()
numerator []field.Element
packedZ = field.BatchInvert(denominator)
)

if len(numeratorMetadata) == 0 {
numerator = vector.Repeat(field.One(), zCatalog[key].Size)
}

if len(numeratorMetadata) > 0 {
numerator = column.EvalExprColumn(run, numBoard).IntoRegVecSaveAlloc()
}

for k := range packedZ {
packedZ[k].Mul(&numerator[k], &packedZ[k])
if k > 0 {
packedZ[k].Add(&packedZ[k], &packedZ[k-1])
}
}

actualSum.Add(&actualSum, &packedZ[len(packedZ)-1])
}
queryID = qName
//@Azam panic if it has more than one.
// it breaks since we expect only a single query of this type.
break
}
input := DistributionInputs{
ModuleComp: moduleComp,
InitialComp: initialComp,
Disc: disc,
ModuleName: moduleName,
QueryID: queryID,
}
GetShareOfLogDerivativeSum(input)

return actualSum
}
Loading

0 comments on commit 66a2ab9

Please sign in to comment.