Skip to content

Commit

Permalink
Update to typeorm 0.3.x (#45)
Browse files Browse the repository at this point in the history
* Bump typeorm deps from 0.2.41 to 0.3.17

* Rename Connection to DataSource

* * Rename `findOne({ id })` to `findOneBy({ id })` and `findOneOrFail({ id })` to `findOneByOrFail({ id })`
* Add missing `where` key  to `find()` args

* Use findOne with a (mandatory) where clause

* Replace db.close with ds.destroy

* wait for dataSource initialization

* Missing renames of Connection to Datasource

* Remove unnecessary if (finOneOrFail would already throw an error)

* Replace `connect` with `initialize` and `close` with `destroy`

* Send retry request for CIPHERTEXT msgs on getMessages (#47)

* Add `node` to `FullBaileysMessage` (message.original`)

* When calling `fetchMessages` send a retry request for any message that was not decrypted before

* Bump baileys
  • Loading branch information
javiercr authored Dec 29, 2023
1 parent 6c6a236 commit 7df831b
Show file tree
Hide file tree
Showing 25 changed files with 238 additions and 174 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"pino": "^7.8.0",
"react": "https://github.com/TextsHQ/react-global-shim#main",
"sanitize-filename": "^1.6.3",
"typeorm": "^0.2.41",
"typeorm": "^0.3.17",
"typeorm-naming-strategies": "^2.0.0"
},
"devDependencies": {
Expand Down
38 changes: 24 additions & 14 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import makeWASocket, { Browsers, ChatModification, ConnectionState, delay, Socke
import { texts, StickerPack, PlatformAPI, OnServerEventCallback, MessageSendOptions, InboxName, LoginResult, OnConnStateChangeCallback, ReAuthError, CurrentUser, MessageContent, ConnectionError, PaginationArg, ClientContext, ActivityType, Thread, Paginated, User, PhoneNumber, ServerEvent, ConnectionStatus, ServerEventType, GetAssetOptions, AssetInfo, MessageLink, Attachment, ThreadFolderName, UserID, PaginatedWithCursors } from '@textshq/platform-sdk'
import { smartJSONStringify } from '@textshq/platform-sdk/dist/json'
import type { Logger } from 'pino'
import { LessThanOrEqual, type Connection } from 'typeorm'
import type { DataSource } from 'typeorm'
import { PassThrough } from 'stream'
import NodeCache from 'node-cache'
import getConnection from './utils/get-connection'
import getDataSource from './utils/get-data-source'
import DBUser from './entities/DBUser'
import { canReconnect, CONNECTION_STATE_MAP, generateInstanceId, isLoggedIn, LOGGED_OUT_CODES, makeMutex, mapMessageID, numberFromJid, PARTICIPANT_ACTION_MAP, PRESENCE_MAP, profilePictureUrl, waitForAllEventsToBeHandled } from './utils/generics'
import DBMessage from './entities/DBMessage'
Expand Down Expand Up @@ -113,7 +113,7 @@ export default class WhatsAppAPI implements PlatformAPI {

logger: Logger

db: Connection
db: DataSource

get meID(): string | undefined {
if (!this.client) return
Expand Down Expand Up @@ -170,7 +170,7 @@ export default class WhatsAppAPI implements PlatformAPI {

this.logger.info({ dbPath, waVersion: this.latestWAVersion }, 'platform whatsapp init')

this.db = await getConnection(this.accountID, dbPath, this.logger)
this.db = await getDataSource(this.accountID, dbPath, this.logger)

this.dataStore = makeTextsBaileysStore(
this.publishEvent,
Expand Down Expand Up @@ -233,7 +233,7 @@ export default class WhatsAppAPI implements PlatformAPI {
}

await this.dataStore?.wait()
await this.db?.close()
await this.db?.destroy()

this.logger?.info('disposed')
}
Expand Down Expand Up @@ -336,7 +336,7 @@ export default class WhatsAppAPI implements PlatformAPI {
}

getCurrentUser = async (): Promise<CurrentUser> => {
let user: User | undefined = await this.db.getRepository(DBUser).findOne({ where: { isSelf: true } })
let user: User | null = await this.db.getRepository(DBUser).findOne({ where: { isSelf: true } })
if (!user) {
const id = this.meID
if (!id) {
Expand Down Expand Up @@ -423,7 +423,7 @@ export default class WhatsAppAPI implements PlatformAPI {

private loadWAMessageFromDB = async (threadID: string, messageID: string) => {
const repo = this.db.getRepository(DBMessage)
const dbmsg = await repo.findOne({ id: messageID, threadID })
const dbmsg = await repo.findOneBy({ id: messageID, threadID })
if (dbmsg) {
await remapMessagesAndSave(repo, [dbmsg], this)
}
Expand Down Expand Up @@ -647,7 +647,7 @@ export default class WhatsAppAPI implements PlatformAPI {
},
)
} else {
const user = await this.db.getRepository(DBUser).findOne({ id: thread.id })
const user = await this.db.getRepository(DBUser).findOneBy({ id: thread.id })
thread.user = user || null
}

Expand All @@ -657,7 +657,7 @@ export default class WhatsAppAPI implements PlatformAPI {
deleteThread = async (threadID: string) => {
// thread deletes are local on WA multi-device
const repo = this.db.getRepository(DBThread)
const item = await repo.findOne({ id: threadID })
const item = await repo.findOneBy({ id: threadID })
if (item) {
await repo.remove(item)
}
Expand Down Expand Up @@ -703,10 +703,20 @@ export default class WhatsAppAPI implements PlatformAPI {
await delay(50)
}

const result = await fetchMessages(this, threadID, pagination)
const result = await fetchMessages(this, threadID, pagination, this.senderRetryRequest)

return result
}

senderRetryRequest = async (message: DBMessage) => {
if (!this.client) {
throw new Error('client not initialized')
}
if (message.original.node) {
this.client.sendRetryRequest(message.original.node)
}
}

getUser = async (ids: { userID: UserID } | { username: string } | { phoneNumber: PhoneNumber } | { email: string }): Promise<User | undefined> => {
if (!('phoneNumber' in ids)) return
const { phoneNumber } = ids
Expand All @@ -729,7 +739,7 @@ export default class WhatsAppAPI implements PlatformAPI {

getOriginalObject = async (objName: 'thread' | 'message', objectID: string) => {
const repo = this.db.getRepository(objName === 'thread' ? DBThread : DBMessage)
const item = await repo.findOne({ id: objectID })
const item = await repo.findOneBy({ id: objectID })
return smartJSONStringify(item?.original)
}

Expand Down Expand Up @@ -839,7 +849,7 @@ export default class WhatsAppAPI implements PlatformAPI {

getMessage = async (threadID: string, messageID: string) => {
const repo = this.db.getRepository(DBMessage)
const msg = await repo.findOne({ threadID, id: messageID })
const msg = await repo.findOneBy({ threadID, id: messageID })
if (msg) {
await remapMessagesAndSave(repo, [msg], this)
}
Expand Down Expand Up @@ -869,7 +879,7 @@ export default class WhatsAppAPI implements PlatformAPI {

forwardMessage = async (threadID: string, messageID: string, threadIDs?: string[]) => {
await this.waitForConnectionOpen()
const { original: { message } } = await this.db.getRepository(DBMessage).findOneOrFail({
const { original: { message } } = await this.db.getRepository(DBMessage).findOneByOrFail({
id: messageID,
threadID,
})
Expand Down Expand Up @@ -1104,7 +1114,7 @@ export default class WhatsAppAPI implements PlatformAPI {

private getChat = (threadID: string) => {
const repo = this.db.getRepository(DBThread)
return repo.findOne({ id: threadID })
return repo.findOneBy({ id: threadID })
}

getStickerPacks = async (): Promise<PaginatedWithCursors<StickerPack>> => {
Expand Down
2 changes: 1 addition & 1 deletion src/entities/DBMessage-util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ export const MessageTransformer: ValueTransformer = {
},
to: (item: FullBaileysMessage | null) => {
if (item) {
return serialize({ ...item, message: WAProto.WebMessageInfo.encode(item.message).finish() })
return serialize({ ...item, node: item.message.node, message: WAProto.WebMessageInfo.encode(item.message).finish() })
}
return null
},
Expand Down
24 changes: 12 additions & 12 deletions src/tests/db.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// eslint-disable-next-line import/order
import getConnection from '../utils/get-connection'
import getDataSource from '../utils/get-data-source'

import { Chat, delay, generateMessageID, makeEventBuffer, unixTimestampSeconds, WAMessageStubType, WAProto } from 'baileys'
import { unlink, stat } from 'fs/promises'
import type { Connection } from 'typeorm'
import type { DataSource } from 'typeorm'
import DBMessage from '../entities/DBMessage'
import DBThread from '../entities/DBThread'
import type { MappingContextWithDB } from '../types'
Expand All @@ -20,7 +20,7 @@ logger.level = 'trace'
jest.setTimeout(30_000)

describe('Database Sync Tests', () => {
let db: Connection
let db: DataSource
let store: ReturnType<typeof makeTextsBaileysStore>

const mappingCtx: MappingContextWithDB = {
Expand All @@ -38,7 +38,7 @@ describe('Database Sync Tests', () => {
logger.info('removing existing DB')
await unlink(DB_PATH)
}
db = await getConnection('default', DB_PATH, logger)
db = await getDataSource('default', DB_PATH, logger)
mappingCtx.db = db
store = makeTextsBaileysStore(() => { }, () => { throw new Error('no') }, mappingCtx)
ev.process(events => store.process(events).then(() => { }))
Expand Down Expand Up @@ -67,13 +67,13 @@ describe('Database Sync Tests', () => {
await delay(500)

expect(
await db.getRepository(DBThread).findOne({
await db.getRepository(DBThread).findOneBy({
id: msg.key.remoteJid!,
}),
).toBeTruthy()

expect(
await db.getRepository(DBMessage).findOne({
await db.getRepository(DBMessage).findOneBy({
id: mapMessageID(msg.key),
}),
).toBeTruthy()
Expand Down Expand Up @@ -124,7 +124,7 @@ describe('Database Sync Tests', () => {
await delay(200)

const repo = db.getRepository(DBThread)
const thread = await repo.findOne({ id: jid })
const thread = await repo.findOneBy({ id: jid })
expect(thread?.unreadCount).toEqual(1)
expect(thread?.timestamp).toEqual(new Date(ogTimstamp * 1000))
})
Expand Down Expand Up @@ -163,7 +163,7 @@ describe('Database Sync Tests', () => {
await delay(200)

const repo = db.getRepository(DBThread)
const thread = await repo.findOne({ id: jid })
const thread = await repo.findOneBy({ id: jid })
expect(thread?.original?.chat?.ephemeralExpiration).toEqual(60 * 60 * 24)
})

Expand Down Expand Up @@ -237,7 +237,7 @@ describe('Database Sync Tests', () => {
})

const repo = db.getRepository(DBMessage)
const messages = await repo.find({ threadID: jid })
const messages = await repo.find({ where: { threadID: jid }})
expect(messages).toHaveLength(3)
})

Expand Down Expand Up @@ -341,13 +341,13 @@ describe('Database Sync Tests', () => {
)

await store.wait()
await db.close()
await db.destroy()
await tasks

await db.connect()
await db.initialize()

const repo = db.getRepository(DBMessage)
const dbMessages = await repo.find({ threadID: jid })
const dbMessages = await repo.find({ where: { threadID: jid }})
expect(dbMessages).toHaveLength(3)
})
})
7 changes: 4 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { Chat, Contact, GroupMetadata, GroupParticipant, WAMessage } from 'baileys'
import type { BinaryNode, Chat, Contact, GroupMetadata, GroupParticipant, WAMessage } from 'baileys'
import type { Logger } from 'pino'
import type { Connection, EntityManager } from 'typeorm'
import type { DataSource, EntityManager } from 'typeorm'

export type FullBaileysChat = {
chat: Partial<Chat> & {
Expand All @@ -14,6 +14,7 @@ export type FullBaileysChat = {

export type FullBaileysMessage = {
message: WAMessage
node?: BinaryNode
seenByMe?: boolean
// the last version of platform-whatsapp the message was mapped on
lastMappedVersion: number | undefined
Expand All @@ -30,7 +31,7 @@ export type MappingContext = {
logger: Logger
}

export type MappingContextWithDB = MappingContext & { db: Connection | EntityManager }
export type MappingContextWithDB = MappingContext & { db: DataSource | EntityManager }

export type Transaction = any

Expand Down
6 changes: 3 additions & 3 deletions src/utils/db-get-earliest-msg-order-key.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { Connection, EntityManager } from 'typeorm'
import { Not, type DataSource, type EntityManager, IsNull } from 'typeorm'
import DBMessage from '../entities/DBMessage'

export default async (db: EntityManager | Connection) => {
export default async (db: EntityManager | DataSource) => {
const msg = await db.getRepository(DBMessage)
.findOne({ order: { orderKey: 'ASC' }, select: ['orderKey'] })
.findOne({ where: { id: Not(IsNull()) }, order: { orderKey: 'ASC' }, select: ['orderKey'] })
return msg?.orderKey
}
6 changes: 3 additions & 3 deletions src/utils/db-get-latest-msg-order-key.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import type { Connection, EntityManager } from 'typeorm'
import { Not, type DataSource, type EntityManager, IsNull } from 'typeorm'
import DBMessage from '../entities/DBMessage'

export default async (db: EntityManager | Connection) => {
export default async (db: EntityManager | DataSource) => {
const msg = await db.getRepository(DBMessage)
.findOne({ order: { orderKey: 'DESC' }, select: ['orderKey'] })
.findOne({ where: { id: Not(IsNull()) }, order: { orderKey: 'DESC' }, select: ['orderKey'] })
return msg?.orderKey
}
14 changes: 7 additions & 7 deletions src/utils/db-key-store.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { Connection } from 'typeorm'
import type { DataSource } from 'typeorm'
import type { Logger } from 'pino'
import { SignalKeyStore, SignalDataTypeMap, makeCacheableSignalKeyStore } from 'baileys'
import AccountKeyValue from '../entities/AccountKeyValue'
Expand All @@ -16,12 +16,12 @@ const KEY_MAP: { [T in keyof SignalDataTypeMap]: string } = {
* Key store required for baileys.
* Stores all keys in the sqlite database
*/
const _makeDBKeyStore = (db: Connection): SignalKeyStore => {
const repo = db.getRepository(AccountKeyValue)
const _makeDBKeyStore = (ds: DataSource): SignalKeyStore => {
const repo = ds.getRepository(AccountKeyValue)

return {
get: async (type, ids) => {
const items = await db
const items = await ds
.createQueryBuilder(AccountKeyValue, 'acc')
.where('category = :category AND id IN (:...ids)', {
category: KEY_MAP[type],
Expand Down Expand Up @@ -51,7 +51,7 @@ const _makeDBKeyStore = (db: Connection): SignalKeyStore => {
}
}

await db.transaction(
await ds.transaction(
async db => {
const repo = db.getRepository(AccountKeyValue)

Expand All @@ -75,8 +75,8 @@ const _makeDBKeyStore = (db: Connection): SignalKeyStore => {
}
}

export const makeDBKeyStore = (db: Connection, logger: Logger) => {
const store = _makeDBKeyStore(db)
export const makeDBKeyStore = (ds: DataSource, logger: Logger) => {
const store = _makeDBKeyStore(ds)
return makeCacheableSignalKeyStore(
store,
logger,
Expand Down
14 changes: 7 additions & 7 deletions src/utils/db-mutex-all-transactions.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-param-reassign */
import type { Logger } from 'pino'
import type { Connection } from 'typeorm'
import type { DataSource } from 'typeorm'
import { makeMutex } from './generics'

/**
Expand All @@ -9,17 +9,17 @@ import { makeMutex } from './generics'
*
* to prevent that, we queue each transaction with this wrapper
*/
const dbMutexAllTransactions = (db: Connection, logger: Logger) => {
const dbMutexAllTransactions = (ds: DataSource, logger: Logger) => {
logger = logger.child({ class: 'transactions' })

const { mutex } = makeMutex()
const { transaction, close } = db
const { transaction, destroy } = ds

db.transaction = (...args: any) => {
ds.transaction = (...args: any) => {
if (logger.level === 'trace') logger.trace('called transaction')
return mutex(async () => {
try {
const result = await transaction.apply(db, args)
const result = await transaction.apply(ds, args)
return result
} catch (error) {
logger.error({ trace: error?.stack }, `error in transaction: ${error}`)
Expand All @@ -30,8 +30,8 @@ const dbMutexAllTransactions = (db: Connection, logger: Logger) => {
})
}

db.close = async () => {
await mutex(() => close.apply(db))
ds.destroy = async () => {
await mutex(() => destroy.apply(ds))
}
}

Expand Down
Loading

0 comments on commit 7df831b

Please sign in to comment.