Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor combiner #10533

Merged
merged 5 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions packages/phone-number-privacy/combiner/src/common/error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { ErrorType } from '@celo/phone-number-privacy-common'

export class OdisError extends Error {
constructor(readonly code: ErrorType, readonly parent?: Error, readonly status: number = 500) {
// This is necessary when extending Error Classes
super(code) // 'Error' breaks prototype chain here
Object.setPrototypeOf(this, new.target.prototype) // restore prototype chain
}
}

export function wrapError<T>(
valueOrError: Promise<T>,
code: ErrorType,
status: number = 500
): Promise<T> {
return valueOrError.catch((parentErr) => {
throw new OdisError(code, parentErr, status)
})
}
155 changes: 142 additions & 13 deletions packages/phone-number-privacy/combiner/src/common/handlers.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import {
ErrorMessage,
ErrorType,
OdisRequest,
OdisResponse,
PnpQuotaStatus,
SequentialDelayDomainState,
WarningMessage,
send,
} from '@celo/phone-number-privacy-common'
import opentelemetry, { SpanStatusCode } from '@opentelemetry/api'
import { SemanticAttributes } from '@opentelemetry/semantic-conventions'
import Logger from 'bunyan'
import { Request, Response } from 'express'
import { performance, PerformanceObserver } from 'perf_hooks'
import { sendFailure } from './io'
import { PerformanceObserver, performance } from 'perf_hooks'
import { getCombinerVersion } from '../config'
import { OdisError } from './error'

const tracer = opentelemetry.trace.getTracer('combiner-tracer')

export interface Locals {
logger: Logger
Expand All @@ -18,31 +27,62 @@ export type PromiseHandler<R extends OdisRequest> = (
res: Response<OdisResponse<R>, Locals>
) => Promise<void>

type ParentHandler = (req: Request<{}, {}, any>, res: Response<any, Locals>) => Promise<void>

export function catchErrorHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): ParentHandler {
): PromiseHandler<R> {
return async (req, res) => {
const logger: Logger = res.locals.logger
try {
await handler(req, res)
} catch (err) {
const logger: Logger = res.locals.logger
logger.error(ErrorMessage.CAUGHT_ERROR_IN_ENDPOINT_HANDLER)
logger.error(err)
if (!res.headersSent) {
logger.info('Responding with error in outer endpoint handler')
res.status(500).json({
success: false,
error: ErrorMessage.UNKNOWN_ERROR,
})
if (err instanceof OdisError) {
sendFailure(err.code, err.status, res, req.url)
} else {
sendFailure(ErrorMessage.UNKNOWN_ERROR, 500, res, req.url)
}
} else {
logger.error(ErrorMessage.ERROR_AFTER_RESPONSE_SENT)
}
}
}
}

export function tracingHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
return tracer.startActiveSpan(
req.url,
{
attributes: {
[SemanticAttributes.HTTP_ROUTE]: req.path,
[SemanticAttributes.HTTP_METHOD]: req.method,
[SemanticAttributes.HTTP_CLIENT_IP]: req.ip,
},
},
async (span) => {
try {
await handler(req, res)
span.setStatus({
code: SpanStatusCode.OK,
})
} catch (err: any) {
span.setStatus({
code: SpanStatusCode.ERROR,
message: err instanceof Error ? err.message : 'Fail',
})
throw err
} finally {
span.end()
}
}
)
}
}

export function meteringHandler<R extends OdisRequest>(
handler: PromiseHandler<R>
): PromiseHandler<R> {
Expand Down Expand Up @@ -86,9 +126,98 @@ export function meteringHandler<R extends OdisRequest>(
}
}

export function timeoutHandler<R extends OdisRequest>(
gastonponti marked this conversation as resolved.
Show resolved Hide resolved
timeoutMs: number,
handler: PromiseHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
const timeoutSignal = (AbortSignal as any).timeout(timeoutMs)
timeoutSignal.addEventListener(
'abort',
() => {
if (!res.headersSent) {
sendFailure(ErrorMessage.TIMEOUT_FROM_SIGNER, 500, res, req.url)
}
},
{ once: true }
)

await handler(req, res)
}
}

export function withEnableHandler<R extends OdisRequest>(
gastonponti marked this conversation as resolved.
Show resolved Hide resolved
enabled: boolean,
handler: PromiseHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
if (enabled) {
return handler(req, res)
} else {
sendFailure(WarningMessage.API_UNAVAILABLE, 503, res, req.url)
}
}
}

export async function disabledHandler<R extends OdisRequest>(
_: Request<{}, {}, R>,
req: Request<{}, {}, R>,
response: Response<OdisResponse<R>, Locals>
): Promise<void> {
sendFailure(WarningMessage.API_UNAVAILABLE, 503, response)
sendFailure(WarningMessage.API_UNAVAILABLE, 503, response, req.url)
}

