-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #658 from epfml/656-improve-gpt-julien
Fix gpt-tfjs bugs, add tests and refactor code
- Loading branch information
Showing
27 changed files
with
933 additions
and
770 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
discojs/discojs-core/src/dataset/data/preprocessing/text_preprocessing.spec.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import { TEXT_PREPROCESSING } from './index.js' | ||
import { expect } from 'chai' | ||
|
||
import type { Task } from '../../../index.js' | ||
import * as tf from '@tensorflow/tfjs' | ||
|
||
describe('text preprocessing', function () { | ||
const [tokenize, leftPadding] = TEXT_PREPROCESSING | ||
// Use a function to create different task object for each test (otherwise the tokenizer gets cached) | ||
function initMockTask(): Task { | ||
return { | ||
id: 'mock-task-id', | ||
displayInformation: {}, | ||
trainingInformation: { | ||
modelID: 'model-id', | ||
epochs: 1, | ||
roundDuration: 1, | ||
validationSplit: 0, | ||
batchSize: 8, | ||
scheme: 'local', | ||
dataType: 'text', | ||
tokenizer: 'Xenova/gpt2', | ||
}} | ||
} | ||
|
||
const text = "Hello world, a bc 1 2345, '? 976. Wikipedia is a free content online encyclopedia written and maintained by a community \n of volunteers, known as Wikipedians. Founded by Jimmy Wales and Larry Sanger on January 15, 2001, Wikipedia is hosted by the Wikimedia Foundation, an American nonprofit organization that employs a staff of over 700 people.[7]" | ||
const expectedTokens = [15496, 995, 11, 257, 47125, 352, 2242, 2231, 11, 705, 30, 860, 4304, 13, 15312, 318, 257, 1479, 2695, 2691, 45352, 3194, 290, 9456, 416, 257, 2055, 220, 198, 286, 11661, 11, 1900, 355, 11145, 46647, 1547, 13, 4062, 276, 416, 12963, 11769, 290, 13633, 311, 2564, 319, 3269, 1315, 11, 5878, 11, 15312, 318, 12007, 416, 262, 44877, 5693, 11, 281, 1605, 15346, 4009, 326, 24803, 257, 3085, 286, 625, 13037, 661, 3693, 22, 60] | ||
|
||
it('can tokenize text', async () => { | ||
const { tokens } = await tokenize.apply(Promise.resolve(text), initMockTask()) as { tokens: number[]} | ||
expect(tokens).to.be.deep.equal(expectedTokens) | ||
}) | ||
|
||
it('can truncate inputs when tokenizing', async () => { | ||
const truncationTask = initMockTask() | ||
truncationTask.trainingInformation.maxSequenceLength = 10 | ||
const { tokens } = await tokenize.apply(Promise.resolve(text), truncationTask) as { tokens: number[] } | ||
const expectedLength = truncationTask.trainingInformation.maxSequenceLength + 1 // + 1 because tokenization includes an extra token label for next label prediction | ||
expect(tokens.length).to.be.equal(expectedLength) | ||
expect(tokens).to.be.deep.equal(expectedTokens.slice(0, expectedLength)) | ||
}) | ||
|
||
it('can left pad tokens', async () => { | ||
// Create a task where output token sequence should all have length 20 | ||
const paddingTask = initMockTask() | ||
paddingTask.trainingInformation.maxSequenceLength = 20 | ||
|
||
// Create a token sequence of length 10 | ||
const tokens = { tokens: [0,1,2,3,4,5,6,7,8,9] } | ||
const { xs, ys } = await leftPadding.apply(Promise.resolve(tokens), paddingTask) as { xs: tf.Tensor1D, ys: tf.Tensor2D } | ||
const xsArray = await xs.array() | ||
const ysArray = await ys.array() | ||
|
||
// Output sequences should have shape (20) and (20, 50258), 50258 being the size of the vocab for gpt2 | ||
expect(xsArray.length).to.be.equal(paddingTask.trainingInformation.maxSequenceLength) | ||
expect(ysArray.length).to.be.equal(paddingTask.trainingInformation.maxSequenceLength) | ||
expect(ysArray[0].length).to.be.equal(50258) | ||
|
||
// xs should be left pad with gpt2's padding token 50256 to be of length 20. | ||
// We expect the last token of input token sequence (9) to not be included in xs since it doesn't have a next token to be predicted | ||
const paddingToken = 50256 | ||
const expectedXs = Array.from({length:11}).map(_ => paddingToken).concat(tokens.tokens.slice(0,9)) | ||
expect(xsArray).to.be.deep.equal(expectedXs) | ||
|
||
// ys should be a one hot encoding of the next token in xs | ||
// if the input tokens are [0,1,2,3] then the labels are [1,2,3] which are then one-hot encoded | ||
// So the sum of each row should be equal to 1 | ||
const expectedOneHot = Array.from({ length: 20 }).map(_ => 1) | ||
expect(await ys.sum(-1).array()).to.be.deep.equal(expectedOneHot) | ||
|
||
// In each row, the index of the 1 should be the token id | ||
const expectedYs = Array.from({length:10}).map(_ => paddingToken).concat(tokens.tokens) | ||
expect(await ys.argMax(-1).array()).to.be.deep.equal(expectedYs) | ||
}) | ||
|
||
it('throws an error if no tokenizer is specified', async () => { | ||
const invalidTask = initMockTask() | ||
invalidTask.trainingInformation.tokenizer = undefined; | ||
try { | ||
await tokenize.apply(Promise.resolve("input text doesn't matter"), invalidTask) | ||
} catch { | ||
return | ||
} | ||
throw new Error("undefined tokenizer should have thrown an error") | ||
}) | ||
it('throws an error if the tokenizer name is invalid', async () => { | ||
const invalidTask = initMockTask() | ||
invalidTask['trainingInformation']['tokenizer'] = 'invalid-tokenizer-name' | ||
try { | ||
await tokenize.apply(Promise.resolve("input text doesn't matter"), invalidTask) | ||
} catch { | ||
return | ||
} | ||
throw new Error("invalid tokenizer name should have thrown an error") | ||
}) | ||
|
||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import { expect } from 'chai' | ||
import * as tf from '@tensorflow/tfjs-node' | ||
import { AutoTokenizer } from '@xenova/transformers'; | ||
import { GPT } from './index.js' | ||
import { type GPTConfig } from './config.js' | ||
|
||
describe('gpt-tfjs', function() { | ||
this.timeout(50_000) | ||
const data = "Lorem ipsum dolor sit" | ||
|
||
const config: GPTConfig = { | ||
modelType: 'gpt-nano', | ||
lr: 0.01, | ||
maxIter: 10, | ||
evaluateEvery:10, | ||
maxEvalBatches: 10, | ||
blockSize: 8, | ||
vocabSize: 50258 | ||
} | ||
|
||
it('can overfit one sentence', async () => { | ||
const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2') | ||
const datasetSource = new tf.data.FileDataSource(Buffer.from(data)) | ||
const textDataset = new tf.data.TextLineDataset(datasetSource) | ||
const tokenDataset = textDataset.map((text: string) => { | ||
const { input_ids: tokens } = tokenizer(text, { | ||
padding: true, | ||
truncation: true, | ||
return_tensor: false, | ||
max_length: config.blockSize + 1, | ||
}) as { input_ids: number[] } | ||
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length + 1) | ||
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32') | ||
return {xs, ys} | ||
}).repeat().batch(64) | ||
|
||
const model = new GPT(config) | ||
const logGenerator = model.train(tokenDataset, undefined, 5) // 5 epochs | ||
for await (const _ of logGenerator); // Await the end of training | ||
const generation = await model.generate("Lorem ipsum dolor", tokenizer, 1) | ||
console.log(generation) | ||
expect(generation).equal(data) // Assert that the model completes 'Lorem ipsum dolor' with 'sit' | ||
}) | ||
}) |
Oops, something went wrong.