Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow named types in unions #469

Merged
merged 24 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 79 additions & 29 deletions lib/types.js
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class Type {
wrapUnions = 'auto';
} else if (typeof wrapUnions == 'string') {
wrapUnions = wrapUnions.toLowerCase();
} else if (typeof wrapUnions === 'function') {
wrapUnions = 'auto';
}
switch (wrapUnions) {
case 'always':
Expand Down Expand Up @@ -196,11 +198,26 @@ class Type {
let types = schema.map((obj) => {
return Type.forSchema(obj, opts);
});
let projectionFn;
if (!UnionType) {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
// either automatic detection or we have a projection function
if (typeof opts.wrapUnions === 'function') {
// we have a projection function
joscha marked this conversation as resolved.
Show resolved Hide resolved
try {
projectionFn = opts.wrapUnions(types);
UnionType = typeof projectionFn !== 'undefined'
// projection function yields a function, we can use an Unwrapped type
? UnwrappedUnionType
: WrappedUnionType;
} catch(e) {
throw new Error(`Error generating projection function: ${e}`);
}
joscha marked this conversation as resolved.
Show resolved Hide resolved
} else {
UnionType = isAmbiguous(types) ? WrappedUnionType : UnwrappedUnionType;
}
}
LOGICAL_TYPE = logicalType;
type = new UnionType(types, opts);
type = new UnionType(types, opts, projectionFn);
} else { // New type definition.
type = (function (typeName) {
let Type = TYPES[typeName];
Expand Down Expand Up @@ -341,10 +358,10 @@ class Type {
return branchTypes[name];
}), opts);
} catch (err) {
opts.wrapUnions = wrapUnions;
throw err;
} finally {
opts.wrapUnions = wrapUnions;
}
opts.wrapUnions = wrapUnions;
return unionType;
}

Expand Down Expand Up @@ -1226,6 +1243,44 @@ UnionType.prototype._branchConstructor = function () {
throw new Error('unions cannot be directly wrapped');
};


function generateProjectionIndexer(projectionFn) {
return (val) => {
const index = projectionFn(val);
if (typeof index !== 'number') {
throw new Error(`Projected index '${index}' is not valid`);
}
return index;
};
}

function generateDefaultIndexer() {
joscha marked this conversation as resolved.
Show resolved Hide resolved
this._dynamicBranches = null;
this._bucketIndices = {};
this.types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
if (!this._dynamicBranches) {
this._dynamicBranches = [];
}
this._dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (this._bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
this._bucketIndices[bucket] = index;
}
}, this);
joscha marked this conversation as resolved.
Show resolved Hide resolved
return (val) => {
let index = this._bucketIndices[getValueBucket(val)];
if (this._dynamicBranches) {
// Slower path, we must run the value through all branches.
index = this._getBranchIndex(val, index);
}
return index;
};
}

