Skip to content

Commit

Permalink
finishd sqrt and div
Browse files Browse the repository at this point in the history
  • Loading branch information
okwme committed Oct 26, 2023
1 parent 3632e07 commit d0fb991
Show file tree
Hide file tree
Showing 17 changed files with 288 additions and 174 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,18 @@ You should see something like:

Currently the project is targeting [powersOfTau28_hez_final_20.ptau](https://github.com/iden3/snarkjs/blob/master/README.md#7-prepare-phase-2) which has a limit of 1MM constraints. Below is a table of the number of constraints used by each circuit.

| Circuit | Non-Linear Constraints |
|---------|-------------|
| absoluteValueSubtraction(252) | 259 |
| acceptableMarginOfError(60) | 128 |
| calculateForce() | 717 |
| detectCollision(3) | 510 |
| forceAccumulator(3) | 2821 |
| getDistance(20) | 142 |
| limiter(252) | 257 |
| lowerLimiter(252) | 257 |
| nft(3, 10) | 28039 |
| stepState(3, 10) | 33531 |
| Circuit | Non-Linear Constraints | seconds at 25fps under 1MM constraints |
|---------|-------------|---------------------------------------------------|
| absoluteValueSubtraction(252) | 257 | 155.64 |
| acceptableMarginOfError(60) | 125 | 320 |
| calculateForce() | 279 | 143.37 |
| detectCollision(3) | 348 | 114.94 |
| forceAccumulator(3) | 1522 | 26.28 |
| getDistance(20) | 88 | 454.55 |
| limiter(252) | 254 | 157.48|
| lowerLimiter(252) | 254 | 157.48|
| nft(3, 10) | 15184 | 26.34 |
| stepState(3, 10) | 19121 | 20.92 |

# built using circom-starter

Expand Down
168 changes: 127 additions & 41 deletions circuits/approxMath.circom
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pragma circom 2.1.6;
include "../node_modules/circomlib/circuits/mux1.circom";
include "../node_modules/circomlib/circuits/comparators.circom";
include "../node_modules/circomlib/circuits/gates.circom";
include "helpers.circom";

function approxSqrt(n) {
if (n == 0) {
Expand Down Expand Up @@ -28,53 +30,138 @@ function approxSqrt(n) {
return [lo, mid, hi];
}


function approxDiv(dividend, divisor) {
var bitsDivident = 0;
var dividendCopy = dividend;
while(dividendCopy > 0) {
bitsDivident++;
dividendCopy = dividendCopy >> 1;
if (dividend == 0) {
return 0;
}

// Create internal signals for our binary search
var lowerBound, upperBound, midPoint, testProduct;
var lo, hi, mid, testProduct;

// Initialize our search space
lowerBound = 0;
upperBound = dividend; // Assuming worst case where divisor = 1

for (var i = 0; i < bitsDivident; i++) { // 32 iterations for 32-bit numbers as an example
midPoint = (upperBound + lowerBound) >> 1;
testProduct = midPoint * divisor;
// Adjust our bounds based on the test product
if (testProduct > dividend) {
upperBound = midPoint;
} else {
lowerBound = midPoint;
}
lo = 0;
hi = dividend; // Assuming worst case where divisor = 1

while (lo < hi) { // 32 iterations for 32-bit numbers as an example
mid = (hi + lo + 1) >> 1;
testProduct = mid * divisor;

// Adjust our bounds based on the test product
if (testProduct > dividend) {
hi = mid - 1;
} else {
lo = mid;
}
}

// Output the midpoint as our approximated quotient after iterations
return midPoint;
// Output the lo as our approximated quotient after iterations
// quotient <== lo;
return lo;
}

template Div() {
signal input dividend;
signal input divisor;
signal output quotient;

quotient <-- approxDiv(dividend, divisor); // maxBits: 64 (maxNum: 10_400_000_000_000_000_000)

// NOTE: the following constraints the approxDiv to ensure it's within the acceptable error of margin
signal approxNumerator1 <== quotient * divisor; // maxBits: 126 (maxNum: 58_831_302_400_000_000_000_000_000_000_000_000_000)

// NOTE: approxDiv always rounds down so the approximate quotient will always be less
// than the actual quotient.
signal diff <== dividend - quotient;
// log("diff ", diff);
// log("dividend", dividend);
// log("divisor", divisor);
// log("quotient", quotient);

component lessThan = LessThan(64); // forceXnum; // maxBits: 64
lessThan.in[0] <== diff;
lessThan.in[1] <== dividend;
// log("lessThan", lessThan.out, "\n");

component isZero = IsZero();
isZero.in <== dividend;

component xor = XOR();
xor.a <== isZero.out;
xor.b <== lessThan.out;
xor.out === 1;
}

// template AbsoulteValueSubtraction(maxDiffMaxBits) {
// signal input in[2];
// signal input maxDiff;
// signal output out;

// signal diffA = in[0] - in[1];
// signal diffAOffset = diffA + maxDiff;

// signal diffB = in[1] - in[0];
// signal diffBOffset <== diffB + maxDiff;

// diffAOffset - diffBOffset - in[0]

// signal isZero = IsZero();



// }
template Sqrt(unboundDistanceSquaredMax) {
signal input squaredValue;
signal output root;
signal approxSqrtResults[3];
approxSqrtResults <-- approxSqrt(squaredValue);
// approxSqrtResults[0] = lo
// approxSqrtResults[1] = mid
// approxSqrtResults[2] = hi
// log("squaredValue", squaredValue);
// log("approxSqrtResults[0]", approxSqrtResults[0]);
// log("approxSqrtResults[1]", approxSqrtResults[1]);
// log("approxSqrtResults[2]", approxSqrtResults[2]);
root <-- approxSqrtResults[1];

var distanceResults[3];
distanceResults = approxSqrt(unboundDistanceSquaredMax);
var distanceMax = distanceResults[1]; // maxNum = 1414214n
var distanceMaxBits = maxBits(distanceMax);

component isPerfect = IsZero();
isPerfect.in <== (root**2) - squaredValue;
// signal perfectSquare <== isPerfect.out;
// log("isPerfect", isPerfect.out);

// perfectSquare is true, absDiff = 0
// OR
// if lo - mid == 0, absDiff = mid**2 - actual
// if hi - mid == 0, absDiff = actual - mid**2
component isZeroDiff2 = IsZero();
isZeroDiff2.in <== approxSqrtResults[0] - approxSqrtResults[1]; // lo - mid

// need to constrain that if isZeroDiff2 is not 0 then hi - mid is 0
component isZeroDiff3 = IsZero();
isZeroDiff3.in <== approxSqrtResults[2] - approxSqrtResults[1]; // hi - mid

// firstCondition is XOR
// (isZeroDiff2 == 1 AND isZeroDiff3 == 0) OR (isZeroDiff2 == 0 AND isZeroDiff3 == 1)
// secondCondition
// OR (isPerfect = 1)

component firstCondition = XOR();
firstCondition.a <== isZeroDiff2.out;
firstCondition.b <== isZeroDiff3.out;

// one must be true;
component secondCondition = OR();
secondCondition.a <== firstCondition.out;
secondCondition.b <== isPerfect.out;
secondCondition.out === 1;

component diffMux = Mux1();
diffMux.c[0] <== (approxSqrtResults[1] ** 2) - squaredValue; // mid**2 - actual
diffMux.c[1] <== squaredValue - (approxSqrtResults[1] ** 2); // actual - mid**2
diffMux.s <== isZeroDiff2.out;
signal imperfectDiff <== diffMux.out;

// difference is 0 if perfect square is true
component diffMux2 = Mux1();
diffMux2.c[0] <== imperfectDiff;
diffMux2.c[1] <== 0;
diffMux2.s <== isPerfect.out;
signal diff <== diffMux2.out;

var distanceMaxDoubleMax = distanceMax*2; // maxNum: 2,828,428
var distanceMaxSquaredMaxBits = maxBits(distanceMaxDoubleMax); // maxBits: 22
component lessThan2 = LessEqThan(distanceMaxSquaredMaxBits);
lessThan2.in[0] <== diff;
lessThan2.in[1] <== root*2; // maxBits: 22 (maxNum: 2_828_428)
// diff must be less than root*2 as the acceptable margin of error
lessThan2.out === 1;
}


template AcceptableMarginOfError (n) {
Expand All @@ -83,7 +170,6 @@ template AcceptableMarginOfError (n) {
signal input marginOfError;
signal output out;


// The following is to ensure diff = Abs(actual - expected)
component absoluteValueSubtraction = AbsoluteValueSubtraction(n);
absoluteValueSubtraction.in[0] <== expected;
Expand All @@ -100,7 +186,7 @@ template AbsoluteValueSubtraction (n) {
signal input in[2];
signal output out;

component lessThan = LessThan(n); // TODO: test limits of squares
component lessThan = LessThan(n);
lessThan.in[0] <== in[0];
lessThan.in[1] <== in[1];
signal lessThanResult <== lessThan.out;
Expand Down
87 changes: 20 additions & 67 deletions circuits/calculateForce.circom
Original file line number Diff line number Diff line change
Expand Up @@ -134,60 +134,19 @@ template CalculateForce() {
myMux.c[0] <== unboundDistanceSquared; // maxBits: 41 (maxNum: 2_000_000_000_000)
myMux.c[1] <== minDistanceScaled; // maxBits: 36 (maxNum: 40_000_000_000)
myMux.s <== lessThan.out;
signal distanceSquared <== myMux.out; // maxBits: 41 (maxNum: 2_000_000_000_000)

signal approxSqrtResults[3];
approxSqrtResults <-- approxSqrt(distanceSquared);
// approxSqrtResults[0] = lo
// approxSqrtResults[1] = mid
// approxSqrtResults[2] = hi
signal distance <-- approxSqrtResults[1];
signal distanceSquared <== myMux.out; // maxBits: 41 (maxNum: 2_000_000_000_000)

component sqrt = Sqrt(unboundDistanceSquaredMax);
sqrt.squaredValue <== distanceSquared; // maxBits: 41 (maxNum: 2_000_000_000_000)
signal distance <== sqrt.root;


var distanceResults[3];
distanceResults = approxSqrt(unboundDistanceSquaredMax);
var distanceMax = distanceResults[1];
var distanceMax = distanceResults[1]; // maxNum = 1_414_214
var distanceMaxBits = maxBits(distanceMax);

component isPerfect = IsZero();
isPerfect.in <== (distance**2) - distanceSquared;
signal perfectSquare <== isPerfect.out;
// log("isPerfect", isPerfect.out);

// perfectSquare is true, absDiff = 0
// OR
// if lo - mid == 0, absDiff = mid**2 - actual
// if hi - mid == 0, absDiff = actual - mid**2
component isZeroDiff2 = IsZero();
isZeroDiff2.in <== approxSqrtResults[0] - approxSqrtResults[1]; // lo - mid

// need to constrain that if isZeroDiff2 is not 0 then hi - mid is 0
component isZeroDiff3 = IsZero();
isZeroDiff3.in <== approxSqrtResults[2] - approxSqrtResults[1]; // hi - mid

// one must be true;
isZeroDiff2.out + isZeroDiff3.out === 1;

component diffMux = Mux1();
diffMux.c[0] <== (approxSqrtResults[1] ** 2) - distanceSquared; // mid**2 - actual
diffMux.c[1] <== distanceSquared - (approxSqrtResults[1] ** 2); // actual - mid**2
diffMux.s <== isZeroDiff2.out;
signal imperfectDiff <== diffMux.out;

// difference is 0 if perfect square is true
component diffMux2 = Mux1();
diffMux2.c[0] <== imperfectDiff;
diffMux2.c[1] <== 0;
diffMux2.s <== perfectSquare;
signal diff <== diffMux2.out;

var distanceMaxSquaredMax = distanceMax**2;
var distanceMaxSquaredMaxBits = maxBits(distanceMaxSquaredMax);
log(distanceMaxSquaredMaxBits, " is 22 (distanceMaxSquaredMaxBits)");
component lessThan2 = LessEqThan(distanceMaxSquaredMaxBits);
lessThan2.in[0] <== diff;
lessThan2.in[1] <== distance*2; // maxBits: 22 (maxNum: 2_828_428)
// diff must be less than distance*2 as the acceptable margin of error
lessThan2.out === 1;

// NOTE: this could be tweaked as a variable for "liveliness" of bodies
signal bodies_sum_tmp <== (body1_radius + body2_radius) * 4; // maxBits: 17 (maxNum: 104_000)

Expand Down Expand Up @@ -224,15 +183,11 @@ template CalculateForce() {

signal forceXnum <== dxAbs * forceMag_numerator; // maxBits: 64 (maxNum: 10_400_000_000_000_000_000)
// log("forceXnum", forceXnum);
signal forceXunsigned <-- approxDiv(forceXnum, forceDenom); // maxBits: 64 (maxNum: 10_400_000_000_000_000_000)
// log("forceXunsigned", forceXunsigned);
// NOTE: the following constraints the approxDiv to ensure it's within the acceptable error of margin
signal approxNumerator1 <== forceXunsigned * forceDenom; // maxBits: 126 (maxNum: 58_831_302_400_000_000_000_000_000_000_000_000_000)
component acceptableMarginOfErrorDiv1 = AcceptableMarginOfError(126);
acceptableMarginOfErrorDiv1.expected <== forceXnum; // maxBits: 64
acceptableMarginOfErrorDiv1.actual <== approxNumerator1; // maxBits: 126
acceptableMarginOfErrorDiv1.marginOfError <== forceDenom; // TODO: actually could be further reduced to (realDenom / 2) + 1 but then we're using division again
acceptableMarginOfErrorDiv1.out === 1;
component div1 = Div();
div1.dividend <== forceXnum;
div1.divisor <== forceDenom;
signal forceXunsigned <== div1.quotient;


// if dxAbs + dx is 0, then forceX should be negative
component isZero3 = IsZero();
Expand All @@ -244,15 +199,13 @@ template CalculateForce() {

signal forceYnum <== dyAbs * forceMag_numerator; // maxBits:64 (maxNum: 10_400_000_000_000_000_000)
// log("forceYnum", forceYnum);
signal forceYunsigned <-- approxDiv(forceYnum, forceDenom); // maxBits: 64 (maxNum: 10_400_000_000_000_000_000)
// log("forceYunsigned", forceYunsigned);
// NOTE: the following constraints the approxDiv to ensure it's within the acceptable error of margin
signal approxNumerator2 <== forceYunsigned * forceDenom; // maxBits: 126 (maxNum: 58_831_302_400_000_000_000_000_000_000_000_000_000)
component acceptableMarginOfErrorDiv2 = AcceptableMarginOfError(126);
acceptableMarginOfErrorDiv2.expected <== forceYnum; // maxBits: 64
acceptableMarginOfErrorDiv2.actual <== approxNumerator2; // maxBits: 126
acceptableMarginOfErrorDiv2.marginOfError <== forceDenom; // TODO: actually could be further reduced to (realDenom / 2) + 1 but then we're using division again
acceptableMarginOfErrorDiv2.out === 1;

// // NOTE: the following component uses approxDiv then ensures it's within an
// acceptable margin of error
component div2 = Div();
div2.dividend <== forceYnum;
div2.divisor <== forceDenom;
signal forceYunsigned <== div2.quotient;

// if dyAbs + dy is 0, then forceY should be negative
component isZero4 = IsZero();
Expand Down
28 changes: 20 additions & 8 deletions circuits/getDistance.circom
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
pragma circom 2.1.6;

include "approxMath.circom";
include "helpers.circom";

template GetDistance(n) {
signal input x1; // maxBits: 20 (maxNum: 1_000_000) = windowWidthScaled
signal input y1; // maxBits: 20 (maxNum: 1_000_000) = windowWidthScaled
signal input x2; // maxBits: 20 (maxNum: 1_000_000) = windowWidthScaled
signal input y2; // maxBits: 20 (maxNum: 1_000_000) = windowWidthScaled
var scalingFactorFactor = 3;
var scalingFactor = 10 ** scalingFactorFactor;
var windowWidth = 1000;
var windowWidthScaled = windowWidth * scalingFactor;
var positionMax = windowWidthScaled;
signal output distance;

// signal dx <== x2 - x1;
Expand All @@ -22,15 +28,21 @@ template GetDistance(n) {
signal dyAbs <== absoluteValueSubtraction2.out; // maxBits: 20 (maxNum: 1_000_000) = windowWidthScaled

signal dxs <== dxAbs * dxAbs; // maxBits: 40 = 20 * 2 (maxNum: 1_000_000_000_000)
var dxsMax = positionMax ** 2;
signal dys <== dyAbs * dyAbs; // maxBits: 40 = 20 * 2 (maxNum: 1_000_000_000_000)
signal distanceSquared <== dxs + dys; // maxBits: 41 = 40 + 1 (maxNum: 2_000_000_000_000)
var distanceSquaredMax = dxsMax + dxsMax;

// NOTE: confirm this is correct
distance <-- approxSqrt(distanceSquared); // maxBits: 21 (maxNum: 1_414_214) ~= 41 / 2 + 2
component acceptableMarginOfError = AcceptableMarginOfError((2 * n) + 1);
acceptableMarginOfError.expected <== distance ** 2; // maxBits: 41 (maxNum: 2_000_001_237_796) ~= 21 * 2
acceptableMarginOfError.actual <== distanceSquared; // maxBits: 41
// margin of error should be midpoint between squares
acceptableMarginOfError.marginOfError <== distance * 2; // maxBits: 22 (maxNum: 2_828_428)
acceptableMarginOfError.out === 1;
component sqrt = Sqrt(distanceSquaredMax);
sqrt.squaredValue <== distanceSquared;
distance <== sqrt.root;

// // NOTE: confirm this is correct
// distance <-- approxSqrt(distanceSquared); // maxBits: 21 (maxNum: 1_414_214) ~= 41 / 2 + 2
// component acceptableMarginOfError = AcceptableMarginOfError((2 * n) + 1);
// acceptableMarginOfError.expected <== distance ** 2; // maxBits: 41 (maxNum: 2_000_001_237_796) ~= 21 * 2
// acceptableMarginOfError.actual <== distanceSquared; // maxBits: 41
// // margin of error should be midpoint between squares
// acceptableMarginOfError.marginOfError <== distance * 2; // maxBits: 22 (maxNum: 2_828_428)
// acceptableMarginOfError.out === 1;
}
Loading

0 comments on commit d0fb991

Please sign in to comment.