diff --git a/packages/cc/src/cc/Security2CC.ts b/packages/cc/src/cc/Security2CC.ts index b7f2b1a176b6..9db2b1edeb0f 100644 --- a/packages/cc/src/cc/Security2CC.ts +++ b/packages/cc/src/cc/Security2CC.ts @@ -1420,7 +1420,7 @@ export class Security2CCMessageEncapsulation extends Security2CC { ); offset += extensionLength; - const ext = Security2Extension.from(extensionData); + const ext = Security2Extension.parse(extensionData); switch (validateS2Extension(ext, wasEncrypted)) { case ValidateS2ExtensionResult.OK: diff --git a/packages/cc/src/lib/Security2/Extension.ts b/packages/cc/src/lib/Security2/Extension.ts index c579b71639d1..0f077e1e300a 100644 --- a/packages/cc/src/lib/Security2/Extension.ts +++ b/packages/cc/src/lib/Security2/Extension.ts @@ -1,17 +1,10 @@ +import { createSimpleReflectionDecorator } from "@zwave-js/core"; import { ZWaveError, ZWaveErrorCodes, - isZWaveError, validatePayload, } from "@zwave-js/core/safe"; -import { - Bytes, - type TypedClassDecorator, - buffer2hex, - getEnumMemberName, - isUint8Array, -} from "@zwave-js/shared/safe"; -import "reflect-metadata"; +import { Bytes, buffer2hex, getEnumMemberName } from "@zwave-js/shared/safe"; enum S2ExtensionType { SPAN = 0x01, @@ -20,81 +13,31 @@ enum S2ExtensionType { MOS = 0x04, } -const METADATA_S2ExtensionMap = Symbol("S2ExtensionMap"); -const METADATA_S2Extension = Symbol("S2Extension"); - -type S2ExtensionMap = Map< - S2ExtensionType, +const extensionTypeDecorator = createSimpleReflectionDecorator< + Security2Extension, + [type: S2ExtensionType], Security2ExtensionConstructor ->; +>({ + name: "extensionType", +}); -export type Security2ExtensionConstructor = - & typeof Security2Extension - & { - new (options: Security2ExtensionOptions): T; - }; +/** Defines which S2 extension type a subclass of S2Extension has */ +export const extensionType = extensionTypeDecorator.decorator; + +/** Returns which S2 extension type a subclass of S2Extension has */ +export const getExtensionType = extensionTypeDecorator.lookupValue; /** * Looks up the S2 extension constructor for a given S2 extension type */ -export function getS2ExtensionConstructor( - type: S2ExtensionType, -): Security2ExtensionConstructor | undefined { - // Retrieve the constructor map from the CommandClass class - const map = Reflect.getMetadata( - METADATA_S2ExtensionMap, - Security2Extension, - ) as S2ExtensionMap | undefined; - return map?.get(type); -} +export const getS2ExtensionConstructor = + extensionTypeDecorator.lookupConstructor; -/** - * Defines the command class associated with a Z-Wave message - */ -export function extensionType( - type: S2ExtensionType, -): TypedClassDecorator { - return (extensionClass) => { - Reflect.defineMetadata(METADATA_S2Extension, type, extensionClass); - - const map: S2ExtensionMap = - Reflect.getMetadata(METADATA_S2ExtensionMap, Security2Extension) - || new Map(); - map.set( - type, - extensionClass as unknown as Security2ExtensionConstructor< - Security2Extension - >, - ); - Reflect.defineMetadata( - METADATA_S2ExtensionMap, - map, - Security2Extension, - ); +export type Security2ExtensionConstructor = + & typeof Security2Extension + & { + new (options: Security2ExtensionOptions): T; }; -} - -/** - * Retrieves the command class defined for a Z-Wave message class - */ -export function getExtensionType( - ext: T, -): S2ExtensionType { - // get the class constructor - const constr = ext.constructor; - // retrieve the current metadata - const ret: S2ExtensionType | undefined = Reflect.getMetadata( - METADATA_S2Extension, - constr, - ); - if (ret == undefined) { - throw new ZWaveError( - `No S2 extension type defined for ${constr.name}!`, - ZWaveErrorCodes.CC_Invalid, - ); - } - return ret; -} export enum ValidateS2ExtensionResult { OK, @@ -138,40 +81,87 @@ export function validateS2Extension( return ValidateS2ExtensionResult.OK; } -interface Security2ExtensionCreationOptions { - critical: boolean; - payload?: Uint8Array; -} +export class Security2ExtensionRaw { + public constructor( + public type: S2ExtensionType, + public critical: boolean, + public readonly moreToFollow: boolean, + public payload: Uint8Array, + ) {} + + public static parse(data: Uint8Array): Security2ExtensionRaw { + validatePayload(data.length >= 2); + const totalLength = data[0]; + const moreToFollow = !!(data[1] & 0b1000_0000); + const critical = !!(data[1] & 0b0100_0000); + const type = data[1] & 0b11_1111; + const payload = data.subarray(2, totalLength); -interface Security2ExtensionDeserializationOptions { - data: Uint8Array; + return new Security2ExtensionRaw(type, critical, moreToFollow, payload); + } + + public withPayload(payload: Bytes): Security2ExtensionRaw { + return new Security2ExtensionRaw( + this.type, + this.critical, + this.moreToFollow, + payload, + ); + } } -type Security2ExtensionOptions = - | Security2ExtensionCreationOptions - | Security2ExtensionDeserializationOptions; +interface Security2ExtensionBaseOptions { + critical?: boolean; + moreToFollow?: boolean; +} -function gotDeserializationOptions( - options: Record, -): options is Security2ExtensionDeserializationOptions { - return "data" in options && isUint8Array(options.data); +interface Security2ExtensionOptions extends Security2ExtensionBaseOptions { + type?: S2ExtensionType; + payload?: Uint8Array; } export class Security2Extension { public constructor(options: Security2ExtensionOptions) { - if (gotDeserializationOptions(options)) { - validatePayload(options.data.length >= 2); - const totalLength = options.data[0]; - const controlByte = options.data[1]; - this.moreToFollow = !!(controlByte & 0b1000_0000); - this.critical = !!(controlByte & 0b0100_0000); - this.type = controlByte & 0b11_1111; - this.payload = options.data.subarray(2, totalLength); - } else { - this.type = getExtensionType(this); - this.critical = options.critical; - this.payload = options.payload ?? new Uint8Array(); + const { + // Try to determine the extension type if none is given + type = getExtensionType(this), + critical = false, + moreToFollow = false, + payload = new Uint8Array(), + } = options; + + if (type == undefined) { + throw new ZWaveError( + "A Security2Extension must have a given or predefined extension type", + ZWaveErrorCodes.Argument_Invalid, + ); } + + this.type = type; + this.critical = critical; + this.moreToFollow = moreToFollow; + this.payload = payload; + } + + public static parse( + data: Uint8Array, + ): Security2Extension { + const raw = Security2ExtensionRaw.parse(data); + const Constructor = getS2ExtensionConstructor(raw.type) + ?? Security2Extension; + return Constructor.from(raw); + } + + /** Creates an instance of the message that is serialized in the given buffer */ + public static from( + raw: Security2ExtensionRaw, + ): Security2Extension { + return new this({ + type: raw.type, + critical: raw.critical, + moreToFollow: raw.moreToFollow, + payload: raw.payload, + }); } public type: S2ExtensionType; @@ -227,34 +217,6 @@ export class Security2Extension { return 2 + this.payload.length; } - /** - * Retrieves the correct constructor for the next extension in the given Buffer. - * It is assumed that the buffer has been checked beforehand - */ - public static getConstructor( - data: Uint8Array, - ): Security2ExtensionConstructor { - const type = data[1] & 0b11_1111; - return getS2ExtensionConstructor(type) ?? Security2Extension; - } - - /** Creates an instance of the S2 extension that is serialized in the given buffer */ - public static from(data: Uint8Array): Security2Extension { - const Constructor = Security2Extension.getConstructor(data); - try { - const ret = new Constructor({ data }); - return ret; - } catch (e) { - if ( - isZWaveError(e) - && e.code === ZWaveErrorCodes.PacketFormat_InvalidPayload - ) { - return new InvalidExtension({ data }); - } - throw e; - } - } - public toLogEntry(): string { let ret = ` ยท type: ${getEnumMemberName(S2ExtensionType, this.type)}`; @@ -276,28 +238,30 @@ interface SPANExtensionOptions { @extensionType(S2ExtensionType.SPAN) export class SPANExtension extends Security2Extension { public constructor( - options: - | Security2ExtensionDeserializationOptions - | SPANExtensionOptions, + options: SPANExtensionOptions & Security2ExtensionBaseOptions, ) { - if (gotDeserializationOptions(options)) { - super(options); - validatePayload(this.payload.length === 16); - this.senderEI = this.payload; - } else { - super({ critical: true }); - if (options.senderEI.length !== 16) { - throw new ZWaveError( - "The sender's entropy must be a 16-byte buffer!", - ZWaveErrorCodes.Argument_Invalid, - ); - } - this.senderEI = options.senderEI; + if (options.senderEI.length !== 16) { + throw new ZWaveError( + "The sender's entropy must be a 16-byte buffer!", + ZWaveErrorCodes.Argument_Invalid, + ); } + super({ critical: true, ...options }); + this.senderEI = options.senderEI; } - public senderEI: Uint8Array; + public static from(raw: Security2ExtensionRaw): Security2Extension { + validatePayload(raw.payload.length === 16); + const senderEI = raw.payload; + return new SPANExtension({ + critical: raw.critical, + moreToFollow: raw.moreToFollow, + senderEI, + }); + } + + public senderEI: Uint8Array; public static readonly expectedLength = 18; public serialize(moreToFollow: boolean): Bytes { @@ -320,26 +284,30 @@ interface MPANExtensionOptions { @extensionType(S2ExtensionType.MPAN) export class MPANExtension extends Security2Extension { public constructor( - options: - | Security2ExtensionDeserializationOptions - | MPANExtensionOptions, + options: MPANExtensionOptions & Security2ExtensionBaseOptions, ) { - if (gotDeserializationOptions(options)) { - super(options); - validatePayload(this.payload.length === 17); - this.groupId = this.payload[0]; - this.innerMPANState = this.payload.subarray(1); - } else { - if (options.innerMPANState.length !== 16) { - throw new ZWaveError( - "The inner MPAN state must be a 16-byte buffer!", - ZWaveErrorCodes.Argument_Invalid, - ); - } - super({ critical: true }); - this.groupId = options.groupId; - this.innerMPANState = options.innerMPANState; + if (options.innerMPANState.length !== 16) { + throw new ZWaveError( + "The inner MPAN state must be a 16-byte buffer!", + ZWaveErrorCodes.Argument_Invalid, + ); } + super({ critical: true, ...options }); + this.groupId = options.groupId; + this.innerMPANState = options.innerMPANState; + } + + public static from(raw: Security2ExtensionRaw): Security2Extension { + validatePayload(raw.payload.length === 17); + const groupId = raw.payload[0]; + const innerMPANState = raw.payload.subarray(1); + + return new MPANExtension({ + critical: raw.critical, + moreToFollow: raw.moreToFollow, + groupId, + innerMPANState, + }); } public groupId: number; @@ -378,18 +346,21 @@ interface MGRPExtensionOptions { @extensionType(S2ExtensionType.MGRP) export class MGRPExtension extends Security2Extension { public constructor( - options: - | Security2ExtensionDeserializationOptions - | MGRPExtensionOptions, + options: MGRPExtensionOptions & Security2ExtensionBaseOptions, ) { - if (gotDeserializationOptions(options)) { - super(options); - validatePayload(this.payload.length === 1); - this.groupId = this.payload[0]; - } else { - super({ critical: true }); - this.groupId = options.groupId; - } + super({ critical: true, ...options }); + this.groupId = options.groupId; + } + + public static from(raw: Security2ExtensionRaw): Security2Extension { + validatePayload(raw.payload.length === 1); + const groupId = raw.payload[0]; + + return new MGRPExtension({ + critical: raw.critical, + moreToFollow: raw.moreToFollow, + groupId, + }); } public groupId: number; @@ -410,12 +381,8 @@ export class MGRPExtension extends Security2Extension { @extensionType(S2ExtensionType.MOS) export class MOSExtension extends Security2Extension { - public constructor(options?: Security2ExtensionDeserializationOptions) { - if (options && gotDeserializationOptions(options)) { - super(options); - } else { - super({ critical: false }); - } + public constructor(options: Security2ExtensionBaseOptions = {}) { + super({ critical: false, ...options }); } public static readonly expectedLength = 2;