/**
* "Natural" union type.
*
Expand All @@ -1246,38 +1301,33 @@ UnionType.prototype._branchConstructor = function () {
* + `map`, `record`
*/
class UnwrappedUnionType extends UnionType {
constructor (schema, opts) {
/**
*
* @param {*} schema
* @param {*} opts
* @param {Function|undefined} projectionFn The projection function used
* to determine the bucket for the
* Union. Falls back to generate
* from `wrapUnions` parameter
* if given.
*/
mtth marked this conversation as resolved.
Show resolved Hide resolved
constructor (schema, opts, projectionFn) {
super(schema, opts);

this._dynamicBranches = null;
this._bucketIndices = {};
this.types.forEach(function (type, index) {
if (Type.isType(type, 'abstract', 'logical')) {
if (!this._dynamicBranches) {
this._dynamicBranches = [];
}
this._dynamicBranches.push({index, type});
} else {
let bucket = getTypeBucket(type);
if (this._bucketIndices[bucket] !== undefined) {
throw new Error(`ambiguous unwrapped union: ${j(this)}`);
}
this._bucketIndices[bucket] = index;
if (!projectionFn && opts && typeof opts.wrapUnions === 'function') {
try {
projectionFn = opts.wrapUnions(this.types);
} catch(e) {
throw new Error(`Error generating projection function: ${e}`);
}
}, this);
}
this._getIndex = projectionFn
? generateProjectionIndexer(projectionFn)
: generateDefaultIndexer.bind(this)(this.types);
joscha marked this conversation as resolved.
Show resolved Hide resolved

Object.freeze(this);
}

_getIndex (val) {
let index = this._bucketIndices[getValueBucket(val)];
if (this._dynamicBranches) {
// Slower path, we must run the value through all branches.
index = this._getBranchIndex(val, index);
}
return index;
}

_getBranchIndex (any, index) {
joscha marked this conversation as resolved.
Show resolved Hide resolved
let logicalBranches = this._dynamicBranches;
for (let i = 0, l = logicalBranches.length; i < l; i++) {
Expand Down
48 changes: 48 additions & 0 deletions test/test_types.js
Original file line number Diff line number Diff line change
Expand Up @@ -3505,6 +3505,54 @@ suite('types', () => {
assert(Type.isType(t.field('unwrapped').type, 'union:unwrapped'));
});

test('union projection', () => {
joscha marked this conversation as resolved.
Show resolved Hide resolved
const Dog = {
type: 'record',
name: 'Dog',
fields: [
{ type: 'string', name: 'bark' }
],
};
const Cat = {
type: 'record',
name: 'Cat',
fields: [
{ type: 'string', name: 'meow' }
],
};
const animalTypes = [Dog, Cat];

const wrapUnions = (types) => {
assert.deepEqual(types.map(t => t.name), ['Dog', 'Cat']);
return (animal) => {
const animalType = ((animal) => {
if ('bark' in animal) {
return 'Dog';
} else if ('meow' in animal) {
return 'Cat';
}
throw new Error('Unknown animal');
})(animal);
return types.indexOf(types.find(type => type.name === animalType));
joscha marked this conversation as resolved.
Show resolved Hide resolved
}
};

// TODO: replace this with a mock when available
// currently we're on mocha without sinon
function mockWrapUnions() {
mockWrapUnions.calls = typeof mockWrapUnions.calls === 'undefined'
? 1
: ++mockWrapUnions.calls;
return wrapUnions.apply(null, arguments);
}
joscha marked this conversation as resolved.
Show resolved Hide resolved

// Ambiguous, but we have a projection function
const Animal = Type.forSchema(animalTypes, { wrapUnions: mockWrapUnions });
Animal.toBuffer({ meow: '🐈' });
assert.equal(mockWrapUnions.calls, 1);
assert.throws(() => Animal.toBuffer({ snap: '🐊' }), /Unknown animal/)
mtth marked this conversation as resolved.
Show resolved Hide resolved
});

test('invalid wrap unions option', () => {
assert.throws(() => {
Type.forSchema('string', {wrapUnions: 'FOO'});
Expand Down
13 changes: 12 additions & 1 deletion types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ interface EncoderOptions {
syncMarker: Buffer;
}

/**
* A projection function that is used when unwrapping unions.
* This function is called at schema parsing time on each union with its branches' types.
* If it returns a non-null (function) value, that function will be called each time a value's branch needs to be inferred and should return the branch's index.
* The index muss be a number between 0 and length-1 of the passed types.
* In this case (a branch index) the union will use an unwrapped representation. Otherwise (undefined), the union will be wrapped.
joscha marked this conversation as resolved.
Show resolved Hide resolved
*/
type BranchProjection = (types: ReadonlyArray<Type>) =>
| ((val: unknown) => number)
| undefined;

interface ForSchemaOptions {
assertLogicalTypes: boolean;
logicalTypes: { [type: string]: new (schema: Schema, opts?: any) => types.LogicalType; };
Expand All @@ -104,7 +115,7 @@ interface ForSchemaOptions {
omitRecordMethods: boolean;
registry: { [name: string]: Type };
typeHook: (schema: Schema | string, opts: ForSchemaOptions) => Type | undefined;
wrapUnions: boolean | 'auto' | 'always' | 'never';
wrapUnions: BranchProjection | boolean | 'auto' | 'always' | 'never';
}

interface TypeOptions extends ForSchemaOptions {
Expand Down