Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Danger - Binomial and Poisson distributions #2385

Merged
merged 8 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions packages/squiggle-lang/src/custom-types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,15 @@ declare module "jstat" {
}

export function factorial(x: number): number;

export namespace binomial {
export function pdf(k: number, n: number, p: number): number;
export function cdf(k: number, n: number, p: number): number;
}

export namespace poisson {
export function pdf(k: number, l: number): number;
export function cdf(x: number, l: number): number;
export function sample(l: number): number;
}
}
4 changes: 3 additions & 1 deletion packages/squiggle-lang/src/dist/SampleSetDist/kde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ export const kde = ({
let xWidth = kernelWidth ?? nrd0(samples);
samples = samples.filter((v) => Number.isFinite(v)); // Not sure if this is needed?
const len = samples.length;
if (len === 0) return { usedWidth: xWidth, xs: [], ys: [] };

// It's not clear what to do when xWidth is zero. We might want to throw an error or otherwise instead. This was an issue for discrete distributions, like binomial.
if (len === 0 || xWidth === 0) return { usedWidth: xWidth, xs: [], ys: [] };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My best guess about this:

  • xWidth is zero if all samples are identical, but somehow we decided to run kde on them, maybe because there are less than 5 of them (nrd0 will return 0 if variance is 0)
  • so you want to discard them because otherwise kde did something weird or failed

If so, I'm not sure if it's the best way to fix the problem (discarding samples seems bad, and this is too entangled with samplesToPointSetDist needs and other quirks). But I also don't have any better suggestions. Maybe this deserves a comment?


// Sample min and range
const smin = samples[0];
Expand Down
120 changes: 119 additions & 1 deletion packages/squiggle-lang/src/dist/SymbolicDist.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { BaseDist } from "./BaseDist.js";
import * as Result from "../utility/result.js";
import sum from "lodash/sum.js";
import jstat from "jstat";
import * as E_A_Floats from "../utility/E_A_Floats.js";
import * as XYShape from "../XYShape.js";
Expand All @@ -8,7 +9,7 @@ import * as Operation from "../operation.js";
import { PointSetDist } from "./PointSetDist.js";
import { Ok, result } from "../utility/result.js";
import { ContinuousShape } from "../PointSet/Continuous.js";
import { DistError, xyShapeDistError } from "./DistError.js";
import { DistError, notYetImplemented, xyShapeDistError } from "./DistError.js";
import { OperationError } from "../operationError.js";
import { DiscreteShape } from "../PointSet/Discrete.js";
import { Env } from "./env.js";
Expand Down Expand Up @@ -1070,6 +1071,123 @@ export class PointMass extends SymbolicDist {
}
}

export class Binomial extends SymbolicDist {
constructor(
public n: number,
public p: number
) {
super();
}
toString() {
return `Binomial(${this.n},{${this.p}})`;
}
static make(n: number, p: number): result<Binomial, string> {
if (!Number.isInteger(n) || n < 0) {
return Result.Err(
"The number of trials (n) must be a non-negative integer."
);
}
if (p < 0 || p > 1) {
return Result.Err("Binomial p must be between 0 and 1");
}
return Ok(new Binomial(n, p));
}

simplePdf(x: number) {
return jstat.binomial.pdf(x, this.n, this.p);
}

cdf(k: number) {
return jstat.binomial.cdf(k, this.n, this.p);
}

// Not needed, until we support Sym.Binomial
inv(p: number): number {
throw notYetImplemented();
}

mean() {
return this.n * this.p;
}

variance(): result<number, DistError> {
return Ok(this.n * this.p * (1 - this.p));
}

sample() {
const bernoulli = Bernoulli.make(this.p);
if (bernoulli.ok) {
// less space efficient than adding a bunch of draws, but cleaner. Taken from Guesstimate.
return sum(
Array.from({ length: this.n }, () => bernoulli.value.sample())
);
} else {
throw new Error("Binomial sampling failed");
}
}

_isEqual(other: Binomial) {
return this.n === other.n && this.p === other.p;
}

// Not needed, until we support Sym.Binomial
override toPointSetDist(): result<PointSetDist, DistError> {
return Result.Err(notYetImplemented());
}
}

export class Poisson extends SymbolicDist {
constructor(public lambda: number) {
super();
}
toString() {
return `Poisson(${this.lambda}})`;
}
static make(lambda: number): result<Poisson, string> {
if (lambda <= 0) {
throw new Error(
"Lambda must be a positive number for a Poisson distribution."
);
}

return Ok(new Poisson(lambda));
}

simplePdf(x: number) {
return jstat.poisson.pdf(x, this.lambda);
}

cdf(k: number) {
return jstat.poisson.cdf(k, this.lambda);
}

// Not needed, until we support Sym.Poisson
inv(p: number): number {
throw new Error("Poisson inv not implemented");
}

mean() {
return this.lambda;
}

variance(): result<number, DistError> {
return Ok(this.lambda);
}

sample() {
return jstat.poisson.sample(this.lambda);
}

_isEqual(other: Poisson) {
return this.lambda === other.lambda;
}

// Not needed, until we support Sym.Poisson
override toPointSetDist(): result<PointSetDist, DistError> {
return Result.Err(notYetImplemented());
}
}

/* Calling e.g. "Normal.operate" returns an optional Result.
If the result is undefined, there is no valid analytic solution.
If it's a Result object, it can still return an error if there is a serious problem,
Expand Down
22 changes: 20 additions & 2 deletions packages/squiggle-lang/src/fr/danger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ import {
frLambda,
frNumber,
} from "../library/registry/frTypes.js";
import { FnFactory, unpackDistResult } from "../library/registry/helpers.js";
import {
FnFactory,
unpackDistResult,
distResultToValue,
makeTwoArgsDist,
makeOneArgDist,
} from "../library/registry/helpers.js";
import { ReducerContext } from "../reducer/context.js";
import { Lambda } from "../reducer/lambda.js";
import * as E_A from "../utility/E_A.js";
import { Value, vArray, vNumber } from "../value/index.js";
import { distResultToValue } from "./genericDist.js";
import * as SymbolicDist from "../dist/SymbolicDist.js";

const { factorial } = jstat;

Expand Down Expand Up @@ -410,6 +416,18 @@ const mapYLibrary: FRFunction[] = [
// TODO - shouldn't it be other way around, e^value?
fn: (dist, env) => unpackDistResult(scalePower(dist, Math.E, { env })),
}),
maker.make({
name: "binomialDist",
examples: ["Danger.binomialDist(8, 0.5)"],
definitions: [makeTwoArgsDist((n, p) => SymbolicDist.Binomial.make(n, p))],
}),
maker.make({
name: "poissonDist",
examples: ["Danger.poissonDist(10)"],
definitions: [
makeOneArgDist((lambda) => SymbolicDist.Poisson.make(lambda)),
],
}),
];

export const library = [
Expand Down
109 changes: 9 additions & 100 deletions packages/squiggle-lang/src/fr/dist.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import { BaseDist } from "../dist/BaseDist.js";
import { argumentError, otherError } from "../dist/DistError.js";
import * as SampleSetDist from "../dist/SampleSetDist/index.js";
import * as SymbolicDist from "../dist/SymbolicDist.js";
import { REDistributionError, REOther } from "../errors/messages.js";
import { Env } from "../index.js";
import { REDistributionError } from "../errors/messages.js";
import { FRFunction } from "../library/registry/core.js";
import { makeDefinition } from "../library/registry/fnDefinition.js";
import { frDist, frNumber, frDict } from "../library/registry/frTypes.js";
import {
frDistOrNumber,
frDist,
frNumber,
frDict,
} from "../library/registry/frTypes.js";
import { FnFactory } from "../library/registry/helpers.js";
import { OtherOperationError } from "../operationError.js";
FnFactory,
makeOneArgDist,
makeSampleSet,
makeTwoArgsDist,
twoVarSample,
} from "../library/registry/helpers.js";
import * as Result from "../utility/result.js";
import { Value, vDist } from "../value/index.js";
import { distResultToValue } from "./genericDist.js";
import { vDist } from "../value/index.js";
import { CI_CONFIG, symDistResultToValue } from "./distUtil.js";
import { mixtureDefinitions } from "./mixture.js";

Expand All @@ -25,68 +21,6 @@ const maker = new FnFactory({
requiresNamespace: false,
});

function makeSampleSet(d: BaseDist, env: Env) {
const result = SampleSetDist.SampleSetDist.fromDist(d, env);
if (!result.ok) {
throw new REDistributionError(result.value);
}
return result.value;
}

function twoVarSample(
v1: BaseDist | number,
v2: BaseDist | number,
env: Env,
fn: (
v1: number,
v2: number
) => Result.result<SymbolicDist.SymbolicDist, string>
): Value {
const sampleFn = (a: number, b: number) =>
Result.fmap2(
fn(a, b),
(d) => d.sample(),
(e) => new OtherOperationError(e)
);

if (v1 instanceof BaseDist && v2 instanceof BaseDist) {
const s1 = makeSampleSet(v1, env);
const s2 = makeSampleSet(v2, env);
return distResultToValue(
SampleSetDist.map2({
fn: sampleFn,
t1: s1,
t2: s2,
})
);
} else if (v1 instanceof BaseDist && typeof v2 === "number") {
const s1 = makeSampleSet(v1, env);
return distResultToValue(s1.samplesMap((a) => sampleFn(a, v2)));
} else if (typeof v1 === "number" && v2 instanceof BaseDist) {
const s2 = makeSampleSet(v2, env);
return distResultToValue(s2.samplesMap((a) => sampleFn(v1, a)));
} else if (typeof v1 === "number" && typeof v2 === "number") {
const result = fn(v1, v2);
if (!result.ok) {
throw new REOther(result.value);
}
return vDist(makeSampleSet(result.value, env));
}
throw new REOther("Impossible branch");
}

function makeTwoArgsDist(
fn: (
v1: number,
v2: number
) => Result.result<SymbolicDist.SymbolicDist, string>
) {
return makeDefinition(
[frDistOrNumber, frDistOrNumber],
([v1, v2], { environment }) => twoVarSample(v1, v2, environment, fn)
);
}

function makeCIDist<K1 extends string, K2 extends string>(
lowKey: K1,
highKey: K2,
Expand Down Expand Up @@ -115,31 +49,6 @@ function makeMeanStdevDist(
);
}

function makeOneArgDist(
fn: (v: number) => Result.result<SymbolicDist.SymbolicDist, string>
) {
return makeDefinition([frDistOrNumber], ([v], { environment }) => {
const sampleFn = (a: number) =>
Result.fmap2(
fn(a),
(d) => d.sample(),
(e) => new OtherOperationError(e)
);

if (v instanceof BaseDist) {
const s = makeSampleSet(v, environment);
return distResultToValue(s.samplesMap(sampleFn));
} else if (typeof v === "number") {
const result = fn(v);
if (!result.ok) {
throw new REOther(result.value);
}
return vDist(makeSampleSet(result.value, environment));
}
throw new REOther("Impossible branch");
});
}

export const library: FRFunction[] = [
//We might want to later add all of the options to make() tht SampleSet has. For example, function() and list().
maker.make({
Expand Down
21 changes: 6 additions & 15 deletions packages/squiggle-lang/src/fr/genericDist.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
import { BaseDist } from "../dist/BaseDist.js";
import { DistError } from "../dist/DistError.js";
import * as SymbolicDist from "../dist/SymbolicDist.js";
import {
BinaryOperation,
binaryOperations,
} from "../dist/distOperations/index.js";
import { REDistributionError } from "../errors/messages.js";
import { FRFunction } from "../library/registry/core.js";
import { makeDefinition } from "../library/registry/fnDefinition.js";
import { frDist, frNumber } from "../library/registry/frTypes.js";
import { FnFactory, unpackDistResult } from "../library/registry/helpers.js";
import {
FnFactory,
distResultToValue,
unpackDistResult,
} from "../library/registry/helpers.js";
import * as magicNumbers from "../magicNumbers.js";
import * as Result from "../utility/result.js";
import { Value, vArray, vDist, vNumber } from "../value/index.js";
import { vArray, vNumber } from "../value/index.js";

const maker = new FnFactory({
nameSpace: "",
requiresNamespace: false,
});

export function distResultToValue(
result: Result.result<BaseDist, DistError>
): Value {
if (!result.ok) {
throw new REDistributionError(result.value);
}
return vDist(result.value);
}

type OpPair = [string, BinaryOperation];
const algebraicOps: OpPair[] = [
["add", binaryOperations.algebraicAdd],
Expand Down
Loading
Loading