diff --git a/packages/squiggle-lang/src/custom-types/index.d.ts b/packages/squiggle-lang/src/custom-types/index.d.ts index 297bfe52de..d924e4bcd0 100644 --- a/packages/squiggle-lang/src/custom-types/index.d.ts +++ b/packages/squiggle-lang/src/custom-types/index.d.ts @@ -81,4 +81,15 @@ declare module "jstat" { } export function factorial(x: number): number; + + export namespace binomial { + export function pdf(k: number, n: number, p: number): number; + export function cdf(k: number, n: number, p: number): number; + } + + export namespace poisson { + export function pdf(k: number, l: number): number; + export function cdf(x: number, l: number): number; + export function sample(l: number): number; + } } diff --git a/packages/squiggle-lang/src/dist/SampleSetDist/kde.ts b/packages/squiggle-lang/src/dist/SampleSetDist/kde.ts index 3186295a01..59955501d5 100644 --- a/packages/squiggle-lang/src/dist/SampleSetDist/kde.ts +++ b/packages/squiggle-lang/src/dist/SampleSetDist/kde.ts @@ -23,7 +23,9 @@ export const kde = ({ let xWidth = kernelWidth ?? nrd0(samples); samples = samples.filter((v) => Number.isFinite(v)); // Not sure if this is needed? const len = samples.length; - if (len === 0) return { usedWidth: xWidth, xs: [], ys: [] }; + + // It's not clear what to do when xWidth is zero. We might want to throw an error or otherwise instead. This was an issue for discrete distributions, like binomial. + if (len === 0 || xWidth === 0) return { usedWidth: xWidth, xs: [], ys: [] }; // Sample min and range const smin = samples[0]; diff --git a/packages/squiggle-lang/src/dist/SymbolicDist.ts b/packages/squiggle-lang/src/dist/SymbolicDist.ts index 365c39cb6b..edf13cee63 100644 --- a/packages/squiggle-lang/src/dist/SymbolicDist.ts +++ b/packages/squiggle-lang/src/dist/SymbolicDist.ts @@ -1,5 +1,6 @@ import { BaseDist } from "./BaseDist.js"; import * as Result from "../utility/result.js"; +import sum from "lodash/sum.js"; import jstat from "jstat"; import * as E_A_Floats from "../utility/E_A_Floats.js"; import * as XYShape from "../XYShape.js"; @@ -8,7 +9,7 @@ import * as Operation from "../operation.js"; import { PointSetDist } from "./PointSetDist.js"; import { Ok, result } from "../utility/result.js"; import { ContinuousShape } from "../PointSet/Continuous.js"; -import { DistError, xyShapeDistError } from "./DistError.js"; +import { DistError, notYetImplemented, xyShapeDistError } from "./DistError.js"; import { OperationError } from "../operationError.js"; import { DiscreteShape } from "../PointSet/Discrete.js"; import { Env } from "./env.js"; @@ -1070,6 +1071,123 @@ export class PointMass extends SymbolicDist { } } +export class Binomial extends SymbolicDist { + constructor( + public n: number, + public p: number + ) { + super(); + } + toString() { + return `Binomial(${this.n},{${this.p}})`; + } + static make(n: number, p: number): result { + if (!Number.isInteger(n) || n < 0) { + return Result.Err( + "The number of trials (n) must be a non-negative integer." + ); + } + if (p < 0 || p > 1) { + return Result.Err("Binomial p must be between 0 and 1"); + } + return Ok(new Binomial(n, p)); + } + + simplePdf(x: number) { + return jstat.binomial.pdf(x, this.n, this.p); + } + + cdf(k: number) { + return jstat.binomial.cdf(k, this.n, this.p); + } + + // Not needed, until we support Sym.Binomial + inv(p: number): number { + throw notYetImplemented(); + } + + mean() { + return this.n * this.p; + } + + variance(): result { + return Ok(this.n * this.p * (1 - this.p)); + } + + sample() { + const bernoulli = Bernoulli.make(this.p); + if (bernoulli.ok) { + // less space efficient than adding a bunch of draws, but cleaner. Taken from Guesstimate. + return sum( + Array.from({ length: this.n }, () => bernoulli.value.sample()) + ); + } else { + throw new Error("Binomial sampling failed"); + } + } + + _isEqual(other: Binomial) { + return this.n === other.n && this.p === other.p; + } + + // Not needed, until we support Sym.Binomial + override toPointSetDist(): result { + return Result.Err(notYetImplemented()); + } +} + +export class Poisson extends SymbolicDist { + constructor(public lambda: number) { + super(); + } + toString() { + return `Poisson(${this.lambda}})`; + } + static make(lambda: number): result { + if (lambda <= 0) { + throw new Error( + "Lambda must be a positive number for a Poisson distribution." + ); + } + + return Ok(new Poisson(lambda)); + } + + simplePdf(x: number) { + return jstat.poisson.pdf(x, this.lambda); + } + + cdf(k: number) { + return jstat.poisson.cdf(k, this.lambda); + } + + // Not needed, until we support Sym.Poisson + inv(p: number): number { + throw new Error("Poisson inv not implemented"); + } + + mean() { + return this.lambda; + } + + variance(): result { + return Ok(this.lambda); + } + + sample() { + return jstat.poisson.sample(this.lambda); + } + + _isEqual(other: Poisson) { + return this.lambda === other.lambda; + } + + // Not needed, until we support Sym.Poisson + override toPointSetDist(): result { + return Result.Err(notYetImplemented()); + } +} + /* Calling e.g. "Normal.operate" returns an optional Result. If the result is undefined, there is no valid analytic solution. If it's a Result object, it can still return an error if there is a serious problem, diff --git a/packages/squiggle-lang/src/fr/danger.ts b/packages/squiggle-lang/src/fr/danger.ts index fae5a97d8e..056f0d4e33 100644 --- a/packages/squiggle-lang/src/fr/danger.ts +++ b/packages/squiggle-lang/src/fr/danger.ts @@ -20,12 +20,18 @@ import { frLambda, frNumber, } from "../library/registry/frTypes.js"; -import { FnFactory, unpackDistResult } from "../library/registry/helpers.js"; +import { + FnFactory, + unpackDistResult, + distResultToValue, + makeTwoArgsDist, + makeOneArgDist, +} from "../library/registry/helpers.js"; import { ReducerContext } from "../reducer/context.js"; import { Lambda } from "../reducer/lambda.js"; import * as E_A from "../utility/E_A.js"; import { Value, vArray, vNumber } from "../value/index.js"; -import { distResultToValue } from "./genericDist.js"; +import * as SymbolicDist from "../dist/SymbolicDist.js"; const { factorial } = jstat; @@ -410,6 +416,18 @@ const mapYLibrary: FRFunction[] = [ // TODO - shouldn't it be other way around, e^value? fn: (dist, env) => unpackDistResult(scalePower(dist, Math.E, { env })), }), + maker.make({ + name: "binomialDist", + examples: ["Danger.binomialDist(8, 0.5)"], + definitions: [makeTwoArgsDist((n, p) => SymbolicDist.Binomial.make(n, p))], + }), + maker.make({ + name: "poissonDist", + examples: ["Danger.poissonDist(10)"], + definitions: [ + makeOneArgDist((lambda) => SymbolicDist.Poisson.make(lambda)), + ], + }), ]; export const library = [ diff --git a/packages/squiggle-lang/src/fr/dist.ts b/packages/squiggle-lang/src/fr/dist.ts index 5d08801cf3..eceaffceed 100644 --- a/packages/squiggle-lang/src/fr/dist.ts +++ b/packages/squiggle-lang/src/fr/dist.ts @@ -1,22 +1,18 @@ -import { BaseDist } from "../dist/BaseDist.js"; import { argumentError, otherError } from "../dist/DistError.js"; -import * as SampleSetDist from "../dist/SampleSetDist/index.js"; import * as SymbolicDist from "../dist/SymbolicDist.js"; -import { REDistributionError, REOther } from "../errors/messages.js"; -import { Env } from "../index.js"; +import { REDistributionError } from "../errors/messages.js"; import { FRFunction } from "../library/registry/core.js"; import { makeDefinition } from "../library/registry/fnDefinition.js"; +import { frDist, frNumber, frDict } from "../library/registry/frTypes.js"; import { - frDistOrNumber, - frDist, - frNumber, - frDict, -} from "../library/registry/frTypes.js"; -import { FnFactory } from "../library/registry/helpers.js"; -import { OtherOperationError } from "../operationError.js"; + FnFactory, + makeOneArgDist, + makeSampleSet, + makeTwoArgsDist, + twoVarSample, +} from "../library/registry/helpers.js"; import * as Result from "../utility/result.js"; -import { Value, vDist } from "../value/index.js"; -import { distResultToValue } from "./genericDist.js"; +import { vDist } from "../value/index.js"; import { CI_CONFIG, symDistResultToValue } from "./distUtil.js"; import { mixtureDefinitions } from "./mixture.js"; @@ -25,68 +21,6 @@ const maker = new FnFactory({ requiresNamespace: false, }); -function makeSampleSet(d: BaseDist, env: Env) { - const result = SampleSetDist.SampleSetDist.fromDist(d, env); - if (!result.ok) { - throw new REDistributionError(result.value); - } - return result.value; -} - -function twoVarSample( - v1: BaseDist | number, - v2: BaseDist | number, - env: Env, - fn: ( - v1: number, - v2: number - ) => Result.result -): Value { - const sampleFn = (a: number, b: number) => - Result.fmap2( - fn(a, b), - (d) => d.sample(), - (e) => new OtherOperationError(e) - ); - - if (v1 instanceof BaseDist && v2 instanceof BaseDist) { - const s1 = makeSampleSet(v1, env); - const s2 = makeSampleSet(v2, env); - return distResultToValue( - SampleSetDist.map2({ - fn: sampleFn, - t1: s1, - t2: s2, - }) - ); - } else if (v1 instanceof BaseDist && typeof v2 === "number") { - const s1 = makeSampleSet(v1, env); - return distResultToValue(s1.samplesMap((a) => sampleFn(a, v2))); - } else if (typeof v1 === "number" && v2 instanceof BaseDist) { - const s2 = makeSampleSet(v2, env); - return distResultToValue(s2.samplesMap((a) => sampleFn(v1, a))); - } else if (typeof v1 === "number" && typeof v2 === "number") { - const result = fn(v1, v2); - if (!result.ok) { - throw new REOther(result.value); - } - return vDist(makeSampleSet(result.value, env)); - } - throw new REOther("Impossible branch"); -} - -function makeTwoArgsDist( - fn: ( - v1: number, - v2: number - ) => Result.result -) { - return makeDefinition( - [frDistOrNumber, frDistOrNumber], - ([v1, v2], { environment }) => twoVarSample(v1, v2, environment, fn) - ); -} - function makeCIDist( lowKey: K1, highKey: K2, @@ -115,31 +49,6 @@ function makeMeanStdevDist( ); } -function makeOneArgDist( - fn: (v: number) => Result.result -) { - return makeDefinition([frDistOrNumber], ([v], { environment }) => { - const sampleFn = (a: number) => - Result.fmap2( - fn(a), - (d) => d.sample(), - (e) => new OtherOperationError(e) - ); - - if (v instanceof BaseDist) { - const s = makeSampleSet(v, environment); - return distResultToValue(s.samplesMap(sampleFn)); - } else if (typeof v === "number") { - const result = fn(v); - if (!result.ok) { - throw new REOther(result.value); - } - return vDist(makeSampleSet(result.value, environment)); - } - throw new REOther("Impossible branch"); - }); -} - export const library: FRFunction[] = [ //We might want to later add all of the options to make() tht SampleSet has. For example, function() and list(). maker.make({ diff --git a/packages/squiggle-lang/src/fr/genericDist.ts b/packages/squiggle-lang/src/fr/genericDist.ts index 08c4bce8bf..a3c449b195 100644 --- a/packages/squiggle-lang/src/fr/genericDist.ts +++ b/packages/squiggle-lang/src/fr/genericDist.ts @@ -1,33 +1,24 @@ -import { BaseDist } from "../dist/BaseDist.js"; -import { DistError } from "../dist/DistError.js"; import * as SymbolicDist from "../dist/SymbolicDist.js"; import { BinaryOperation, binaryOperations, } from "../dist/distOperations/index.js"; -import { REDistributionError } from "../errors/messages.js"; import { FRFunction } from "../library/registry/core.js"; import { makeDefinition } from "../library/registry/fnDefinition.js"; import { frDist, frNumber } from "../library/registry/frTypes.js"; -import { FnFactory, unpackDistResult } from "../library/registry/helpers.js"; +import { + FnFactory, + distResultToValue, + unpackDistResult, +} from "../library/registry/helpers.js"; import * as magicNumbers from "../magicNumbers.js"; -import * as Result from "../utility/result.js"; -import { Value, vArray, vDist, vNumber } from "../value/index.js"; +import { vArray, vNumber } from "../value/index.js"; const maker = new FnFactory({ nameSpace: "", requiresNamespace: false, }); -export function distResultToValue( - result: Result.result -): Value { - if (!result.ok) { - throw new REDistributionError(result.value); - } - return vDist(result.value); -} - type OpPair = [string, BinaryOperation]; const algebraicOps: OpPair[] = [ ["add", binaryOperations.algebraicAdd], diff --git a/packages/squiggle-lang/src/library/registry/helpers.ts b/packages/squiggle-lang/src/library/registry/helpers.ts index 9fd8ac7aac..1aa2e1ad41 100644 --- a/packages/squiggle-lang/src/library/registry/helpers.ts +++ b/packages/squiggle-lang/src/library/registry/helpers.ts @@ -13,7 +13,16 @@ import * as Result from "../../utility/result.js"; import { Value, vBool, vDist, vNumber, vString } from "../../value/index.js"; import { FRFunction } from "./core.js"; import { FnDefinition, makeDefinition } from "./fnDefinition.js"; -import { frBool, frDist, frNumber, frString } from "./frTypes.js"; +import { + frBool, + frDist, + frDistOrNumber, + frNumber, + frString, +} from "./frTypes.js"; +import * as SampleSetDist from "../../dist/SampleSetDist/index.js"; +import * as SymbolicDist from "../../dist/SymbolicDist.js"; +import { OtherOperationError } from "../../operationError.js"; type SimplifiedArgs = Omit & Partial>; @@ -287,3 +296,99 @@ export function doBinaryLambdaCall( } throw new REOther("Expected function to return a boolean value"); } + +export function distResultToValue( + result: Result.result +): Value { + if (!result.ok) { + throw new REDistributionError(result.value); + } + return vDist(result.value); +} + +export function makeSampleSet(d: BaseDist, env: Env) { + const result = SampleSetDist.SampleSetDist.fromDist(d, env); + if (!result.ok) { + throw new REDistributionError(result.value); + } + return result.value; +} + +export function twoVarSample( + v1: BaseDist | number, + v2: BaseDist | number, + env: Env, + fn: ( + v1: number, + v2: number + ) => Result.result +): Value { + const sampleFn = (a: number, b: number) => + Result.fmap2( + fn(a, b), + (d) => d.sample(), + (e) => new OtherOperationError(e) + ); + + if (v1 instanceof BaseDist && v2 instanceof BaseDist) { + const s1 = makeSampleSet(v1, env); + const s2 = makeSampleSet(v2, env); + return distResultToValue( + SampleSetDist.map2({ + fn: sampleFn, + t1: s1, + t2: s2, + }) + ); + } else if (v1 instanceof BaseDist && typeof v2 === "number") { + const s1 = makeSampleSet(v1, env); + return distResultToValue(s1.samplesMap((a) => sampleFn(a, v2))); + } else if (typeof v1 === "number" && v2 instanceof BaseDist) { + const s2 = makeSampleSet(v2, env); + return distResultToValue(s2.samplesMap((a) => sampleFn(v1, a))); + } else if (typeof v1 === "number" && typeof v2 === "number") { + const result = fn(v1, v2); + if (!result.ok) { + throw new REOther(result.value); + } + return vDist(makeSampleSet(result.value, env)); + } + throw new REOther("Impossible branch"); +} + +export function makeTwoArgsDist( + fn: ( + v1: number, + v2: number + ) => Result.result +) { + return makeDefinition( + [frDistOrNumber, frDistOrNumber], + ([v1, v2], { environment }) => twoVarSample(v1, v2, environment, fn) + ); +} + +export function makeOneArgDist( + fn: (v: number) => Result.result +) { + return makeDefinition([frDistOrNumber], ([v], { environment }) => { + const sampleFn = (a: number) => + Result.fmap2( + fn(a), + (d) => d.sample(), + (e) => new OtherOperationError(e) + ); + + if (v instanceof BaseDist) { + const s = makeSampleSet(v, environment); + return distResultToValue(s.samplesMap(sampleFn)); + } else if (typeof v === "number") { + const result = fn(v); + if (!result.ok) { + throw new REOther(result.value); + } + return vDist(makeSampleSet(result.value, environment)); + } + throw new REOther("Impossible branch"); + }); +} diff --git a/packages/website/src/pages/docs/Api/Danger.md b/packages/website/src/pages/docs/Api/Danger.md index 8dc47ba6ed..4f6fdcf620 100644 --- a/packages/website/src/pages/docs/Api/Danger.md +++ b/packages/website/src/pages/docs/Api/Danger.md @@ -44,6 +44,39 @@ Danger.binomial: (number, number, number) => number `Danger.binomial(n, k, p)` returns `choose((n, k)) * pow(p, k) * pow(1 - p, n - k)`, i.e., the probability that an event of probability p will happen exactly k times in n draws. + +### binomialDist + +``` +Danger.binomialDist: (n: distribution|number,p: distribution|number) => distribution +``` +A binomial distribution. + +``n`` must be above 0, and ``p`` must be between 0 and 1. + +Note: The binomial distribution is a discrete distribution. When representing this, the Squiggle distribution component might show it as partially or fully continuous. This is a visual mistake; if you inspect the underlying data, it should be discrete. + +**Examples** + +```squiggle +binomialDist(5, 0.5) +binomialDist(10, 0.3) +``` + +### poissonDist + +``` +Danger.poissonDist: (distribution|number) => distribution +``` + +Note: The Poisson distribution is a discrete distribution. When representing this, the Squiggle distribution component might show it as partially or fully continuous. This is a visual mistake; if you inspect the underlying data, it should be discrete. + +**Examples** + +```squiggle +poissonDist(20) +``` + ### integrateFunctionBetweenWithNumIntegrationPoints ```