Skip to content

Commit

Permalink
feat: CommonDaoCfg.patchInTransaction
Browse files Browse the repository at this point in the history
  • Loading branch information
kirillgroshkov committed Jan 19, 2024
1 parent 9b03a07 commit 59e40bd
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 36 deletions.
10 changes: 9 additions & 1 deletion src/commondao/common.dao.model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ export interface CommonDaoCfg<
* Set to false to disable auto-generation of `id`.
* Useful e.g when your DB is generating ids by itself (e.g mysql auto_increment).
*/
createId?: boolean
generateId?: boolean

/**
* See the same option in CommonDB.
Expand Down Expand Up @@ -223,6 +223,14 @@ export interface CommonDaoCfg<
* @deprecated
*/
filterNullishValues?: boolean

/**
* Defaults to false.
* If true - run patch operations (patch, patchById) in a Transaction.
*
* @experimental
*/
patchInTransaction?: boolean
}

/**
Expand Down
109 changes: 74 additions & 35 deletions src/commondao/common.dao.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ export class CommonDao<
// otherwise to log Operations
// e.g in Dev (local machine), Test - it will log operations (useful for debugging)
logLevel: isGAE || isCI ? CommonDaoLogLevel.NONE : CommonDaoLogLevel.OPERATIONS,
createId: true,
generateId: true,
assignGeneratedIds: false,
useCreatedProperty: true,
useUpdatedProperty: true,
Expand All @@ -106,7 +106,7 @@ export class CommonDao<
} satisfies Partial<CommonDaoHooks<BM, DBM, TM>>,
}

if (this.cfg.createId) {
if (this.cfg.generateId) {
this.cfg.hooks!.createRandomId ||= () => stringId()
} else {
delete this.cfg.hooks!.createRandomId
Expand All @@ -130,7 +130,7 @@ export class CommonDao<
const table = opt.table || this.cfg.table
const started = this.logStarted(op, table)

let dbm = (await this.cfg.db.getByIds<DBM>(table, [id]))[0]
let dbm = (await (opt.tx || this.cfg.db).getByIds<DBM>(table, [id]))[0]
if (dbm && !opt.raw && this.cfg.hooks!.afterLoad) {
dbm = (await this.cfg.hooks!.afterLoad(dbm)) || undefined
}
Expand Down Expand Up @@ -170,7 +170,7 @@ export class CommonDao<
const op = `getByIdAsDBM(${id})`
const table = opt.table || this.cfg.table
const started = this.logStarted(op, table)
let [dbm] = await this.cfg.db.getByIds<DBM>(table, [id])
let [dbm] = await (opt.tx || this.cfg.db).getByIds<DBM>(table, [id])
if (dbm && !opt.raw && this.cfg.hooks!.afterLoad) {
dbm = (await this.cfg.hooks!.afterLoad(dbm)) || undefined
}
Expand All @@ -189,7 +189,7 @@ export class CommonDao<
const op = `getByIdAsTM(${id})`
const table = opt.table || this.cfg.table
const started = this.logStarted(op, table)
let [dbm] = await this.cfg.db.getByIds<DBM>(table, [id])
let [dbm] = await (opt.tx || this.cfg.db).getByIds<DBM>(table, [id])
if (dbm && !opt.raw && this.cfg.hooks!.afterLoad) {
dbm = (await this.cfg.hooks!.afterLoad(dbm)) || undefined
}
Expand Down Expand Up @@ -226,7 +226,7 @@ export class CommonDao<
const op = `getByIdsAsDBM ${ids.length} id(s) (${_truncate(ids.slice(0, 10).join(', '), 50)})`
const table = opt.table || this.cfg.table
const started = this.logStarted(op, table)
let dbms = await this.cfg.db.getByIds<DBM>(table, ids)
let dbms = await (opt.tx || this.cfg.db).getByIds<DBM>(table, ids)
if (!opt.raw && this.cfg.hooks!.afterLoad && dbms.length) {
dbms = (await pMap(dbms, async dbm => await this.cfg.hooks!.afterLoad!(dbm))).filter(
_isTruthy,
Expand Down Expand Up @@ -689,7 +689,7 @@ export class CommonDao<
obj.updated = opt.preserveUpdatedCreated && obj.updated ? obj.updated : now
}

if (this.cfg.createId) {
if (this.cfg.generateId) {
obj.id ||= this.cfg.hooks!.createNaturalId?.(obj as any) || this.cfg.hooks!.createRandomId!()
}

Expand All @@ -708,7 +708,7 @@ export class CommonDao<
return bm as Saved<BM>
}

const idWasGenerated = !bm.id && this.cfg.createId
const idWasGenerated = !bm.id && this.cfg.generateId
this.assignIdCreatedUpdated(bm, opt) // mutates
let dbm = await this.bmToDBM(bm, opt)

Expand All @@ -727,7 +727,7 @@ export class CommonDao<
const { excludeFromIndexes } = this.cfg
const assignGeneratedIds = opt.assignGeneratedIds || this.cfg.assignGeneratedIds

await this.cfg.db.saveBatch(table, [dbm], {
await (opt.tx || this.cfg.db).saveBatch(table, [dbm], {
excludeFromIndexes,
assignGeneratedIds,
...opt,
Expand Down Expand Up @@ -784,6 +784,13 @@ export class CommonDao<
patch: Partial<BM>,
opt: CommonDaoSaveBatchOptions<DBM> = {},
): Promise<Saved<BM>> {
if (this.cfg.patchInTransaction && !opt.tx) {
// patchInTransaction means that we should run this op in Transaction
// But if opt.tx is passed - means that we are already in a Transaction,
// and should just continue as-is
return await this.patchByIdInTransaction(id, patch, opt)
}

let patched: Saved<BM>
const loaded = await this.getById(id, opt)

Expand All @@ -801,6 +808,19 @@ export class CommonDao<
return await this.save(patched, opt)
}

/**
* Like patchById, but runs all operations within a Transaction.
*/
async patchByIdInTransaction(
id: string,
patch: Partial<BM>,
opt?: CommonDaoSaveBatchOptions<DBM>,
): Promise<Saved<BM>> {
return await this.runInTransaction(async daoTx => {
return await this.patchById(id, patch, { ...opt, tx: daoTx.tx })
})
}

