Skip to content

Commit

Permalink
improve auth types
Browse files Browse the repository at this point in the history
  • Loading branch information
prostgles committed Dec 22, 2024
1 parent f258a26 commit b791e8c
Show file tree
Hide file tree
Showing 22 changed files with 305 additions and 258 deletions.
41 changes: 27 additions & 14 deletions lib/Auth/AuthHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { DB, DBHandlerServer, Prostgles } from "../Prostgles";
import {
Auth,
AuthClientRequest,
AuthResultOrError,
AuthResult,
AuthResultWithSID,
BasicSession,
ExpressReq,
Expand All @@ -26,7 +26,7 @@ import { getProviders } from "./setAuthProviders";
import { setupAuthRoutes } from "./setupAuthRoutes";
import { getClientRequestIPsInfo } from "./utils/getClientRequestIPsInfo";
import { getReturnUrl } from "./utils/getReturnUrl";
import { getUserFromRequest } from "./utils/getUserFromRequest";
import { getSidAndUserFromRequest } from "./utils/getSidAndUserFromRequest";

export { getClientRequestIPsInfo };
export const HTTP_FAIL_CODES = {
Expand Down Expand Up @@ -58,8 +58,8 @@ export const AUTH_ROUTES_AND_PARAMS = {
} as const;

export class AuthHandler {
protected prostgles: Prostgles;
protected opts: Auth;
protected readonly prostgles: Prostgles;
protected readonly opts: Auth;
dbo: DBHandlerServer;
db: DB;

Expand Down Expand Up @@ -147,7 +147,7 @@ export class AuthHandler {

getUserAndHandleError = async (localParams: AuthClientRequest): Promise<AuthResultWithSID> => {
const sid = this.getSID(localParams);
if (!sid) return undefined;
if (!sid) return { sid };
const handlerError = (code: AuthFailure["code"]) => {
if (localParams.httpReq) {
localParams.res
Expand All @@ -160,7 +160,7 @@ export class AuthHandler {
const userOrErrorCode = await this.throttledFunc(async () => {
return this.opts.getUser(
this.validateSid(sid),
this.dbo as any,
this.dbo as DBOFullyTyped,
this.db,
getClientRequestIPsInfo(localParams)
);
Expand All @@ -169,8 +169,12 @@ export class AuthHandler {
if (typeof userOrErrorCode === "string") {
return handlerError(userOrErrorCode);
}

return { sid, ...userOrErrorCode };
if (sid && userOrErrorCode?.user) {
return { sid, ...userOrErrorCode };
}
return {
sid,
};
} catch (_err) {
return handlerError("server-error");
}
Expand Down Expand Up @@ -362,7 +366,13 @@ export class AuthHandler {
}
};

getUserFromRequest = getUserFromRequest.bind(this);
getUserFromRequest = async (clientReq: AuthClientRequest): Promise<AuthResult> => {
const sidAndUser = await this.getSidAndUserFromRequest(clientReq);
if (sidAndUser.sid && sidAndUser.user) {
return sidAndUser;
}
};
getSidAndUserFromRequest = getSidAndUserFromRequest.bind(this);

isNonExpiredSocketSession = (
socket: PRGLIOSocket,
Expand All @@ -385,12 +395,15 @@ export class AuthHandler {

getClientAuth = async (
clientReq: AuthClientRequest
): Promise<{ auth: AuthSocketSchema; userData: AuthResultOrError }> => {
): Promise<{ auth: AuthSocketSchema; userData: AuthResultWithSID }> => {
let pathGuard = false;
if (this.opts.expressConfig?.publicRoutes && !this.opts.expressConfig.disableSocketAuthGuard) {
pathGuard = true;

if ("socket" in clientReq && clientReq.socket) {
/**
* Due to SPA nature of some clients, we need to check if the connected client ends up on a protected route
*/
if (clientReq.socket) {
const { socket } = clientReq;
socket.removeAllListeners(CHANNELS.AUTHGUARD);
socket.on(
Expand All @@ -415,7 +428,7 @@ export class AuthHandler {
pathname &&
typeof pathname === "string" &&
this.isUserRoute(pathname) &&
!(await this.getUserFromRequest({ socket }))?.user
!(await this.getUserFromRequest({ socket }))
) {
cb(null, { shouldReload: true });
} else {
Expand All @@ -430,15 +443,15 @@ export class AuthHandler {
}
}

const userData = await this.getUserFromRequest(clientReq);
const userData = await this.getSidAndUserFromRequest(clientReq);
const { email } = this.opts.expressConfig?.registrations ?? {};
const auth: AuthSocketSchema = {
providers: getProviders.bind(this)(),
register: email && {
type: email.signupType,
url: AUTH_ROUTES_AND_PARAMS.emailRegistration,
},
user: userData?.clientUser,
user: userData.clientUser,
loginType: email?.signupType ?? "withPassword",
pathGuard,
};
Expand Down
28 changes: 12 additions & 16 deletions lib/Auth/AuthTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,21 +243,13 @@ export type SessionUser<
export type AuthResultWithSID<SU = SessionUser> =
| (SU & { sid: string })
| {
sid?: string | undefined;
user?: undefined;
sessionFields?: undefined;
clientUser?: undefined;
}
| undefined;
sid: string | undefined;
user?: never;
sessionFields?: never;
clientUser?: never;
};

export type AuthResult<SU = SessionUser> =
| SU
| {
user?: undefined;
sessionFields?: undefined;
clientUser?: undefined;
}
| undefined;
export type AuthResult<SU = SessionUser> = SU | undefined;
export type AuthResultOrError<SU = SessionUser> = AuthFailure["code"] | AuthResult<SU>;

export type AuthRequestParams<S, SUser extends SessionUser> = {
Expand All @@ -274,7 +266,11 @@ export type Auth<S = void, SUser extends SessionUser = SessionUser> = {
sidKeyName?: string;

/**
* undefined sid is allowed to enable public users
* Used in:
* - WS AUTHGUARD - allows connected SPA client to check if on protected route and needs to reload to ne redirected to login
* - PublishParams - userData and/or sid (in testing) are passed to the publish function
* - auth.expressConfig.use - express middleware to get user data and
* undefined sid is allowed to enable public users
*/
getUser: (
sid: string | undefined,
Expand Down Expand Up @@ -328,7 +324,7 @@ export type LoginParams =
| ({ type: "username" } & AuthRequest.LoginData)
| ({ type: "provider" } & AuthProviderUserData);

type ExpressConfig<S, SUser extends SessionUser> = {
export type ExpressConfig<S, SUser extends SessionUser> = {
/**
* Express app instance. If provided Prostgles will attempt to set sidKeyName to user cookie
*/
Expand Down
13 changes: 5 additions & 8 deletions lib/Auth/authProviders/setEmailProvider.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import e from "express";
import { AUTH_ROUTES_AND_PARAMS, AuthHandler } from "../AuthHandler";
import { getConfirmEmailRequestHandler } from "../endpoints/getConfirmEmailRequestHandler";
import { getRegisterRequestHandler } from "../endpoints/getRegisterRequestHandler";
import { AuthHandler } from "../AuthHandler";
import { setConfirmEmailRequestHandler } from "../endpoints/setConfirmEmailRequestHandler";
import { setRegisterRequestHandler } from "../endpoints/setRegisterRequestHandler";
import { getOrSetTransporter } from "../sendEmail";
import { checkDmarc } from "../utils/checkDmarc";

Expand All @@ -18,12 +18,9 @@ export async function setEmailProvider(this: AuthHandler, app: e.Express) {
*/
getOrSetTransporter(email.smtp);

app.post(
AUTH_ROUTES_AND_PARAMS.emailRegistration,
getRegisterRequestHandler({ email, websiteUrl })
);
setRegisterRequestHandler({ email, websiteUrl }, app);

if (email.signupType === "withPassword") {
app.get(AUTH_ROUTES_AND_PARAMS.confirmEmailExpressRoute, getConfirmEmailRequestHandler(email));
setConfirmEmailRequestHandler(email, app);
}
}
107 changes: 107 additions & 0 deletions lib/Auth/endpoints/setCatchAllRequestHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import e, { RequestHandler, Request, Response } from "express";
import { AuthClientRequest } from "../AuthTypes";
import { AUTH_ROUTES_AND_PARAMS, AuthHandler, HTTP_FAIL_CODES } from "../AuthHandler";
import { getReturnUrl } from "../utils/getReturnUrl";
import { DBOFullyTyped } from "../../DBSchemaBuilder";

export function setCatchAllRequestHandler(this: AuthHandler, app: e.Express) {
const onLogout = async (req: Request, res: Response) => {
const sid = this.validateSid(req.cookies?.[this.sidKeyName]);
if (sid) {
try {
await this.throttledFunc(() => {
return this.opts.logout?.(req.cookies?.[this.sidKeyName], this.dbo as any, this.db);
});
} catch (err) {
console.error(err);
}
}
res.redirect("/");
};

const requestHandler: RequestHandler = async (req, res, next) => {
const { onGetRequestOK } = this.opts.expressConfig ?? {};
const clientReq: AuthClientRequest = { httpReq: req, res };
const getUser = async () => {
const userOrCode = await this.getUserAndHandleError(clientReq);
if (typeof userOrCode === "string") {
res.status(HTTP_FAIL_CODES.BAD_REQUEST).json({ success: false, code: userOrCode });
throw userOrCode;
}
return userOrCode;
};
const isLoggedInUser = async () => {
const userInfo = await getUser();
return !!userInfo.user;
};
if (this.prostgles.restApi) {
if (
Object.values(this.prostgles.restApi.routes).some((restRoute) =>
this.matchesRoute(restRoute.split("/:")[0], req.path)
)
) {
next();
return;
}
}
try {
const returnURL = getReturnUrl(req);

if (this.matchesRoute(AUTH_ROUTES_AND_PARAMS.logoutGetPath, req.path)) {
await onLogout(req, res);
return;
}

if (this.matchesRoute(AUTH_ROUTES_AND_PARAMS.loginWithProvider, req.path)) {
next();
return;
}
/**
* Requesting a User route
*/
if (this.isUserRoute(req.path)) {
/* Check auth. Redirect to login if unauthorized */
const u = await isLoggedInUser();
if (!u) {
res.redirect(
`${AUTH_ROUTES_AND_PARAMS.login}?returnURL=${encodeURIComponent(req.originalUrl)}`
);
return;
}

/* If authorized and going to returnUrl then redirect. Otherwise serve file */
} else if (returnURL && (await isLoggedInUser())) {
res.redirect(returnURL);
return;

/** If Logged in and requesting login then redirect to main page */
} else if (
this.matchesRoute(AUTH_ROUTES_AND_PARAMS.login, req.path) &&
(await isLoggedInUser())
) {
res.redirect("/");
return;
}

onGetRequestOK?.(req, res, {
getUser,
dbo: this.dbo as DBOFullyTyped,
db: this.db,
});
} catch (error) {
console.error(error);
const errorMessage =
typeof error === "string" ? error
: error instanceof Error ? error.message
: "";
res.status(HTTP_FAIL_CODES.BAD_REQUEST).json({
error:
"Something went wrong when processing your request" +
(errorMessage ? ": " + errorMessage : ""),
});
}
};

app.get(AUTH_ROUTES_AND_PARAMS.catchAll, requestHandler);
return requestHandler;
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import type { Request, Response } from "express";
import { AuthResponse } from "prostgles-types";
import { HTTP_FAIL_CODES } from "../AuthHandler";
import { AUTH_ROUTES_AND_PARAMS, HTTP_FAIL_CODES } from "../AuthHandler";
import { AuthRegistrationConfig } from "../AuthTypes";
import { getClientRequestIPsInfo } from "../utils/getClientRequestIPsInfo";
import e from "express";

export const getConfirmEmailRequestHandler = (
export const setConfirmEmailRequestHandler = (
emailAuthConfig: Extract<
Required<AuthRegistrationConfig<void>>["email"],
{ signupType: "withPassword" }
>
>,
app: e.Express
) => {
return async (
const requestHandler = async (
req: Request,
res: Response<
| AuthResponse.PasswordRegisterSuccess
Expand All @@ -36,4 +38,6 @@ export const getConfirmEmailRequestHandler = (
.json({ success: false, code: "server-error", message: "Failed to confirm email" });
}
};

app.get(AUTH_ROUTES_AND_PARAMS.confirmEmailExpressRoute, requestHandler);
};
18 changes: 18 additions & 0 deletions lib/Auth/endpoints/setLoginRequestHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import e from "express";
import { AUTH_ROUTES_AND_PARAMS, AuthHandler, HTTP_FAIL_CODES } from "../AuthHandler";
import { LoginParams } from "../AuthTypes";

export function setLoginRequestHandler(this: AuthHandler, app: e.Express) {
app.post(AUTH_ROUTES_AND_PARAMS.login, async (req, res) => {
try {
const loginParams: LoginParams = {
type: "username",
...req.body,
};

await this.loginThrottledAndSetCookie(req, res, loginParams);
} catch (error) {
res.status(HTTP_FAIL_CODES.BAD_REQUEST).json({ error });
}
});
}
Loading

0 comments on commit b791e8c

Please sign in to comment.