Skip to content

Commit

Permalink
Added dist sum, product, cumsum, cumprod, diff, to Number.
Browse files Browse the repository at this point in the history
  • Loading branch information
OAGr committed Oct 26, 2023
1 parent bda1431 commit 6043164
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 36 deletions.
18 changes: 18 additions & 0 deletions packages/squiggle-lang/__tests__/library/number_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ describe("Numbers", () => {
testEvalToBe("Number.log2(10)", "3.321928094887362");
testEvalToBe("Number.sum([2,5,3])", "10");
testEvalToBe("sum([2,5,3])", "10");
testEvalToBe("sum([2,Dist.make(2),Dist.make(4)])", "PointMass(8)");
testEvalToBe("sum([2,Dist.make(2),2 to 10])", "Sample Set Distribution");
testEvalToBe("Number.product([2,5,3])", "30");
testEvalToBe(
"Number.product([Dist.make(2),Dist.make(5),Dist.make(3)])",
"PointMass(30)"
);
testEvalToBe("Number.min([2,5,3])", "2");
testEvalToBe("Number.max([2,5,3])", "5");
testEvalToBe("Number.mean([0,5,10])", "5");
Expand All @@ -21,6 +27,18 @@ describe("Numbers", () => {
testEvalToBe("Number.variance([0,5,10,15])", "31.25");
testEvalToBe("Number.sort([10,0,15,5])", "[0,5,10,15]");
testEvalToBe("Number.cumsum([1,5,3])", "[1,6,9]");
testEvalToBe(
"Number.cumsum([Dist.make(1), Dist.make(5), Dist.make(3)])",
"[PointMass(1),PointMass(6),PointMass(9)]"
);
testEvalToBe("Number.cumprod([1,5,3])", "[1,5,15]");
testEvalToBe(
"Number.cumprod([Dist.make(1),Dist.make(5),Dist.make(3)])",
"[PointMass(1),PointMass(5),PointMass(15)]"
);
testEvalToBe("Number.diff([1,5,3])", "[4,-2]");
testEvalToBe(
"Number.diff([Dist.make(1),Dist.make(5),Dist.make(3)])",
"[PointMass(4),PointMass(-2)]"
);
});
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export const binaryOperations = {
pointwisePower: (t1, t2, { env }) => pointwise(t1, t2, env, "Power"),
} satisfies { [k: string]: BinaryOperation };

export const algebraicAddMany = (dists: BaseDist[], env: Env): DistResult =>
export const algebraicSum = (dists: BaseDist[], env: Env): DistResult =>
dists.reduce<DistResult>(
(accumulatedDist, currentDist) =>
bind(accumulatedDist, (aVal) =>
Expand All @@ -66,16 +66,13 @@ export const algebraicAddMany = (dists: BaseDist[], env: Env): DistResult =>
Ok(new PointMass(0))
);

export const algebraicMultiplyMany = (
dists: BaseDist[],
env: Env
): DistResult =>
export const algebraicProduct = (dists: BaseDist[], env: Env): DistResult =>
dists.reduce<DistResult>(
(accumulatedDist, currentDist) =>
bind(accumulatedDist, (aVal) =>
binaryOperations.algebraicMultiply(aVal, currentDist, { env })
),
Ok(new PointMass(2))
Ok(new PointMass(1))
);

export const algebraicCumSum = (dists: BaseDist[], env: Env): BaseDist[] =>
Expand Down
93 changes: 64 additions & 29 deletions packages/squiggle-lang/src/fr/number.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import { BaseDist } from "../dist/BaseDist.js";
import * as SymbolicDist from "../dist/SymbolicDist.js";
import { binaryOperations } from "../dist/distOperations/binaryOperations.js";
import {
algebraicSum,
algebraicCumDiff,
algebraicCumProd,
algebraicCumSum,
algebraicProduct,
} from "../dist/distOperations/binaryOperations.js";
import { REArgumentError } from "../errors/messages.js";
import { makeDefinition } from "../library/registry/fnDefinition.js";
import {
Expand All @@ -10,6 +17,7 @@ import {
} from "../library/registry/frTypes.js";
import { FnFactory } from "../library/registry/helpers.js";
import * as E_A_Floats from "../utility/E_A_Floats.js";
import { getExt } from "../utility/result.js";
import { NumericRangeDomain } from "../value/domain.js";
import { vArray, vDomain, vNumber, vDist } from "../value/index.js";

Expand Down Expand Up @@ -42,6 +50,14 @@ function makeNumberArrayToNumberArrayDefinition(
});
}

function distOrNumberToDistOrNumber(d: BaseDist | number): BaseDist {
if (typeof d == "number") {
return getExt(SymbolicDist.PointMass.make(d));
} else {
return d;
}
}

export const library = [
maker.n2n({
name: "floor",
Expand Down Expand Up @@ -91,31 +107,13 @@ export const library = [
examples: [`sum([3,5,2])`],
definitions: [
makeNumberArrayToNumberDefinition((arr) => E_A_Floats.sum(arr)),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) => {
const d = dists.map((dist) => {
if (typeof dist == "number") {
const result = SymbolicDist.PointMass.make(dist);
if (result.ok) {
return result.value;
} else {
throw new Error(result.value);
}
} else {
return dist;
}
});
const ee = d.reduce((a, b) => {
const result = binaryOperations.algebraicAdd(a, b, {
env: environment,
});
if (result.ok) {
return result.value;
} else {
throw new Error("FAIL");
}
}, new SymbolicDist.PointMass(0));
return vDist(ee);
}),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) =>
vDist(
getExt(
algebraicSum(dists.map(distOrNumberToDistOrNumber), environment)
)
)
),
],
}),
maker.make({
Expand All @@ -124,6 +122,13 @@ export const library = [
examples: [`product([3,5,2])`],
definitions: [
makeNumberArrayToNumberDefinition((arr) => E_A_Floats.product(arr)),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) =>
vDist(
getExt(
algebraicProduct(dists.map(distOrNumberToDistOrNumber), environment)
)
)
),
],
}),
maker.make({
Expand Down Expand Up @@ -185,20 +190,50 @@ export const library = [
output: "Array",
description: "cumulative sum",
examples: [`cumsum([3,5,2,3,5])`],
definitions: [makeNumberArrayToNumberArrayDefinition(E_A_Floats.cumSum)],
definitions: [
makeNumberArrayToNumberArrayDefinition(E_A_Floats.cumSum),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) =>
vArray(
algebraicCumSum(
dists.map(distOrNumberToDistOrNumber),
environment
).map((r) => vDist(r))
)
),
],
}),
maker.make({
name: "cumprod",
description: "cumulative product",
output: "Array",
examples: [`cumprod([3,5,2,3,5])`],
definitions: [makeNumberArrayToNumberArrayDefinition(E_A_Floats.cumProd)],
definitions: [
makeNumberArrayToNumberArrayDefinition(E_A_Floats.cumProd),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) =>
vArray(
algebraicCumProd(
dists.map(distOrNumberToDistOrNumber),
environment
).map((r) => vDist(r))
)
),
],
}),
maker.make({
name: "diff",
output: "Array",
examples: [`diff([3,5,2,3,5])`],
definitions: [makeNumberArrayToNumberArrayDefinition(E_A_Floats.diff)],
definitions: [
makeNumberArrayToNumberArrayDefinition(E_A_Floats.diff),
makeDefinition([frArray(frDistOrNumber)], ([dists], { environment }) =>
vArray(
algebraicCumDiff(
dists.map(distOrNumberToDistOrNumber),
environment
).map((r) => vDist(r))
)
),
],
}),
maker.make({
name: "rangeDomain",
Expand Down
2 changes: 1 addition & 1 deletion packages/squiggle-lang/src/utility/E_A.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export const accumulate = <A>(
fn: (x: A, y: A) => A
): A[] => {
const len = items.length;
const result = new Array(length);
const result = new Array(len);
for (let i = 0; i < len; i++) {
const element = items[i];
if (i === 0) {
Expand Down

0 comments on commit 6043164

Please sign in to comment.