Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torus-acceleration for multiexponentiation on GT #485

Merged
merged 16 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions benchmarks/bench_fields_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,30 @@ import
./bench_blueprint

export notes, abstractions
proc separator*() = separator(165)
proc separator*() = separator(145)
proc smallSeparator*() = separator(8)

proc report(op, field: string, start, stop: MonoTime, startClk, stopClk: int64, iters: int) =
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
echo &"{op:<49} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<70} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"
echo &"{op:<49} {field:<18} {throughput:>15.3f} ops/s {ns:>9} ns/op"

macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
# we get the Curve ID instead of the curve name.
let instantiated = T.getTypeInst()
var name = $instantiated[1][0] # 𝔽p
name.add "[" & $Algebra(instantiated[1][1].intVal) & "]"
if instantiated[1][1].kind == nnkIntLit:
name.add "[" & $Algebra(instantiated[1][1].intVal) & "]"
else:
name.add "[" & $instantiated[1][1][0] # QuadraticExt[𝔽p6[
name.add "[" & $Algebra(instantiated[1][1][1].intVal) & "]]"
result = newLit name

template bench(op: string, T: typedesc, iters: int, body: untyped): untyped =
template bench*(op: string, T: typedesc, iters: int, body: untyped): untyped =
measure(iters, startTime, stopTime, startClk, stopClk, body)
report(op, fixFieldDisplay(T), startTime, stopTime, startClk, stopClk, iters)

Expand Down Expand Up @@ -184,10 +188,10 @@ proc sqrtBench*(T: typedesc, iters: int) =
"Tonelli-Shanks"
const addchain = block:
when T.Name.hasSqrtAddchain() or T.Name.hasTonelliShanksAddchain():
"with addition chain"
"+ addchain"
else:
"without addition chain"
const desc = "Square Root (constant-time " & algoType & " " & addchain & ")"
"no addchain"
const desc = "Sqrt (constant-time " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square()
Expand All @@ -211,10 +215,10 @@ proc sqrtVartimeBench*(T: typedesc, iters: int) =
"Tonelli-Shanks"
const addchain = block:
when T.Name.hasSqrtAddchain() or T.Name.hasTonelliShanksAddchain():
"with addition chain"
"+ addchain"
else:
"without addition chain"
const desc = "Square Root (vartime " & algoType & " " & addchain & ")"
"no addchain"
const desc = "Sqrt (vartime " & algoType & " " & addchain & ")"
bench(desc, T, iters):
var r = x
discard r.sqrt_if_square_vartime()
Expand Down
14 changes: 11 additions & 3 deletions benchmarks/bench_gt_multiexp_bls12_381.nim
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,25 @@ const AvailableCurves = [
BLS12_381,
]

const testNumPoints = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
# const testNumPoints = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
const testNumPoints = [128, 256]


type Fp12over4[C: static Algebra] = CubicExt[Fp4[C]]
type Fp12over6[C: static Algebra] = QuadraticExt[Fp6[C]]

proc main() =
separator()
staticFor i, 0, AvailableCurves.len:
const curve = AvailableCurves[i]
var ctx = createBenchMultiExpContext(Fp12[curve], testNumPoints)
var ctx12o4 = createBenchMultiExpContext(Fp12over4[curve], testNumPoints)
var ctx12o6 = createBenchMultiExpContext(Fp12over6[curve], testNumPoints)
separator()
for numPoints in testNumPoints:
let batchIters = max(1, Iters div numPoints)
ctx.multiExpParallelBench(numPoints, batchIters)
ctx12o4.multiExpParallelBench(numPoints, batchIters)
echo "----"
ctx12o6.multiExpParallelBench(numPoints, batchIters)
separator()
separator()

