Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schnorr signature #6

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
8 changes: 7 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ plugins {
id "com.jfrog.bintray" version "1.8.5"
id "maven-publish"
id 'java-library'
id 'com.diffplug.gradle.spotless' version '3.28.1'
}

group = "com.ing.dlt"
version '1.0.8'

repositories {
mavenCentral()
maven { url "https://dl.bintray.com/ethereum/maven/" }
maven { url "https://dl.bintray.com/ethereum/maven/"}
}

dependencies {
Expand All @@ -35,6 +36,11 @@ java {
withJavadocJar()
}

spotless {
kotlin {
ktlint("0.37.1")
}
}

// Bintray publishing
bintray {
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/com/ing/dlt/zkkrypto/ecc/EllipticCurve.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ interface EllipticCurve : Arithmetic {

// identity element
val zero: EllipticCurvePoint
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ data class EllipticCurvePoint(val x: BigInteger, val y: BigInteger, val curve: E
fun scalarMult(scalar: BigInteger): EllipticCurvePoint = curve.scalarMult(this, scalar)
fun double(): EllipticCurvePoint = curve.double(this)
fun isOnCurve(): Boolean = curve.isOnCurve(this)
}
}
9 changes: 9 additions & 0 deletions src/main/kotlin/com/ing/dlt/zkkrypto/ecc/HashEnum.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.ing.dlt.zkkrypto.ecc

enum class HashEnum {
BLAKE2S,
BLAKE2B,
PEDERSEN,
MIMC,
POSEIDON
}
2 changes: 1 addition & 1 deletion src/main/kotlin/com/ing/dlt/zkkrypto/ecc/ZKHash.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ interface ZKHash {
val hashLength: Int

fun hash(msg: ByteArray): ByteArray
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ interface Arithmetic {
fun scalarMult(p: EllipticCurvePoint, scalar: BigInteger): EllipticCurvePoint
fun double(p: EllipticCurvePoint): EllipticCurvePoint
fun isOnCurve(p: EllipticCurvePoint): Boolean
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object EdwardsForm : Arithmetic {
// x = (pointA.x * pointB.y + pointA.y * pointB.x) / (1 + d * pointA.x * pointB.x * pointA.y * pointB.y)
// y = (pointA.y * pointB.y + pointA.x * pointB.x) / (1 - d * pointA.x * pointB.x * pointA.y * pointB.y)

if(a.curve != b.curve) throw IllegalArgumentException("Points should be on the same curve, A's curve is ${a.curve}, B's curve is ${b.curve}")
if (a.curve != b.curve) throw IllegalArgumentException("Points should be on the same curve, A's curve is ${a.curve}, B's curve is ${b.curve}")

val curve = a.curve as EdwardsCurve

Expand Down Expand Up @@ -43,7 +43,7 @@ object EdwardsForm : Arithmetic {
var doubling: EllipticCurvePoint = p.copy()
var result: EllipticCurvePoint = p.curve.zero

for( i in 0 until s.bitLength() ) {
for (i in 0 until s.bitLength()) {
if (s.testBit(i)) {
result = result.add(doubling)
}
Expand All @@ -69,5 +69,4 @@ object EdwardsForm : Arithmetic {

return y2.subtract(x2).mod(p.curve.R) == BigInteger.ONE.add(dTimesX2Y2).mod(p.curve.R)
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ object AltBabyJubjub : EdwardsCurve {
override val cofactor: BigInteger = BigInteger.valueOf(8)

override val zero: EllipticCurvePoint = EllipticCurvePoint(BigInteger.ZERO, BigInteger.ONE, this)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ object BabyJubjub : EdwardsCurve {
override val cofactor: BigInteger = BigInteger.valueOf(8)

override val zero: EllipticCurvePoint = EllipticCurvePoint(BigInteger.ZERO, BigInteger.ONE, this)

}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.ing.dlt.zkkrypto.ecc.curves

import com.ing.dlt.zkkrypto.ecc.arithmetic.EdwardsForm
import com.ing.dlt.zkkrypto.ecc.EllipticCurve
import com.ing.dlt.zkkrypto.ecc.EllipticCurvePoint
import com.ing.dlt.zkkrypto.ecc.arithmetic.EdwardsForm
import java.math.BigInteger

/**
Expand All @@ -19,5 +19,5 @@ interface EdwardsCurve : EllipticCurve {

override fun double(p: EllipticCurvePoint): EllipticCurvePoint = EdwardsForm.double(p)

override fun isOnCurve(p: EllipticCurvePoint): Boolean = EdwardsForm.isOnCurve(p)
}
override fun isOnCurve(p: EllipticCurvePoint): Boolean = EdwardsForm.isOnCurve(p)
}
2 changes: 1 addition & 1 deletion src/main/kotlin/com/ing/dlt/zkkrypto/ecc/curves/Jubjub.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ object Jubjub : EdwardsCurve {
override val cofactor: BigInteger = BigInteger.valueOf(8)

override val zero: EllipticCurvePoint = EllipticCurvePoint(BigInteger.ZERO, BigInteger.ONE, this)
}
}
13 changes: 6 additions & 7 deletions src/main/kotlin/com/ing/dlt/zkkrypto/ecc/mimc/Mimc7Hash.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,24 @@ import org.bouncycastle.jcajce.provider.digest.Keccak
import java.math.BigInteger
import kotlin.math.min


data class Mimc7Hash(
val r: BigInteger = BabyJubjub.R,
val numRounds: Int = defaultNumRounds,
val roundConstants: List<BigInteger> = generateRoundConstants(r = r, numRounds = numRounds)
): ZKHash {
) : ZKHash {

/**
* Hash size in bytes
*/
override val hashLength = r.bitLength() / 8 + if(r.bitLength() / 8 != 0) 1 else 0
override val hashLength = r.bitLength() / 8 + if (r.bitLength() / 8 != 0) 1 else 0

override fun hash(msg: ByteArray): ByteArray = hash(bytesToField(msg))

fun hash(msg: List<BigInteger>): ByteArray {

// check if all elements are in field R
msg.forEach {
if( it > r) throw IllegalArgumentException("Element $it is not in field $r")
if (it > r) throw IllegalArgumentException("Element $it is not in field $r")
}

var result = BigInteger.ZERO
Expand All @@ -33,7 +32,7 @@ data class Mimc7Hash(
result = (result + it + hashElement(it, result)) % r
}
val bytes = result.toByteArray()
return if(bytes.size == hashLength)
return if (bytes.size == hashLength)
bytes
else
ByteArray(hashLength - bytes.size).plus(bytes)
Expand Down Expand Up @@ -63,7 +62,7 @@ data class Mimc7Hash(

for (i in msg.indices step n) {
// We revert array here because bytes are supposed to be little-endian unlike BigInteger
val int = BigInteger(1, msg.sliceArray(i until min(i+n, msg.size)).reversedArray())
val int = BigInteger(1, msg.sliceArray(i until min(i + n, msg.size)).reversedArray())
ints.add(int)
}
return ints
Expand Down Expand Up @@ -99,4 +98,4 @@ data class Mimc7Hash(
} else bytes
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.ing.dlt.zkkrypto.ecc.pedersenhash

import com.ing.dlt.zkkrypto.ecc.EllipticCurvePoint
import com.ing.dlt.zkkrypto.ecc.EllipticCurve
import com.ing.dlt.zkkrypto.ecc.EllipticCurvePoint
import com.ing.dlt.zkkrypto.ecc.curves.AltBabyJubjub
import com.ing.dlt.zkkrypto.ecc.curves.BabyJubjub
import com.ing.dlt.zkkrypto.ecc.curves.Jubjub
Expand Down Expand Up @@ -162,4 +162,4 @@ object GeneratorsGenerator {
else -> throw IllegalArgumentException("Unknown curve: $curve")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ data class PedersenHash(
val curve: EllipticCurve,
val generators: List<EllipticCurvePoint> = GeneratorsGenerator.defaultForCurve(curve),
val defaultSalt: BitArray? = null
): ZKHash {
) : ZKHash {

init {
generators.forEach {
if (!it.isOnCurve()) throw IllegalStateException("Point is not on the curve")
if(it.curve != generators.first().curve) throw IllegalStateException("Generators should belong to same curve")
if (it.curve != generators.first().curve) throw IllegalStateException("Generators should belong to same curve")
}
if(window != 3) throw IllegalStateException("Only supporting window size 3 at the moment")
if (window != 3) throw IllegalStateException("Only supporting window size 3 at the moment")
}

private val chunkShift = window + 1
Expand All @@ -46,18 +46,18 @@ data class PedersenHash(
var hashPoint = curve.zero
val salted = salted(msg, salt)

if(salted.size > maxBitLength()) throw IllegalArgumentException("Message is too long, length = ${salted.size}, limit = ${maxBitLength()}")
if (salted.size > maxBitLength()) throw IllegalArgumentException("Message is too long, length = ${salted.size}, limit = ${maxBitLength()}")

val m = padded(salted)

val numProducts = numProducts(m)
for (i in 0 until numProducts) {
hashPoint = hashPoint.add(generators[i].scalarMult(product(m, i)))
hashPoint = hashPoint.add(generators[i].scalarMult(product(m, i)))
}

// we want constant size hashes so we add trailing zero bytes to the beginning
val bytes = hashPoint.x.toByteArray()
return if(bytes.size == hashLength)
return if (bytes.size == hashLength)
bytes
else
ByteArray(hashLength - bytes.size).plus(bytes)
Expand All @@ -68,11 +68,11 @@ data class PedersenHash(
val lowestBitIndex = (productIndex * chunksPerGenerator + chunkIndex) * window

var chunk = BigInteger.ONE
for(i in 0 until window-1) {
for (i in 0 until window - 1) {
chunk += m.get(lowestBitIndex + i).shiftLeft(i)
}

if(m.testBit(lowestBitIndex + (window-1))) {
if (m.testBit(lowestBitIndex + (window - 1))) {
// only works for ZCash 3-bits window now because iden3 4-bit algorithm uses different sign switch :shrug:
chunk = chunk.negate()
}
Expand All @@ -81,7 +81,7 @@ data class PedersenHash(
}

private fun fieldNegate(element: BigInteger): BigInteger {
return if(element >= BigInteger.ZERO) {
return if (element >= BigInteger.ZERO) {
element
} else {
curve.S + element
Expand All @@ -93,7 +93,7 @@ data class PedersenHash(

var product = BigInteger.ZERO

for(j in 0 until numChunksInProduct(msg, i)) {
for (j in 0 until numChunksInProduct(msg, i)) {
val chunk = fieldNegate(chunk(msg, i, j).shiftLeft(chunkShift * j))
product += chunk
}
Expand All @@ -104,13 +104,13 @@ data class PedersenHash(
// chunksPerGenerator in most cases but variable for last product (it can be shorter)
private fun numChunksInProduct(msg: BitArray, productIndex: Int): Int {

val lastProduct = if(msg.size % productBitSize() == 0) msg.size / productBitSize() - 1 else msg.size / productBitSize()
val lastProduct = if (msg.size % productBitSize() == 0) msg.size / productBitSize() - 1 else msg.size / productBitSize()

return if(productIndex == lastProduct) {
if(msg.size % productBitSize() == 0)
return if (productIndex == lastProduct) {
if (msg.size % productBitSize() == 0)
chunksPerGenerator
else {
msg.size % productBitSize() / window + if(msg.size % window == 0) 0 else 1
msg.size % productBitSize() / window + if (msg.size % window == 0) 0 else 1
}
} else chunksPerGenerator
}
Expand All @@ -123,7 +123,7 @@ data class PedersenHash(
return generators.size * productBitSize()
}

private fun numProducts(m: BitArray) = min(generators.size, m.size / productBitSize() + if(m.size % productBitSize() == 0) 0 else 1)
private fun numProducts(m: BitArray) = min(generators.size, m.size / productBitSize() + if (m.size % productBitSize() == 0) 0 else 1)

private fun salted(msg: BitArray, salt: BitArray?): BitArray {
return salt?.plus(msg) ?: msg
Expand All @@ -143,4 +143,3 @@ data class PedersenHash(
fun zcash() = PedersenHash(curve = Jubjub, chunksPerGenerator = 63)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ data class Constants(

fun defaultRoundConstants(): Constants {
return Constants(
defaultCStrings.map { it.map { BigInteger(it, 10) }},
defaultMStrings.map { it.map { it.map { BigInteger(it, 10) }}} ,
defaultCStrings.map { it.map { BigInteger(it, 10) } },
defaultMStrings.map { it.map { it.map { BigInteger(it, 10) } } },
defaultNumRoundsF,
defaultNumRoundsP
)
Expand Down Expand Up @@ -3472,4 +3472,4 @@ data class Constants(
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,35 +9,35 @@ import kotlin.math.min
data class PoseidonHash(
val r: BigInteger = BabyJubjub.R,
val constants: Constants = defaultRoundConstants()
): ZKHash {
) : ZKHash {

/**
* Hash size in bytes
*/
override val hashLength = r.bitLength() / 8 + if(r.bitLength() / 8 != 0) 1 else 0
override val hashLength = r.bitLength() / 8 + if (r.bitLength() / 8 != 0) 1 else 0

override fun hash(msg: ByteArray): ByteArray = hash(bytesToField(msg)).toByteArray()

fun hash(msg: List<BigInteger>): BigInteger {
fun hash(msg: List<BigInteger>): BigInteger {

if (msg.isEmpty() || msg.size >= constants.numRoundsP.size - 1)
throw Exception("Invalid inputs length: ${msg.size}, maximum allowed: ${constants.numRoundsP.size - 2}")

msg.forEach { if(it >= r) throw Exception("Element {$it} is out of the field {$r}") }
msg.forEach { if (it >= r) throw Exception("Element {$it} is out of the field {$r}") }

val t = msg.size + 1

var state = msg.plus(BigInteger.ZERO).toMutableList()

val nRoundsP = constants.numRoundsP[t-2]
val nRoundsP = constants.numRoundsP[t - 2]

val lastRound = constants.numRoundsF + nRoundsP - 1

for (round in 0 until constants.numRoundsF + nRoundsP) {

// Add Round Key
for (i in state.indices) {
state[i] = state[i] + constants.c[t-2][round * t + i]
state[i] = state[i] + constants.c[t - 2][round * t + i]
}

// S-Box
Expand All @@ -53,7 +53,7 @@ data class PoseidonHash(

// If not last round: mix (via matrix multiplication)
if (round != lastRound) {
state = mix(state, constants.m[t-2]).toMutableList()
state = mix(state, constants.m[t - 2]).toMutableList()
}
}

Expand Down Expand Up @@ -87,9 +87,9 @@ data class PoseidonHash(
val ints = mutableListOf<BigInteger>()

for (i in msg.indices step n) {
val int = BigInteger(1, msg.sliceArray(i until min(i+n, msg.size)))
val int = BigInteger(1, msg.sliceArray(i until min(i + n, msg.size)))
ints.add(int)
}
return ints
}
}
}
Loading