export function sendFailure(
alecps marked this conversation as resolved.
Show resolved Hide resolved
error: ErrorType,
status: number,
response: Response,
_endpoint: string,
gastonponti marked this conversation as resolved.
Show resolved Hide resolved
body?: Record<any, any> // TODO remove any
) {
send(
response,
{
success: false,
version: getCombinerVersion(),
error,
...body,
},
status,
response.locals.logger
)
}

export interface Result<R extends OdisRequest> {
status: number
body: OdisResponse<R>
}

export type ResultHandler<R extends OdisRequest> = (
request: Request<{}, {}, R>,
res: Response<OdisResponse<R>, Locals>
) => Promise<Result<R>>

export function resultHandler<R extends OdisRequest>(
resHandler: ResultHandler<R>
): PromiseHandler<R> {
return async (req, res) => {
const result = await resHandler(req, res)
send(res, result.body, result.status, res.locals.logger)
}
}

export function errorResult(
status: number,
error: string,
quotaStatus?: PnpQuotaStatus | { status: SequentialDelayDomainState }
): Result<any> {
// TODO remove any
return {
status,
body: {
success: false,
version: getCombinerVersion(),
error,
...quotaStatus,
},
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@ import {
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDisableDomainRequestAuthenticity,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'
import { errorResult, ResultHandler } from '../../../common/handlers'
import { getKeyVersionInfo } from '../../../common/io'
import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDisableDomainHandler(
export function disableDomain(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DisableDomainRequest> {
): ResultHandler<DisableDomainRequest> {
return async (request, response) => {
if (!disableDomainRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
return errorResult(400, WarningMessage.INVALID_INPUT)
}

if (!verifyDisableDomainRequestAuthenticity(request.body)) {
sendFailure(WarningMessage.UNAUTHENTICATED_USER, 401, response)
return
return errorResult(401, WarningMessage.UNAUTHENTICATED_USER)
}

// TODO remove?
Expand All @@ -57,18 +54,14 @@ export function createDisableDomainHandler(
signers.length
)
if (disableDomainStatus.disabled) {
send(
response,
{
return {
status: 200,
body: {
success: true,
version: getCombinerVersion(),
status: disableDomainStatus,
},
200,
response.locals.logger
)

return
}
}
} catch (err) {
response.locals.logger.error(
Expand All @@ -77,6 +70,6 @@ export function createDisableDomainHandler(
)
}

sendFailure(ErrorMessage.THRESHOLD_DISABLE_DOMAIN_FAILURE, maxErrorCode ?? 500, response)
return errorResult(maxErrorCode ?? 500, ErrorMessage.THRESHOLD_DISABLE_DOMAIN_FAILURE)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@ import {
DomainSchema,
ErrorMessage,
getSignerEndpoint,
send,
SequentialDelayDomainStateSchema,
verifyDomainQuotaStatusRequestAuthenticity,
WarningMessage,
} from '@celo/phone-number-privacy-common'
import { Signer, thresholdCallToSigners } from '../../../common/combine'
import { PromiseHandler } from '../../../common/handlers'
import { getKeyVersionInfo, sendFailure } from '../../../common/io'
import { errorResult, ResultHandler } from '../../../common/handlers'
import { getKeyVersionInfo } from '../../../common/io'
import { getCombinerVersion, OdisConfig } from '../../../config'
import { logDomainResponseDiscrepancies } from '../../services/log-responses'
import { findThresholdDomainState } from '../../services/threshold-state'

export function createDomainQuotaHandler(
export function domainQuota(
signers: Signer[],
config: OdisConfig
): PromiseHandler<DomainQuotaStatusRequest> {
): ResultHandler<DomainQuotaStatusRequest> {
return async (request, response) => {
if (!domainQuotaStatusRequestSchema(DomainSchema).is(request.body)) {
sendFailure(WarningMessage.INVALID_INPUT, 400, response)
return
return errorResult(400, WarningMessage.INVALID_INPUT)
}

if (!verifyDomainQuotaStatusRequestAuthenticity(request.body)) {
sendFailure(WarningMessage.UNAUTHENTICATED_USER, 401, response)
return
return errorResult(401, WarningMessage.UNAUTHENTICATED_USER)
}

// TODO remove?
Expand All @@ -49,21 +46,18 @@ export function createDomainQuotaHandler(
logDomainResponseDiscrepancies(response.locals.logger, signerResponses)
if (signerResponses.length >= keyVersionInfo.threshold) {
try {
send(
response,
{
return {
status: 200,
body: {
success: true,
version: getCombinerVersion(),
status: findThresholdDomainState(keyVersionInfo, signerResponses, signers.length),
},
200,
response.locals.logger
)
return
}
} catch (err) {
response.locals.logger.error(err, 'Error combining signer quota status responses')
}
}
sendFailure(ErrorMessage.THRESHOLD_DOMAIN_QUOTA_STATUS_FAILURE, maxErrorCode ?? 500, response)
return errorResult(maxErrorCode ?? 500, ErrorMessage.THRESHOLD_DOMAIN_QUOTA_STATUS_FAILURE)
}
}
Loading
Loading