diff --git a/src/Bit255.test.ts b/src/Bit255.test.ts index d04373b..b4ee438 100644 --- a/src/Bit255.test.ts +++ b/src/Bit255.test.ts @@ -68,4 +68,10 @@ describe('Bit255', () => { let converted = Bit255.fromFields(b.toFields()); expect(b.toBigInt()).toEqual(converted.toBigInt()); }); + + it('Should convert between bits correctly', async () => { + let b = Bit255.fromScalar(Scalar.random()); + let converted = Bit255.fromBits(b.toBits()); + expect(b.toBigInt()).toEqual(converted.toBigInt()); + }); }); diff --git a/src/Bit255.ts b/src/Bit255.ts index 6503cb5..8b4cbdc 100644 --- a/src/Bit255.ts +++ b/src/Bit255.ts @@ -1,19 +1,31 @@ import { Bool, Field, Gadgets, Poseidon, Scalar, Struct } from 'o1js'; +const MASK_HEAD = (1n << 128n) - 1n; +const MASK_TAIL = (1n << 127n) - 1n; + // WARNING - Convert between Scalar and Bit255 does not preserve bigint value export class Bit255 extends Struct({ head: Field, tail: Field, }) { static fromScalar(scalar: Scalar): Bit255 { + let bits = [scalar.lowBit].concat(scalar.high254.toBits()); return new Bit255({ - head: scalar.toFields()[0], - tail: scalar.toFields()[1], + head: Field.fromBits(bits.slice(0, 128)), + tail: Field.fromBits(bits.slice(128)), }); } static toScalar(b: Bit255): Scalar { - return Scalar.fromFields([b.head, b.tail]); + return Scalar.fromFields([ + b.head.toBits()[0].toField(), + Field.fromBits( + b.head + .toBits() + .slice(1, 128) + .concat(b.tail.toBits().slice(0, 127)) + ), + ]); } static fromFields(fields: Field[]): Bit255 { @@ -24,7 +36,7 @@ export class Bit255 extends Struct({ } static toFields(value: { head: Field; tail: Field }): Field[] { - return [value.head].concat([value.tail]); + return [value.head, value.tail]; } static sizeInFields(): number { @@ -33,40 +45,35 @@ export class Bit255 extends Struct({ static xor(a: Bit255, b: Bit255): Bit255 { return new Bit255({ - head: Gadgets.xor(a.head, b.head, 1), - tail: Gadgets.xor(a.tail, b.tail, 254), + head: Gadgets.xor(a.head, b.head, 128), + tail: Gadgets.xor(a.tail, b.tail, 127), }); } static fromBits(bits: Bool[]): Bit255 { if (bits.length !== 255) throw new Error('Invalid input length'); return new Bit255({ - head: Field.fromBits(bits.slice(0, 1)), - tail: Field.fromBits(bits.slice(1)), + head: Field.fromBits(bits.slice(0, 128)), + tail: Field.fromBits(bits.slice(128)), }); } + static toBits(b: Bit255): Bool[] { + return b.head + .toBits() + .slice(0, 128) + .concat(b.tail.toBits().slice(0, 127)); + } + static fromBigInt(i: bigint): Bit255 { - let bits = new Array(255).fill(Bool(false)); - let index = 0; - while (i > 0) { - bits[index] = Bool(i % 2n == 1n); - i = i / 2n; - index += 1; - } - return Bit255.fromBits(bits); + return new Bit255({ + head: Field((i >> 127n) & MASK_HEAD), + tail: Field(i & MASK_TAIL), + }); } static toBigInt(b: Bit255): bigint { - let bits = b.head - .toBits() - .slice(0, 1) - .concat(b.tail.toBits().slice(0, 254)); - let res = 0n; - for (let i = 0; i < 255; i++) { - if (bits[i].toBoolean()) res += BigInt(Math.pow(2, i)); - } - return res; + return (b.head.toBigInt() << 127n) + b.tail.toBigInt(); } equals(b: Bit255): Bool { @@ -81,10 +88,6 @@ export class Bit255 extends Struct({ return Poseidon.hash(this.toFields()); } - fromFields(fields: Field[]): Bit255 { - return Bit255.fromFields(fields); - } - toFields(): Field[] { return Bit255.toFields(this); } @@ -100,4 +103,8 @@ export class Bit255 extends Struct({ toBigInt(): bigint { return Bit255.toBigInt(this); } + + toBits(): Bool[] { + return Bit255.toBits(this); + } }