diff --git a/fn/rational.fn b/fn/rational.fn index fbb09ea..3c1511f 100644 --- a/fn/rational.fn +++ b/fn/rational.fn @@ -5,10 +5,12 @@ print(#( "4 + 2 / 3 + 4", 4 + 2 / 3 + 4 )); print(#( "4 + 2 / -3 + 4", 4 + 2 / -3 + 4 )); print(#( "2 / 3 + 4 / 5", 2 / 3 + 4 / 5 )); print(#( "(2 / 3) * (4 / 5)", (2 / 3) * (4 / 5) )); -print(#( "(2 / 3) / (4 / 5)", (2 / 3) / (4 / 5) )); +print(#( "(6 / 5)", (6 / 5) )); +print(#( "(2 / 3) / (6 / 5)", (2 / 3) / (6 / 5) )); print(#( "(2 / 3) % (4 / 5)", (2 / 3) % (4 / 5) )); print(#( "1914882942 ** 10", 1914882942 ** 10 )); print(#( "1/3 % 8", 1/3 % 8 )); print(#( "-9 % 8", -9 % 8 )); +print(#( "(1/2) ** 2", (1/2) ** 2 )); print(#( "(1914882942 ** 5 / 5) % (2 / 3)", (1914882942 ** 5 / 5) % (2 / 3) )) diff --git a/src/arithmetic.c b/src/arithmetic.c index c862a05..063b561 100644 --- a/src/arithmetic.c +++ b/src/arithmetic.c @@ -59,6 +59,11 @@ static Value One = { .val = VALUE_VAL_STDINT(1) }; +static Value Zero = { + .type = VALUE_TYPE_STDINT, + .val = VALUE_VAL_STDINT(0) +}; + #ifdef DEBUG_ARITHMETIC static void ppNumber(Value number) { switch (number.type) { @@ -98,6 +103,12 @@ static Value intValue(int i) { # define ASSERT_STDINT(x) #endif +static int littleCmp(Value left, Value right) { + ASSERT_STDINT(left); + ASSERT_STDINT(right); + return left.val.stdint < right.val.stdint ? -1 : left.val.stdint == right.val.stdint ? 0 : 1; +} + static Value littleAdd(Value left, Value right) { ASSERT_STDINT(left); ASSERT_STDINT(right); @@ -119,6 +130,9 @@ static Value littleSub(Value left, Value right) { static Value littleDivide(Value left, Value right) { ASSERT_STDINT(left); ASSERT_STDINT(right); + if (littleCmp(right, Zero) == 0) { + cant_happen("attempted div zero"); + } return intValue(left.val.stdint / right.val.stdint); } @@ -131,6 +145,9 @@ static Value littlePower(Value left, Value right) { static Value littleModulo(Value left, Value right) { ASSERT_STDINT(left); ASSERT_STDINT(right); + if (littleCmp(right, Zero) == 0) { + cant_happen("attempted mod zero"); + } return intValue(left.val.stdint % right.val.stdint); } @@ -153,12 +170,6 @@ static Value littleGcd(Value left, Value right) { return intValue(gcd(left.val.stdint, right.val.stdint)); } -static int littleCmp(Value left, Value right) { - ASSERT_STDINT(left); - ASSERT_STDINT(right); - return left.val.stdint < right.val.stdint ? -1 : left.val.stdint == right.val.stdint ? 0 : 1; -} - static void littleNeg(Value v) { ASSERT_STDINT(v); v.val.stdint = -v.val.stdint; @@ -179,6 +190,14 @@ static bool littleIsNeg(Value v) { # define ASSERT_BIGINT(x) #endif +static int bigCmp(Value left, Value right) { + ENTER(bigCmp); + ASSERT_BIGINT(left); + ASSERT_BIGINT(right); + LEAVE(bigCmp); + return cmpBigInt(left.val.bigint, right.val.bigint); +} + static Value bigIntValue(BigInt *i) { Value val; val.type = VALUE_TYPE_BIGINT; @@ -232,6 +251,9 @@ static Value bigDivide(Value left, Value right) { IFDEBUG(ppNumber(right)); ASSERT_BIGINT(left); ASSERT_BIGINT(right); + if (bigCmp(right, Zero) == 0) { + cant_happen("attempted div zero"); + } BigInt *result = divBigInt(left.val.bigint, right.val.bigint); int save = PROTECT(result); Value res = bigIntValue(result); @@ -258,6 +280,9 @@ static Value bigModulo(Value left, Value right) { ENTER(bigModulo); ASSERT_BIGINT(left); ASSERT_BIGINT(right); + if (bigCmp(right, Zero) == 0) { + cant_happen("attempted mod zero"); + } BigInt *result = modBigInt(left.val.bigint, right.val.bigint); int save = PROTECT(result); Value res = bigIntValue(result); @@ -278,14 +303,6 @@ static Value bigGcd(Value left, Value right) { return res; } -static int bigCmp(Value left, Value right) { - ENTER(bigCmp); - ASSERT_BIGINT(left); - ASSERT_BIGINT(right); - LEAVE(bigCmp); - return cmpBigInt(left.val.bigint, right.val.bigint); -} - static void bigNeg(Value v) { ASSERT_BIGINT(v); negateBigInt(v.val.bigint); @@ -499,6 +516,52 @@ static Value ratModulo(Value left, Value right) { return res; } +static Value _ratPower(Value left, Value right) { + ENTER(_ratPower); + ASSERT_RATIONAL(left); + Value numerator = left.val.vec->values[NUMERATOR]; + Value denominator = left.val.vec->values[DENOMINATOR]; + numerator = int_power(numerator, right); + int save = protectValue(numerator); + denominator = int_power(denominator, right); + protectValue(denominator); + Value res = ratSimplify(numerator, denominator); + protectValue(res); + LEAVE(_ratPower); + IFDEBUG(ppNumber(res)); + UNPROTECT(save); + return res; +} + +static Value ratPower(Value left, Value right) { + ENTER(ratPower); + LEAVE(ratPower); + IFDEBUG(ppNumber(left)); + IFDEBUG(ppNumber(right)); + Value res; + int save = protectValue(left); + protectValue(right); + if (left.type == VALUE_TYPE_RATIONAL) { + if (right.type == VALUE_TYPE_RATIONAL) { + cant_happen("raising numbers to a rational power not supported yet"); + } else { + // only left rational + res = _ratPower(left, right); + protectValue(res); + } + } else if (right.type == VALUE_TYPE_RATIONAL) { + cant_happen("raising numbers to a rational power not supported yet"); + } else { + // neither rational + res = int_power(left, right); + protectValue(res); + } + LEAVE(ratPower); + IFDEBUG(ppNumber(res)); + UNPROTECT(save); + return res; +} + void init_arithmetic() { if (bigint_flag) { int_add = bigAdd; @@ -511,6 +574,9 @@ void init_arithmetic() { int_cmp = bigCmp; int_neg = bigNeg; int_isneg = bigIsNeg; + BigInt *zero = bigIntFromInt(0); + Zero.type = VALUE_TYPE_BIGINT; + Zero.val = VALUE_VAL_BIGINT(zero); BigInt *one = bigIntFromInt(1); One.type = VALUE_TYPE_BIGINT; One.val = VALUE_VAL_BIGINT(one); @@ -532,7 +598,7 @@ void init_arithmetic() { sub = ratSub; mul = ratMul; divide = ratDivide; - power = int_power; + power = ratPower; modulo = ratModulo; } else { add = int_add; @@ -545,5 +611,6 @@ void init_arithmetic() { } void markArithmetic() { + markValue(Zero); markValue(One); }