Expand Down
80 changes: 64 additions & 16 deletions benchmarks/bench_gt_parallel_template.nim
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ proc report(op, domain: string, start, stop: MonoTime, startClk, stopClk: int64,
let ns = inNanoseconds((stop-start) div iters)
let throughput = 1e9 / float64(ns)
when SupportsGetTicks:
echo &"{op:<68} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
echo &"{op:<65} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)"
else:
echo &"{op:<68} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op"
echo &"{op:<65} {domain:<20} {throughput:>15.3f} ops/s {ns:>9} ns/op"

macro fixFieldDisplay(T: typedesc): untyped =
# At compile-time, enums are integers and their display is buggy
Expand All @@ -52,7 +52,7 @@ macro fixFieldDisplay(T: typedesc): untyped =
result = newLit name

func fixDisplay(T: typedesc): string =
when T is (Fp or Fp2 or Fp4 or Fp6 or Fp12):
when T is (Fp or ExtensionField):
fixFieldDisplay(T)
else:
$T
Expand All @@ -68,7 +68,7 @@ func random_gt*(rng: var RngState, F: typedesc): F {.inline, noInit.} =
result = rng.random_unsafe(F)
result.finalExp()

# Multi-exponentiations
# multi-exp
# ---------------------------------------------------------------------------

type BenchMultiexpContext*[GT] = object
Expand Down Expand Up @@ -126,11 +126,19 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

var r{.noInit.}: GT
var startNaive, stopNaive, startMultiExpBaseline, stopMultiExpBaseline: MonoTime
var startMultiExpOpt, stopMultiExpOpt, startMultiExpPara, stopMultiExpPara: MonoTime
var startMultiExpOpt, stopMultiExpOpt: MonoTime
var startMultiExpPara, stopMultiExpPara: MonoTime
var startMultiExpParaTorus, stopMultiExpParaTorus: MonoTime

when GT is QuadraticExt:
var startMultiExpBaselineTorus: MonoTime
var stopMultiExpBaselineTorus: MonoTime
var startMultiExpOptTorus: MonoTime
var stopMultiExpOptTorus: Monotime

if numInputs <= 100000:
# startNaive = getMonotime()
bench("𝔾ₜ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("𝔾ₜ exponentiations " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -140,7 +148,7 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startNaive = getMonotime()
bench("𝔾ₜ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
bench("𝔾ₜ exponentiations vartime " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
var tmp: GT
r.setOne()
for i in 0 ..< elems.len:
Expand All @@ -150,30 +158,59 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in

if numInputs <= 100000:
startMultiExpBaseline = getMonotime()
bench("𝔾ₜ multi-exponentiations baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents)
bench("𝔾ₜ multi-exp baseline " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents, useTorus = false)
stopMultiExpBaseline = getMonotime()

if numInputs <= 100000:
when GT is QuadraticExt:
startMultiExpBaselineTorus = getMonotime()
bench("𝔾ₜ multi-exp baseline + torus" & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_reference_vartime(elems, exponents, useTorus = true)
stopMultiExpBaselineTorus = getMonotime()

block:
startMultiExpOpt = getMonotime()
bench("𝔾ₜ multi-exponentiations optimized " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents)
bench("𝔾ₜ multi-exp opt " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents, useTorus = false)
stopMultiExpOpt = getMonotime()

when GT is QuadraticExt:
block:
startMultiExpOptTorus = getMonotime()
bench("𝔾ₜ multi-exp opt + torus " & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
r.multiExp_vartime(elems, exponents, useTorus = true)
stopMultiExpOptTorus = getMonotime()

block:
ctx.tp = Threadpool.new()

startMultiExpPara = getMonotime()
bench("𝔾ₜ multi-exponentiations" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents)
bench("𝔾ₜ multi-exp " & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents, useTorus = false)
stopMultiExpPara = getMonotime()

ctx.tp.shutdown()

when GT is QuadraticExt:
block:
ctx.tp = Threadpool.new()

startMultiExpParaTorus = getMonotime()
bench("𝔾ₜ multi-exp torus" & align($ctx.tp.numThreads & " threads", 11) & align($numInputs, 10) & " (" & $bits & "-bit exponents)", GT, iters):
ctx.tp.multiExp_vartime_parallel(r, elems, exponents, useTorus = true)
stopMultiExpParaTorus = getMonotime()

ctx.tp.shutdown()

let perfNaive = inNanoseconds((stopNaive-startNaive) div iters)
let perfMultiExpBaseline = inNanoseconds((stopMultiExpBaseline-startMultiExpBaseline) div iters)
let perfMultiExpOpt = inNanoseconds((stopMultiExpOpt-startMultiExpOpt) div iters)
let perfMultiExpPara = inNanoseconds((stopMultiExpPara-startMultiExpPara) div iters)
when GT is QuadraticExt:
let perfMultiExpBaselineTorus = inNanoseconds((stopMultiExpBaselineTorus-startMultiExpBaselineTorus) div iters)
let perfMultiExpOptTorus = inNanoseconds((stopMultiExpOptTorus-startMultiExpOptTorus) div iters)
let perfMultiExpParaTorus = inNanoSeconds((stopMultiExpParaTorus-startMultiExpParaTorus) div iters)

if numInputs <= 100000:
let speedupBaseline = float(perfNaive) / float(perfMultiExpBaseline)
Expand All @@ -182,8 +219,19 @@ proc multiExpParallelBench*[GT](ctx: var BenchMultiExpContext[GT], numInputs: in
let speedupOpt = float(perfNaive) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over naive linear combination: {speedupOpt:>6.3f}x"

let speedupOptBaseline = float(perfMultiExpBaseline) / float(perfMultiExpOpt)
echo &"Speedup ratio optimized over baseline linear combination: {speedupOptBaseline:>6.3f}x"
when GT is QuadraticExt:
let speedupTorusOverBaseline = float(perfMultiExpBaseline) / float(perfMultiExpBaselineTorus)
echo &"Speedup ratio baseline + Torus over baseline linear combination: {speedupTorusOverBaseline:>6.3f}x"

let speedupTorusOverOpt = float(perfMultiExpOpt) / float(perfMultiExpOptTorus)
echo &"Speedup ratio optimized + Torus over optimized: {speedupTorusOverOpt:>6.3f}x"

let speedupParaOpt = float(perfMultiExpOpt) / float(perfMultiExpPara)
echo &"Speedup ratio parallel over optimized linear combination: {speedupParaOpt:>6.3f}x"
echo &"Speedup ratio parallel over serial optimized linear combination: {speedupParaOpt:>6.3f}x"

when GT is QuadraticExt:
let speedupParaTorus = float(perfMultiExpOptTorus) / float(perfMultiExpParaTorus)
echo &"Speedup ratio parallel over serial for Torus-based multiexp: {speedupParaTorus:>6.3f}x"

let speedupParaTorusOpt = float(perfMultiExpPara) / float(perfMultiExpParaTorus)
echo &"Speedup ratio parallel over parallel Torus-based multiexp: {speedupParaTorusOpt:>6.3f}x"
Loading
Loading