Skip to content

Commit

Permalink
throw error code in auth.getUser
Browse files Browse the repository at this point in the history
  • Loading branch information
prostgles committed Dec 21, 2024
1 parent bc2876b commit 47d9bda
Show file tree
Hide file tree
Showing 42 changed files with 647 additions and 756 deletions.
148 changes: 45 additions & 103 deletions lib/Auth/AuthHandler.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import {
AnyObject,
AuthFailure,
AuthGuardLocation,
AuthGuardLocationResponse,
AuthSocketSchema,
CHANNELS,
} from "prostgles-types";
import { LocalParams, PRGLIOSocket } from "../DboBuilder/DboBuilder";
import { PRGLIOSocket } from "../DboBuilder/DboBuilder";
import { DBOFullyTyped } from "../DBSchemaBuilder";
import { removeExpressRoute } from "../FileManager/FileManager";
import { DB, DBHandlerServer, Prostgles } from "../Prostgles";
import {
Auth,
AuthClientRequest,
AuthResult,
AuthResultOrError,
BasicSession,
ExpressReq,
ExpressRes,
Expand All @@ -23,6 +26,7 @@ import { getProviders } from "./setAuthProviders";
import { setupAuthRoutes } from "./setupAuthRoutes";
import { getClientRequestIPsInfo } from "./utils/getClientRequestIPsInfo";
import { getReturnUrl } from "./utils/getReturnUrl";
import { getUserFromRequest } from "./utils/getUserFromRequest";

export { getClientRequestIPsInfo };
export const HTTP_FAIL_CODES = {
Expand Down Expand Up @@ -141,23 +145,35 @@ export class AuthHandler {
}
};

getUser = async (clientReq: { httpReq: ExpressReq }): Promise<AuthResult> => {
const sid = clientReq.httpReq.cookies?.[this.sidKeyName];
getUserAndHandleError = async (localParams: AuthClientRequest): Promise<AuthResult> => {
const sid = this.getSID(localParams);
if (!sid) return undefined;

const handlerError = (code: AuthFailure["code"]) => {
if (localParams.httpReq) {
localParams.res
.status(HTTP_FAIL_CODES.BAD_REQUEST)
.json({ success: false, code, error: code });
}
throw code;
};
try {
return this.throttledFunc(async () => {
return this.opts!.getUser(
const userOrErrorCode = await this.throttledFunc(async () => {
return this.opts.getUser(
this.validateSid(sid),
this.dbo as any,
this.db,
getClientRequestIPsInfo(clientReq)
getClientRequestIPsInfo(localParams)
);
}, 50);
} catch (err) {
console.error(err);

if (typeof userOrErrorCode === "string") {
return handlerError(userOrErrorCode);
}

return userOrErrorCode;
} catch (_err) {
return handlerError("server-error");
}
return undefined;
};

init = setupAuthRoutes.bind(this);
Expand Down Expand Up @@ -265,7 +281,7 @@ export class AuthHandler {
const start = Date.now();
const errCodeOrSession = await this.loginThrottledAndValidate(
loginParams,
getClientRequestIPsInfo({ httpReq: req })
getClientRequestIPsInfo({ httpReq: req, res })
);
const loginResponse =
typeof errCodeOrSession === "string" ?
Expand All @@ -292,30 +308,31 @@ export class AuthHandler {
* query params
* Based on sid names in auth
*/
getSID(localParams: LocalParams | undefined): string | undefined {
if (!localParams) return undefined;
getSID(maybeClientReq: AuthClientRequest | undefined): string | undefined {
if (!maybeClientReq) return undefined;
const { sidKeyName } = this;
if (localParams.socket) {
const { handshake } = localParams.socket;
if (maybeClientReq.socket) {
const { handshake } = maybeClientReq.socket;
const querySid = handshake.auth?.[sidKeyName] || handshake.query?.[sidKeyName];
let rawSid = querySid;
if (!rawSid) {
const cookie_str = localParams.socket.handshake.headers?.cookie;
const cookie_str = maybeClientReq.socket.handshake.headers?.cookie;
const cookie = parseCookieStr(cookie_str);
rawSid = cookie[sidKeyName];
}
return this.validateSid(rawSid);
} else if (localParams.httpReq) {
const [tokenType, base64Token] = localParams.httpReq.headers.authorization?.split(" ") ?? [];
} else {
const [tokenType, base64Token] =
maybeClientReq.httpReq.headers.authorization?.split(" ") ?? [];
let bearerSid: string | undefined;
if (tokenType && base64Token) {
if (tokenType.trim() !== "Bearer") {
throw "Only Bearer Authorization header allowed";
}
bearerSid = Buffer.from(base64Token, "base64").toString();
}
return this.validateSid(bearerSid ?? localParams.httpReq.cookies?.[sidKeyName]);
} else throw "socket OR httpReq missing from localParams";
return this.validateSid(bearerSid ?? maybeClientReq.httpReq.cookies?.[sidKeyName]);
}

function parseCookieStr(cookie_str: string | undefined): any {
if (!cookie_str || typeof cookie_str !== "string") {
Expand All @@ -336,91 +353,16 @@ export class AuthHandler {
/**
* Used for logging
*/
getSIDNoError = (localParams: LocalParams | undefined): string | undefined => {
if (!localParams) return undefined;
getSIDNoError = (clientReq: AuthClientRequest | undefined): string | undefined => {
if (!clientReq) return undefined;
try {
return this.getSID(localParams);
return this.getSID(clientReq);
} catch {
return undefined;
}
};

/**
* For a given sid return the user data if available
*/
async getClientInfo(
localParams: Pick<LocalParams, "socket" | "httpReq">
): Promise<AuthResult<any>> {
/**
* Get cached session if available
*/
const getSession = this.opts.cacheSession?.getSession;
const isSocket = "socket" in localParams;
if (isSocket && getSession && localParams.socket?.__prglCache) {
const { session, user, clientUser } = localParams.socket.__prglCache;
const isValid = this.isNonExpiredSocketSession(localParams.socket, session);
if (isValid) {
return {
sid: session.sid,
user,
clientUser,
};
} else
return {
sid: session.sid,
};
}

/**
* Get sid from request and fetch user data
*/
const authStart = Date.now();
const clientInfo = await this.throttledFunc(async () => {
const { getUser } = this.opts;

if (localParams.httpReq || localParams.socket) {
const sid = this.getSID(localParams);
const clientReq =
localParams.httpReq ? { httpReq: localParams.httpReq } : { socket: localParams.socket! };
let user, clientUser;
if (sid) {
const clientInfo = await getUser(
sid,
this.dbo as any,
this.db,
getClientRequestIPsInfo(clientReq)
);
if (typeof clientInfo === "string") throw clientInfo;
user = clientInfo?.user;
clientUser = clientInfo?.clientUser;
}
if (getSession && isSocket) {
const session = await getSession(sid, this.dbo as any, this.db);
if (session && session.expires && user && clientUser && localParams.socket) {
localParams.socket.__prglCache = {
session,
user,
clientUser,
};
}
}
if (sid) {
return { sid, user, clientUser };
}
}

return {};
}, 5);

await this.prostgles.opts.onLog?.({
type: "auth",
command: "getClientInfo",
duration: Date.now() - authStart,
sid: clientInfo.sid,
socketId: localParams.socket?.id,
});
return clientInfo;
}
getUserFromRequest = getUserFromRequest.bind(this);

isNonExpiredSocketSession = (
socket: PRGLIOSocket,
Expand All @@ -442,8 +384,8 @@ export class AuthHandler {
};

getClientAuth = async (
clientReq: Pick<LocalParams, "socket" | "httpReq">
): Promise<{ auth: AuthSocketSchema; userData: AuthResult }> => {
clientReq: AuthClientRequest
): Promise<{ auth: AuthSocketSchema; userData: AuthResultOrError }> => {
let pathGuard = false;
if (this.opts.expressConfig?.publicRoutes && !this.opts.expressConfig.disableSocketAuthGuard) {
pathGuard = true;
Expand Down Expand Up @@ -473,7 +415,7 @@ export class AuthHandler {
pathname &&
typeof pathname === "string" &&
this.isUserRoute(pathname) &&
!(await this.getClientInfo({ socket }))?.user
!(await this.getUserFromRequest({ socket }))?.user
) {
cb(null, { shouldReload: true });
} else {
Expand All @@ -488,7 +430,7 @@ export class AuthHandler {
}
}

const userData = await this.getClientInfo(clientReq);
const userData = await this.getUserFromRequest(clientReq);
const { email } = this.opts.expressConfig?.registrations ?? {};
const auth: AuthSocketSchema = {
providers: getProviders.bind(this)(),
Expand Down
27 changes: 19 additions & 8 deletions lib/Auth/AuthTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ export type BasicSession = {
/** On expired */
onExpiration: "redirect" | "show_error";
};
export type AuthClientRequest =
| { socket: PRGLIOSocket; httpReq?: undefined }
| { httpReq: ExpressReq; socket?: undefined };

type SocketClientRequest = { socket: PRGLIOSocket; httpReq?: undefined };
type HttpClientRequest = { httpReq: ExpressReq; res: ExpressRes; socket?: undefined };
export type AuthClientRequest = SocketClientRequest | HttpClientRequest;

type ThirdPartyProviders = {
facebook?: Pick<FacebookStrategy, "clientID" | "clientSecret"> & {
Expand Down Expand Up @@ -239,20 +240,30 @@ export type SessionUser<
clientUser: ClientUser;
};

export type AuthResult<SU = SessionUser> =
| AuthFailure["code"]
export type AuthResultWithSID<SU = SessionUser> =
| (SU & { sid: string })
| {
sid?: string | undefined;
user?: undefined;
sessionFields?: undefined;
clientUser?: undefined;
}
| undefined;

export type AuthResult<SU = SessionUser> =
| SU
| {
user?: undefined;
sessionFields?: undefined;
clientUser?: undefined;
sid?: string | undefined;
}
| undefined;
export type AuthResultOrError<SU = SessionUser> = AuthFailure["code"] | AuthResult<SU>;

export type AuthRequestParams<S, SUser extends SessionUser> = {
db: DB;
dbo: DBOFullyTyped<S>;
getUser: () => Promise<AuthResult<SUser>>;
getUser: () => Promise<AuthResultOrError<SUser>>;
};

export type Auth<S = void, SUser extends SessionUser = SessionUser> = {
Expand All @@ -270,7 +281,7 @@ export type Auth<S = void, SUser extends SessionUser = SessionUser> = {
dbo: DBOFullyTyped<S>,
db: DB,
client: AuthClientRequest & LoginClientInfo
) => Awaitable<AuthResult<SUser>>;
) => Awaitable<AuthResultOrError<SUser>>;

/**
* Will setup auth routes
Expand Down
2 changes: 1 addition & 1 deletion lib/Auth/authProviders/setOAuthProviders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export function setOAuthProviders(

app.get(callbackPath, async (req, res) => {
try {
const clientInfo = getClientRequestIPsInfo({ httpReq: req });
const clientInfo = getClientRequestIPsInfo({ httpReq: req, res });
const db = this.db;
const dbo = this.dbo as any;
const args = { provider: providerName, req, res, clientInfo, db, dbo };
Expand Down
2 changes: 1 addition & 1 deletion lib/Auth/endpoints/getConfirmEmailRequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const getConfirmEmailRequestHandler = (
if (!id || typeof id !== "string") {
return res.send({ success: false, code: "something-went-wrong", message: "Invalid code" });
}
const { httpReq, ...clientInfo } = getClientRequestIPsInfo({ httpReq: req });
const { httpReq, ...clientInfo } = getClientRequestIPsInfo({ httpReq: req, res });
await emailAuthConfig.onEmailConfirmation({
confirmationCode: id,
clientInfo,
Expand Down
2 changes: 1 addition & 1 deletion lib/Auth/endpoints/getRegisterRequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const getRegisterRequestHandler = ({
}
}
try {
const { httpReq, ...clientInfo } = getClientRequestIPsInfo({ httpReq: req });
const { httpReq, ...clientInfo } = getClientRequestIPsInfo({ httpReq: req, res });
const { smtp } = emailAuthConfig;
const errCodeOrResult =
emailAuthConfig.signupType === "withPassword" ?
Expand Down
Loading

0 comments on commit 47d9bda

Please sign in to comment.