Skip to content

Commit

Permalink
discojs/validator: use postprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Oct 4, 2024
1 parent ee2edeb commit 9a611b7
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
4 changes: 4 additions & 0 deletions discojs/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ export interface ModelEncoded {
text: [List<Token>, Token];
}

/** what get's outputted by the Validator, for humans */
export interface Inferred {
// label of the image
image: string;
// column name and its prediction
tabular: Partial<Record<string, number>>;
// next token
text: string;
}
20 changes: 14 additions & 6 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type {
Dataset,
DataType,
Inferred,
Model,
Raw,
RawWithoutLabel,
Expand All @@ -26,21 +27,28 @@ export class Validator<D extends DataType> {
(await this.#model.predict(batch.map(([inputs, _]) => inputs)))
.zip(batch.map(([_, outputs]) => outputs))
.map(([infered, truth]) => infered === truth),
);
)
.unbatch();

for await (const batch of results) for (const e of batch) yield e;
for await (const e of results) yield e;
}

/** use the model to predict every line of the dataset */
async *infer(
dataset: Dataset<RawWithoutLabel[D]>,
): AsyncGenerator<number, void> {
const results = (
): AsyncGenerator<Inferred[D], void> {
const modelPredictions = (
await processing.preprocessWithoutLabel(this.task, dataset)
)
.batch(this.task.trainingInformation.batchSize)
.map((batch) => this.#model.predict(batch));
.map((batch) => this.#model.predict(batch))
.unbatch();

for await (const batch of results) for await (const e of batch) yield e;
const predictions = await processing.postprocess(
this.task,
modelPredictions,
);

for await (const e of predictions) yield e;
}
}
20 changes: 9 additions & 11 deletions webapp/src/components/testing/PredictSteps.vue
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ import createDebug from "debug";
import { List } from "immutable";
import { computed, ref, toRaw } from "vue";
import type { DataType, Model, Task } from "@epfml/discojs";
import type { DataType, Inferred, Model, Task } from "@epfml/discojs";
import { Validator } from "@epfml/discojs";
import InfoIcon from "@/assets/svg/InfoIcon.vue";
Expand Down Expand Up @@ -141,7 +141,7 @@ interface Results {
}
const dataset = ref<UnlabeledDataset[D]>();
const generator = ref<AsyncGenerator<number, void>>();
const generator = ref<AsyncGenerator<Inferred[D], void>>();
const predictions = ref<Results[D]>();
const visitedSamples = computed<number>(() => {
Expand Down Expand Up @@ -205,10 +205,9 @@ async function startImageInference(
let results: Results["image"] = List();
try {
generator.value = validator.infer(dataset.map(({ image }) => image));
for await (const [{ filename, image }, prediction] of dataset.zip(
toRaw(generator.value),
)) {
const gen = validator.infer(dataset.map(({ image }) => image));
generator.value = gen as AsyncGenerator<Inferred[D], void>;
for await (const [{ filename, image }, output] of dataset.zip(toRaw(gen))) {
results = results.push({
input: {
filename,
Expand All @@ -218,7 +217,7 @@ async function startImageInference(
image.height,
),
},
output: labels.get(prediction) ?? prediction.toString(),
output,
});
predictions.value = results as Results[D];
Expand All @@ -241,10 +240,9 @@ async function startTabularInference(
let results: Results["tabular"]["results"] = List();
try {
generator.value = validator.infer(dataset);
for await (const [input, prediction] of dataset.zip(
toRaw(generator.value),
)) {
const gen = validator.infer(dataset);
generator.value = gen as AsyncGenerator<Inferred[D], void>;
for await (const [input, prediction] of dataset.zip(toRaw(gen))) {
results = results.push({
input: labels.input.map((label) => {
const ret = input[label];
Expand Down

0 comments on commit 9a611b7

Please sign in to comment.