Skip to content

Commit

Permalink
Merge pull request #7 from okwme/convertSqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
okwme authored Oct 26, 2023
2 parents 84ea71a + d0fb991 commit 962f375
Show file tree
Hide file tree
Showing 18 changed files with 346 additions and 134 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,18 @@ You should see something like:

Currently the project is targeting [powersOfTau28_hez_final_20.ptau](https://github.com/iden3/snarkjs/blob/master/README.md#7-prepare-phase-2) which has a limit of 1MM constraints. Below is a table of the number of constraints used by each circuit.

| Circuit | Non-Linear Constraints |
|---------|-------------|
| absoluteValueSubtraction(252) | 259 |
| acceptableMarginOfError(60) | 128 |
| calculateForce() | 717 |
| detectCollision(3) | 510 |
| forceAccumulator(3) | 2821 |
| getDistance(20) | 142 |
| limiter(252) | 257 |
| lowerLimiter(252) | 257 |
| nft(3, 10) | 28039 |
| stepState(3, 10) | 33531 |
| Circuit | Non-Linear Constraints | seconds at 25fps under 1MM constraints |
|---------|-------------|---------------------------------------------------|
| absoluteValueSubtraction(252) | 257 | 155.64 |
| acceptableMarginOfError(60) | 125 | 320 |
| calculateForce() | 279 | 143.37 |
| detectCollision(3) | 348 | 114.94 |
| forceAccumulator(3) | 1522 | 26.28 |
| getDistance(20) | 88 | 454.55 |
| limiter(252) | 254 | 157.48|
| lowerLimiter(252) | 254 | 157.48|
| nft(3, 10) | 15184 | 26.34 |
| stepState(3, 10) | 19121 | 20.92 |

# built using circom-starter

Expand Down
156 changes: 131 additions & 25 deletions circuits/approxMath.circom
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
pragma circom 2.1.6;
include "../node_modules/circomlib/circuits/mux1.circom";
include "../node_modules/circomlib/circuits/comparators.circom";
include "../node_modules/circomlib/circuits/gates.circom";
include "helpers.circom";

function approxSqrt(n) {
if (n == 0) {
return 0;
return [0,0,0];
}

var lo = 0;
Expand All @@ -16,7 +18,7 @@ function approxSqrt(n) {
// TODO: Make more accurate by checking if lo + hi is odd or even before bit shifting
midSquared = (mid * mid);
if (midSquared == n) {
return mid; // Exact square root found
return [lo,mid,hi]; // Exact square root found
} else if (midSquared < n) {
lo = mid + 1; // Adjust lower bound
} else {
Expand All @@ -25,35 +27,140 @@ function approxSqrt(n) {
}
// If we reach here, no exact square root was found.
// return the closest approximation
return mid;
return [lo, mid, hi];
}


function approxDiv(dividend, divisor) {
var bitsDivident = 0;
var dividendCopy = dividend;
while(dividendCopy > 0) {
bitsDivident++;
dividendCopy = dividendCopy >> 1;
if (dividend == 0) {
return 0;
}

// Create internal signals for our binary search
var lowerBound, upperBound, midPoint, testProduct;
var lo, hi, mid, testProduct;

// Initialize our search space
lowerBound = 0;
upperBound = dividend; // Assuming worst case where divisor = 1

for (var i = 0; i < bitsDivident; i++) { // 32 iterations for 32-bit numbers as an example
midPoint = (upperBound + lowerBound) >> 1;
testProduct = midPoint * divisor;
// Adjust our bounds based on the test product
if (testProduct > dividend) {
upperBound = midPoint;
} else {
lowerBound = midPoint;
}
lo = 0;
hi = dividend; // Assuming worst case where divisor = 1

while (lo < hi) { // 32 iterations for 32-bit numbers as an example
mid = (hi + lo + 1) >> 1;
testProduct = mid * divisor;

// Adjust our bounds based on the test product
if (testProduct > dividend) {
hi = mid - 1;
} else {
lo = mid;
}
}
// Output the lo as our approximated quotient after iterations
// quotient <== lo;
return lo;
}

template Div() {
signal input dividend;
signal input divisor;
signal output quotient;

quotient <-- approxDiv(dividend, divisor); // maxBits: 64 (maxNum: 10_400_000_000_000_000_000)

// NOTE: the following constraints the approxDiv to ensure it's within the acceptable error of margin
signal approxNumerator1 <== quotient * divisor; // maxBits: 126 (maxNum: 58_831_302_400_000_000_000_000_000_000_000_000_000)

// NOTE: approxDiv always rounds down so the approximate quotient will always be less
// than the actual quotient.
signal diff <== dividend - quotient;
// log("diff ", diff);
// log("dividend", dividend);
// log("divisor", divisor);
// log("quotient", quotient);

component lessThan = LessThan(64); // forceXnum; // maxBits: 64
lessThan.in[0] <== diff;
lessThan.in[1] <== dividend;
// log("lessThan", lessThan.out, "\n");

component isZero = IsZero();
isZero.in <== dividend;

component xor = XOR();
xor.a <== isZero.out;
xor.b <== lessThan.out;
xor.out === 1;
}

// Output the midpoint as our approximated quotient after iterations
return midPoint;
template Sqrt(unboundDistanceSquaredMax) {
signal input squaredValue;
signal output root;
signal approxSqrtResults[3];
approxSqrtResults <-- approxSqrt(squaredValue);
// approxSqrtResults[0] = lo
// approxSqrtResults[1] = mid
// approxSqrtResults[2] = hi
// log("squaredValue", squaredValue);
// log("approxSqrtResults[0]", approxSqrtResults[0]);
// log("approxSqrtResults[1]", approxSqrtResults[1]);
// log("approxSqrtResults[2]", approxSqrtResults[2]);
root <-- approxSqrtResults[1];

var distanceResults[3];
distanceResults = approxSqrt(unboundDistanceSquaredMax);
var distanceMax = distanceResults[1]; // maxNum = 1414214n
var distanceMaxBits = maxBits(distanceMax);

component isPerfect = IsZero();
isPerfect.in <== (root**2) - squaredValue;
// signal perfectSquare <== isPerfect.out;
// log("isPerfect", isPerfect.out);

// perfectSquare is true, absDiff = 0
// OR
// if lo - mid == 0, absDiff = mid**2 - actual
// if hi - mid == 0, absDiff = actual - mid**2
component isZeroDiff2 = IsZero();
isZeroDiff2.in <== approxSqrtResults[0] - approxSqrtResults[1]; // lo - mid

// need to constrain that if isZeroDiff2 is not 0 then hi - mid is 0
component isZeroDiff3 = IsZero();
isZeroDiff3.in <== approxSqrtResults[2] - approxSqrtResults[1]; // hi - mid

// firstCondition is XOR
// (isZeroDiff2 == 1 AND isZeroDiff3 == 0) OR (isZeroDiff2 == 0 AND isZeroDiff3 == 1)
// secondCondition
// OR (isPerfect = 1)

component firstCondition = XOR();
firstCondition.a <== isZeroDiff2.out;
firstCondition.b <== isZeroDiff3.out;

// one must be true;
component secondCondition = OR();
secondCondition.a <== firstCondition.out;
secondCondition.b <== isPerfect.out;
secondCondition.out === 1;

component diffMux = Mux1();
diffMux.c[0] <== (approxSqrtResults[1] ** 2) - squaredValue; // mid**2 - actual
diffMux.c[1] <== squaredValue - (approxSqrtResults[1] ** 2); // actual - mid**2
diffMux.s <== isZeroDiff2.out;
signal imperfectDiff <== diffMux.out;

// difference is 0 if perfect square is true
component diffMux2 = Mux1();
diffMux2.c[0] <== imperfectDiff;
diffMux2.c[1] <== 0;
diffMux2.s <== isPerfect.out;
signal diff <== diffMux2.out;

var distanceMaxDoubleMax = distanceMax*2; // maxNum: 2,828,428
var distanceMaxSquaredMaxBits = maxBits(distanceMaxDoubleMax); // maxBits: 22
component lessThan2 = LessEqThan(distanceMaxSquaredMaxBits);
lessThan2.in[0] <== diff;
lessThan2.in[1] <== root*2; // maxBits: 22 (maxNum: 2_828_428)
// diff must be less than root*2 as the acceptable margin of error
lessThan2.out === 1;
}


Expand All @@ -63,7 +170,6 @@ template AcceptableMarginOfError (n) {
signal input marginOfError;
signal output out;


// The following is to ensure diff = Abs(actual - expected)
component absoluteValueSubtraction = AbsoluteValueSubtraction(n);
absoluteValueSubtraction.in[0] <== expected;
Expand All @@ -80,7 +186,7 @@ template AbsoluteValueSubtraction (n) {
signal input in[2];
signal output out;

component lessThan = LessThan(n); // TODO: test limits of squares
component lessThan = LessThan(n);
lessThan.in[0] <== in[0];
lessThan.in[1] <== in[1];
signal lessThanResult <== lessThan.out;
Expand Down
Loading

0 comments on commit 962f375

Please sign in to comment.