Skip to content

Commit

Permalink
chore: properly implement oidc
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeidnx committed Nov 12, 2024
1 parent 868103c commit e86c915
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 101 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies {
implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1'
implementation 'com.github.FireMasterK.NewPipeExtractor:NewPipeExtractor:a64e202bb498032e817a702145263590829f3c1d'
implementation 'com.github.FireMasterK:nanojson:9f4af3b739cc13f3d0d9d4b758bbe2b2ae7119d7'
implementation 'com.nimbusds:oauth2-oidc-sdk:11.5'
implementation 'com.nimbusds:oauth2-oidc-sdk:11.20.1'
implementation 'com.fasterxml.jackson.core:jackson-core:2.17.2'
implementation 'com.fasterxml.jackson.core:jackson-annotations:2.17.2'
implementation 'com.fasterxml.jackson.core:jackson-databind:2.17.2'
Expand Down
14 changes: 7 additions & 7 deletions config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ hibernate.connection.password:changeme
#frontend.statusPageUrl:https://kavin.rocks
#frontend.donationUrl:https://kavin.rocks

# Oidc configuration
#oidc.provider.INSERT_HERE.name:INSERT_HERE
#oidc.provider.INSERT_HERE.clientId:INSERT_HERE
#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE
#oidc.provider.INSERT_HERE.authUri:INSERT_HERE
#oidc.provider.INSERT_HERE.tokenUri:INSERT_HERE
#oidc.provider.INSERT_HERE.userinfoUri:INSERT_HERE
# SSO via OIDC
# each provider needs to have these three options specified. <NAME> is the
# friendly name which will be shown to the clients and used in the database.
# If you want to change the name later, you will have to update the database.
# oidc.provider.<NAME>.clientId:<Client_id>
# oidc.provider.<NAME>.clientSecret:<Client_secret>
# oidc.provider.<NAME>.issuer:<Issuer_url>
8 changes: 2 additions & 6 deletions src/main/java/me/kavin/piped/consts/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,13 @@ else if (key.startsWith("oidc.provider")) {
}
});
oidcProviderConfig.forEach((provider, config) -> {
ObjectNode providerNode = frontendProperties.putObject(provider);
OIDC_PROVIDERS.add(new OidcProvider(
getRequiredMapValue(config, "name"),
provider,
getRequiredMapValue(config, "clientId"),
getRequiredMapValue(config, "clientSecret"),
getRequiredMapValue(config, "authUri"),
getRequiredMapValue(config, "tokenUri"),
getRequiredMapValue(config, "userinfoUri")
getRequiredMapValue(config, "issuer")
));
providerNames.add(provider);
config.forEach(providerNode::put);
});
frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART);
frontendProperties.putArray("countries").addAll(
Expand Down
9 changes: 8 additions & 1 deletion src/main/java/me/kavin/piped/server/ServerLauncher.java
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ AsyncServlet mainServlet(Executor executor) {
return switch (function) {
case "login" -> UserHandlers.oidcLoginResponse(provider, request.getQueryParameter("redirect"));
case "callback" -> UserHandlers.oidcCallbackResponse(provider, URI.create(request.getFullUrl()));
case "delete" -> UserHandlers.oidcDeleteResponse(provider, URI.create(request.getFullUrl()));
case "delete" -> UserHandlers.oidcDeleteCallback(provider, URI.create(request.getFullUrl()));
default -> HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`");
};
} catch (Exception e) {
Expand Down Expand Up @@ -491,6 +491,13 @@ AsyncServlet mainServlet(Executor executor) {
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
})).map(GET, "/user/delete", AsyncServlet.ofBlocking(executor, request -> {
try {
var session = request.getQueryParameter("session");
return UserHandlers.oidcDeleteRequest(session);
} catch (Exception e) {
return getErrorResponse(e, request.getPath());
}
})).map(POST, "/logout", AsyncServlet.ofBlocking(executor, request -> {
try {
return getJsonResponse(UserHandlers.logoutResponse(request.getHeader(AUTHORIZATION)), "private");
Expand Down
188 changes: 121 additions & 67 deletions src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
package me.kavin.piped.server.handlers.auth;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.JWTParser;
import com.nimbusds.oauth2.sdk.*;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.openid.connect.sdk.*;
import com.nimbusds.openid.connect.sdk.claims.IDTokenClaimsSet;
import com.nimbusds.openid.connect.sdk.claims.UserInfo;
import io.activej.http.HttpResponse;
import jakarta.persistence.criteria.CriteriaBuilder;
Expand Down Expand Up @@ -131,16 +138,20 @@ public static HttpResponse oidcLoginResponse(OidcProvider provider, String redir
}

URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
OidcData data = new OidcData(redirectUri);
CodeVerifier codeVerifier = new CodeVerifier();
OidcData data = new OidcData(redirectUri, codeVerifier);
String state = data.getState();

PENDING_OIDC.put(state, data);

AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"),
provider.clientID, callback).endpointURI(provider.authUri)
.state(new State(state)).nonce(data.nonce).build();
provider.clientID, callback)
.endpointURI(provider.authUri)
.codeChallenge(codeVerifier, CodeChallengeMethod.S256)
.state(new State(state))
.nonce(data.nonce).build();

if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) {
return HttpResponse.redirect302(oidcRequest.toURI().toString());
Expand All @@ -155,24 +166,25 @@ public static HttpResponse oidcLoginResponse(OidcProvider provider, String redir
"\">here</a></body></html>");
}
public static HttpResponse oidcCallbackResponse(OidcProvider provider, URI requestUri) throws Exception {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);

AuthenticationSuccessResponse sr = parseOidcUri(requestUri);
AuthenticationSuccessResponse authResponse = parseOidcUri(requestUri);

OidcData data = PENDING_OIDC.get(sr.getState().toString());
OidcData data = PENDING_OIDC.get(authResponse.getState().toString());
if (data == null) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent invalid state data. Try again or contact your oidc admin"
);
}

URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback");
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);
AuthorizationCode code = authResponse.getAuthorizationCode();

AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.pkceVerifier);

ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
TokenRequest tokenReq = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tokenReq.toHTTPRequest().send());

com.nimbusds.oauth2.sdk.http.HTTPResponse tokenResponseText = tokenReq.toHTTPRequest().send();
OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tokenResponseText);

if (!tokenResponse.indicatesSuccess()) {
TokenErrorResponse errorResponse = tokenResponse.toErrorResponse();
Expand All @@ -181,11 +193,17 @@ public static HttpResponse oidcCallbackResponse(OidcProvider provider, URI reque

OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse();

if (data.isInvalidNonce((String) successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Try again or contact your oidc admin"
);
}
JWT idToken = JWTParser.parse(successResponse.getOIDCTokens().getIDTokenString());

try {
provider.validator.validate(idToken, data.nonce);
} catch (BadJOSEException e) {
System.out.println("Invalid token received: " + e.toString());
return HttpResponse.ofCode(400).withHtml("Received a bad token. Please try again");
} catch (JOSEException e) {
System.out.println("Token processing error" + e.toString());
return HttpResponse.ofCode(500).withHtml("Internal processing error. Please try again");
}

UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken());
UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send());
Expand All @@ -200,38 +218,87 @@ public static HttpResponse oidcCallbackResponse(OidcProvider provider, URI reque

UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo();


String uid = userInfo.getSubject().toString();
String sub = userInfo.getSubject().toString();
String sessionId;
try (Session s = DatabaseSessionFactory.createSession()) {
// TODO: Add oidc provider to database
String dbName = provider + "-" + uid;
CriteriaBuilder cb = s.getCriteriaBuilder();
CriteriaQuery<User> cr = cb.createQuery(User.class);
Root<User> root = cr.from(User.class);
cr.select(root).where(root.get("username").in(
dbName
));
CriteriaQuery<OidcUserData> cr = cb.createQuery(OidcUserData.class);
Root<OidcUserData> root = cr.from(OidcUserData.class);

User dbuser = s.createQuery(cr).uniqueResult();
cr.select(root).where(root.get("sub").in(sub));

if (dbuser == null) {
User newuser = new User(dbName, "", Set.of());
OidcUserData dbuser = s.createQuery(cr).uniqueResult();

if (dbuser != null) {
sessionId = dbuser.getUser().getSessionId();
} else {
String username = userInfo.getPreferredUsername();
OidcUserData newUser = new OidcUserData(sub, username, provider.name);

var tr = s.beginTransaction();
s.persist(newuser);
s.persist(newUser);
tr.commit();


sessionId = newuser.getSessionId();
} else sessionId = dbuser.getSessionId();
sessionId = newUser.getUser().getSessionId();
}
}
return HttpResponse.redirect302(data.data + "?session=" + sessionId);

}

public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI requestUri) throws Exception {
ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);
public static HttpResponse oidcDeleteRequest(String session) throws Exception {

if (StringUtils.isBlank(session)) {
return HttpResponse.ofCode(400).withHtml("session is a required parameter");
}

OidcProvider provider = null;
try (Session s = DatabaseSessionFactory.createSession()) {

User user = DatabaseHelper.getUserFromSession(session);

if (user == null) {
return HttpResponse.ofCode(400).withHtml("User not found");
}

CriteriaBuilder cb = s.getCriteriaBuilder();
CriteriaQuery<OidcUserData> cr = cb.createQuery(OidcUserData.class);
Root<OidcUserData> root = cr.from(OidcUserData.class);
cr.select(root).where(cb.equal(root.get("user"), user));

OidcUserData oidcUserData = s.createQuery(cr).uniqueResult();

for (OidcProvider test: Constants.OIDC_PROVIDERS) {
if (test.name.equals(oidcUserData.getProvider())) {
provider = test;
}
}
}

if (provider == null) {
return HttpResponse.ofCode(400).withHtml("Invalid user");
}
CodeVerifier pkceVerifier = new CodeVerifier();

URI callback = URI.create(String.format("%s/oidc/%s/delete", Constants.PUBLIC_URL, provider.name));
OidcData data = new OidcData(session + "|" + Instant.now().getEpochSecond(), pkceVerifier);
String state = data.getState();
PENDING_OIDC.put(state, data);

AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"), provider.clientID, callback)
.endpointURI(provider.authUri)
.codeChallenge(pkceVerifier, CodeChallengeMethod.S256)
.state(new State(state))
.nonce(data.nonce)
// This parameter is optional and the idp does't have to honor it.
.maxAge(0)
.build();

return HttpResponse.redirect302(oidcRequest.toURI().toString());
}
public static HttpResponse oidcDeleteCallback(OidcProvider provider, URI requestUri) throws Exception {

AuthenticationSuccessResponse sr = parseOidcUri(requestUri);

Expand All @@ -247,8 +314,10 @@ public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI request

URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/delete");
AuthorizationCode code = sr.getAuthorizationCode();
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback);
AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.pkceVerifier);


ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret);

TokenRequest tokenRequest = new TokenRequest(provider.tokenUri, clientAuth, codeGrant);
TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tokenRequest.toHTTPRequest().send());
Expand All @@ -260,15 +329,26 @@ public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI request

OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse();

JWTClaimsSet claims = successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet();
JWT idToken = JWTParser.parse(successResponse.getOIDCTokens().getIDTokenString());

IDTokenClaimsSet claims;
try {
claims = provider.validator.validate(idToken, data.nonce);
} catch (BadJOSEException e) {
System.out.println("Invalid token received: " + e.toString());
return HttpResponse.ofCode(400).withHtml("Received a bad token. Please try again");
} catch (JOSEException e) {
System.out.println("Token processing error" + e.toString());
return HttpResponse.ofCode(500).withHtml("Internal processing error. Please try again");
}

if (data.isInvalidNonce((String) claims.getClaim("nonce"))) {
return HttpResponse.ofCode(400).withHtml(
"Your oidc provider sent an invalid nonce. Please try again or contact your oidc admin."
);
}

long authTime = (long) claims.getClaim("auth_time");

Long authTime = (Long) claims.getNumberClaim("auth_time");

if (authTime == null) {
return HttpResponse.ofCode(400).withHtml("Couldn't get the `auth_time` claim from the provided id token");
}

if (authTime < start) {
return HttpResponse.ofCode(500).withHtml(
Expand All @@ -277,7 +357,6 @@ public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI request
}

try (Session s = DatabaseSessionFactory.createSession()) {

var tr = s.beginTransaction();
s.remove(DatabaseHelper.getUserFromSession(session));
tr.commit();
Expand All @@ -297,31 +376,6 @@ public static byte[] deleteUserResponse(String session, String pass) throws IOEx

String hash = user.getPassword();

if (hash.isEmpty()) {

CriteriaBuilder cb = s.getCriteriaBuilder();
CriteriaQuery<OidcUserData> cr = cb.createQuery(OidcUserData.class);
Root<OidcUserData> root = cr.from(OidcUserData.class);
cr.select(root).where(cb.equal(root.get("user"), user.getId()));

OidcUserData oidcUserData = s.createQuery(cr).uniqueResult();

//TODO: Get user from oidc table and lookup provider
OidcProvider provider = Constants.OIDC_PROVIDERS.get(0);
URI callback = URI.create(String.format("%s/oidc/%s/delete", Constants.PUBLIC_URL, provider.name));
OidcData data = new OidcData(session + "|" + Instant.now().getEpochSecond());
String state = data.getState();
PENDING_OIDC.put(state, data);

AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder(
new ResponseType("code"),
new Scope("openid"), provider.clientID, callback).endpointURI(provider.authUri)
.state(new State(state)).nonce(data.nonce).maxAge(0).build();


return mapper.writeValueAsBytes(mapper.createObjectNode()
.put("redirect", oidcRequest.toURI().toString()));
}
if (!hashMatch(hash, pass))
ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class DatabaseSessionFactory {

sessionFactory = configuration.addAnnotatedClass(User.class).addAnnotatedClass(Channel.class)
.addAnnotatedClass(Video.class).addAnnotatedClass(PubSub.class).addAnnotatedClass(Playlist.class)
.addAnnotatedClass(PlaylistVideo.class).addAnnotatedClass(UnauthenticatedSubscription.class).buildSessionFactory();
.addAnnotatedClass(PlaylistVideo.class).addAnnotatedClass(UnauthenticatedSubscription.class).addAnnotatedClass(OidcUserData.class).buildSessionFactory();
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit e86c915

Please sign in to comment.