/**
* Same as patchById, but takes the whole object as input.
* This "whole object" is mutated with the patch and returned.
Expand All @@ -812,9 +832,12 @@ export class CommonDao<
patch: Partial<BM>,
opt: CommonDaoSaveBatchOptions<DBM> = {},
): Promise<Saved<BM>> {
_assert(bm.id, 'patch argument object should have an id', {
bm,
})
if (this.cfg.patchInTransaction && !opt.tx) {
// patchInTransaction means that we should run this op in Transaction
// But if opt.tx is passed - means that we are already in a Transaction,
// and should just continue as-is
return await this.patchInTransaction(bm, patch, opt)
}

const loaded = await this.getById(bm.id, opt)

Expand All @@ -835,6 +858,19 @@ export class CommonDao<
return await this.save(bm, opt)
}

/**
* Like patch, but runs all operations within a Transaction.
*/
async patchInTransaction(
bm: Saved<BM>,
patch: Partial<BM>,
opt?: CommonDaoSaveBatchOptions<DBM>,
): Promise<Saved<BM>> {
return await this.runInTransaction(async daoTx => {
return await this.patch(bm, patch, { ...opt, tx: daoTx.tx })
})
}

async saveAsDBM(dbm: DBM, opt: CommonDaoSaveBatchOptions<DBM> = {}): Promise<Saved<DBM>> {
this.requireWriteAccess()
const table = opt.table || this.cfg.table
Expand All @@ -843,7 +879,7 @@ export class CommonDao<
// will override/set `updated` field, unless opts.preserveUpdated is set
let row = dbm as Saved<DBM>
if (!opt.raw) {
const idWasGenerated = !dbm.id && this.cfg.createId
const idWasGenerated = !dbm.id && this.cfg.generateId
this.assignIdCreatedUpdated(dbm, opt) // mutates
row = this.anyToDBM(dbm, opt)
if (opt.ensureUniqueId && idWasGenerated) await this.ensureUniqueId(table, row)
Expand All @@ -861,7 +897,7 @@ export class CommonDao<
if (row === null) return dbm as Saved<DBM>
}

await this.cfg.db.saveBatch(table, [row], {
await (opt.tx || this.cfg.db).saveBatch(table, [row], {
excludeFromIndexes,
assignGeneratedIds,
...opt,
Expand Down Expand Up @@ -952,7 +988,7 @@ export class CommonDao<
)
}

await this.cfg.db.saveBatch(table, rows, {
await (opt.tx || this.cfg.db).saveBatch(table, rows, {
excludeFromIndexes,
assignGeneratedIds,
...opt,
Expand Down Expand Up @@ -1163,7 +1199,7 @@ export class CommonDao<
const bm = await this.cfg.hooks!.beforeDBMToBM!(dbm)

// Validate/convert BM
// eslint-disable-next-line @typescript-eslint/return-await

return this.validateAndConvert(bm, this.cfg.bmSchema, DBModelType.BM, opt)
}

Expand Down Expand Up @@ -1192,7 +1228,7 @@ export class CommonDao<
const dbm = { ...(await this.cfg.hooks!.beforeBMToDBM!(bm)) }

// Validate/convert DBM
// eslint-disable-next-line @typescript-eslint/return-await

return this.validateAndConvert(dbm, this.cfg.dbmSchema, DBModelType.DBM, opt)
}

Expand Down Expand Up @@ -1251,14 +1287,14 @@ export class CommonDao<
*
* Does NOT mutate the object.
*/
validateAndConvert<IN, OUT = IN>(
obj: Partial<IN>,
schema: ObjectSchema<IN> | AjvSchema<IN> | ZodSchema<IN> | undefined,
validateAndConvert<T>(
obj: Partial<T>,
schema: ObjectSchema<T> | AjvSchema<T> | ZodSchema<T> | undefined,
modelType: DBModelType,
opt: CommonDaoOptions = {},
): OUT {
): any {
// `raw` option completely bypasses any processing
if (opt.raw) return obj as any as OUT
if (opt.raw) return obj as any

// Kirill 2021-10-18: I realized that there's little reason to keep removing null values
// So, from now on we'll preserve them
Expand All @@ -1277,31 +1313,31 @@ export class CommonDao<

// Pre-validation hooks
if (modelType === DBModelType.DBM) {
obj = this.cfg.hooks!.beforeDBMValidate!(obj as any) as IN
obj = this.cfg.hooks!.beforeDBMValidate!(obj as any) as T
}

// Return as is if no schema is passed or if `skipConversion` is set
if (!schema || opt.skipConversion) {
return obj as OUT
return obj
}

// This will Convert and Validate
const table = opt.table || this.cfg.table
const objectName = table + (modelType || '')

let error: JoiValidationError | AjvValidationError | ZodValidationError<IN> | undefined
let error: JoiValidationError | AjvValidationError | ZodValidationError<T> | undefined
let convertedValue: any

if (schema instanceof ZodSchema) {
// Zod schema
const vr = zSafeValidate(obj as IN, schema)
const vr = zSafeValidate(obj as T, schema)
error = vr.error
convertedValue = vr.data
} else if (schema instanceof AjvSchema) {
// Ajv schema
convertedValue = obj // because Ajv mutates original object

error = schema.getValidationError(obj as IN, {
error = schema.getValidationError(obj as T, {
objectName,
})
} else {
Expand Down Expand Up @@ -1337,20 +1373,24 @@ export class CommonDao<
await this.cfg.db.ping()
}

async runInTransaction(
fn: CommonDaoTransactionFn,
async runInTransaction<T = void>(
fn: CommonDaoTransactionFn<T>,
opt?: CommonDBTransactionOptions,
): Promise<void> {
): Promise<T> {
let r: T

await this.cfg.db.runInTransaction(async tx => {
const daoTx = new CommonDaoTransaction(tx, this.cfg.logger!)

try {
await fn(daoTx)
r = await fn(daoTx)
} catch (err) {
await daoTx.rollback() // graceful rollback that "never throws"
throw err
}
}, opt)

return r!
}

protected logResult(started: number, op: string, res: any, table: string): void {
Expand Down Expand Up @@ -1415,15 +1455,15 @@ export class CommonDao<
*
* Transaction is rolled back when the function returns rejected Promise (aka "throws").
*/
export type CommonDaoTransactionFn = (tx: CommonDaoTransaction) => Promise<void>
export type CommonDaoTransactionFn<T = void> = (tx: CommonDaoTransaction) => Promise<T>

/**
* Transaction context.
* Has similar API than CommonDao, but all operations are performed in the context of the transaction.
*/
export class CommonDaoTransaction {
constructor(
private tx: DBTransaction,
public tx: DBTransaction,
private logger: CommonLogger,
) {}

Expand All @@ -1444,8 +1484,7 @@ export class CommonDaoTransaction {
id?: string | null,
opt?: CommonDaoOptions,
): Promise<Saved<BM> | null> {
if (!id) return null
return (await this.getByIds(dao, [id], opt))[0] || null
return await dao.getById(id, { ...opt, tx: this.tx })
}

async getByIds<BM extends PartialObjectWithId, DBM extends PartialObjectWithId>(
Expand Down

0 comments on commit 59e40bd

Please sign in to comment.