diff --git a/build.gradle b/build.gradle index 7b7f3223..32b1a006 100644 --- a/build.gradle +++ b/build.gradle @@ -18,6 +18,7 @@ dependencies { implementation 'org.bouncycastle:bcprov-jdk18on:1.78.1' implementation 'com.github.FireMasterK.NewPipeExtractor:NewPipeExtractor:a64e202bb498032e817a702145263590829f3c1d' implementation 'com.github.FireMasterK:nanojson:9f4af3b739cc13f3d0d9d4b758bbe2b2ae7119d7' + implementation 'com.nimbusds:oauth2-oidc-sdk:11.20.1' implementation 'com.fasterxml.jackson.core:jackson-core:2.17.2' implementation 'com.fasterxml.jackson.core:jackson-annotations:2.17.2' implementation 'com.fasterxml.jackson.core:jackson-databind:2.17.2' diff --git a/config.properties b/config.properties index ff8592f7..56f9d306 100644 --- a/config.properties +++ b/config.properties @@ -89,3 +89,17 @@ hibernate.connection.password:changeme # Frontend configuration #frontend.statusPageUrl:https://kavin.rocks #frontend.donationUrl:https://kavin.rocks + +# SSO via OIDC +# Each provider needs to have these three options specified. is the +# friendly name which will be shown to the clients and used in the database. +# If you want to change the name later, you will have to update the database. +#oidc.provider..clientId:example_piped_client_id +#oidc.provider..clientSecret:example_piped_client_secret +#oidc.provider..issuer:https://idm.example.com + +# Ask the provider to re-authenticate the user when account deletion is +# requested. This field is optional and you should only set this to false +# if your provider doesn't support the max_age parameter. You will know when +# trying to delete an account. +#oidc.provider..sendMaxAge = true diff --git a/src/main/java/me/kavin/piped/Main.java b/src/main/java/me/kavin/piped/Main.java index ab710fa9..2c34df96 100644 --- a/src/main/java/me/kavin/piped/Main.java +++ b/src/main/java/me/kavin/piped/Main.java @@ -9,6 +9,7 @@ import me.kavin.piped.utils.*; import me.kavin.piped.utils.matrix.SyncRunner; import me.kavin.piped.utils.obj.MatrixHelper; +import me.kavin.piped.utils.obj.db.OidcData; import me.kavin.piped.utils.obj.db.PlaylistVideo; import me.kavin.piped.utils.obj.db.PubSub; import me.kavin.piped.utils.obj.db.Video; @@ -253,5 +254,32 @@ public void run() { } }, 0, TimeUnit.MINUTES.toMillis(60)); + new Timer().scheduleAtFixedRate(new TimerTask() { + @Override + public void run() { + try (StatelessSession s = DatabaseSessionFactory.createStatelessSession()) { + + var cb = s.getCriteriaBuilder(); + var cd = cb.createCriteriaDelete(OidcData.class); + var root = cd.from(OidcData.class); + cd.where(cb.lessThan(root.get("start"), System.currentTimeMillis() - TimeUnit.MINUTES.toMillis(3))); + + var tr = s.beginTransaction(); + + var query = s.createMutationQuery(cd); + + int affected = query.executeUpdate(); + + tr.commit(); + + if (affected > 0) { + System.out.printf("Cleanup: Removed %o orphaned oidc logins%n", affected); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + }, 0, TimeUnit.MINUTES.toMillis(5)); + } } diff --git a/src/main/java/me/kavin/piped/consts/Constants.java b/src/main/java/me/kavin/piped/consts/Constants.java index 2a2624d9..0675e160 100644 --- a/src/main/java/me/kavin/piped/consts/Constants.java +++ b/src/main/java/me/kavin/piped/consts/Constants.java @@ -3,11 +3,15 @@ import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.nimbusds.oauth2.sdk.GeneralException; import io.minio.MinioClient; import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; import me.kavin.piped.utils.PageMixin; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.resp.ListLinkHandlerMixin; import okhttp3.OkHttpClient; import okhttp3.brotli.BrotliInterceptor; @@ -21,8 +25,10 @@ import java.io.File; import java.io.FileReader; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Properties; @@ -103,6 +109,7 @@ public class Constants { public static String YOUTUBE_COUNTRY; public static final String VERSION; + public static final List OIDC_PROVIDERS; public static final ObjectMapper mapper = JsonMapper.builder() .addMixIn(Page.class, PageMixin.class) @@ -170,12 +177,39 @@ public class Constants { MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org"); MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN"); GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL"); + + OIDC_PROVIDERS = new ObjectArrayList<>(); + + Map> oidcProviderConfig = new Object2ObjectOpenHashMap<>(); + ArrayNode providerNames = frontendProperties.putArray("oidcProviders"); prop.forEach((_key, _value) -> { String key = String.valueOf(_key), value = String.valueOf(_value); if (key.startsWith("hibernate")) hibernateProperties.put(key, value); else if (key.startsWith("frontend.")) frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value); + else if (key.startsWith("oidc.provider")) { + String[] split = key.split("\\."); + if (split.length != 4) return; + oidcProviderConfig + .computeIfAbsent(split[2], k -> new Object2ObjectOpenHashMap<>()) + .put(split[3], value); + } + }); + oidcProviderConfig.forEach((provider, config) -> { + try { + OIDC_PROVIDERS.add(new OidcProvider( + provider, + getRequiredMapValue(config, "clientId"), + getRequiredMapValue(config, "clientSecret"), + getRequiredMapValue(config, "issuer"), + getOptionalMapValue(config, "sendMaxAge", "true") + )); + } catch (GeneralException | IOException e) { + System.err.println("Failed to get configuration for '" + provider + "': " + e); + System.exit(1); + } + providerNames.add(provider); }); frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART); frontendProperties.putArray("countries").addAll( @@ -220,4 +254,21 @@ private static String getProperty(final Properties prop, String key, String def) return prop.getProperty(key, def); } + + private static String getRequiredMapValue(final Map map, Object key) { + String value = map.get(key); + if (StringUtils.isBlank(value)) { + System.err.println("Missing '" + key + "' in sub-configuration"); + System.exit(1); + } + return value; + } + + private static String getOptionalMapValue(final Map map, Object key, String def) { + String value = map.get(key); + if (StringUtils.isBlank(value)) { + return def; + } + return value; + } } diff --git a/src/main/java/me/kavin/piped/server/ServerLauncher.java b/src/main/java/me/kavin/piped/server/ServerLauncher.java index 8c7fe63b..f48d5073 100644 --- a/src/main/java/me/kavin/piped/server/ServerLauncher.java +++ b/src/main/java/me/kavin/piped/server/ServerLauncher.java @@ -15,7 +15,9 @@ import me.kavin.piped.server.handlers.auth.FeedHandlers; import me.kavin.piped.server.handlers.auth.StorageHandlers; import me.kavin.piped.server.handlers.auth.UserHandlers; +import me.kavin.piped.utils.ErrorResponse; import me.kavin.piped.utils.*; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.resp.*; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; @@ -23,6 +25,7 @@ import org.jetbrains.annotations.NotNull; import java.net.InetSocketAddress; +import java.net.URI; import java.util.Objects; import java.util.concurrent.Executor; @@ -258,6 +261,22 @@ AsyncServlet mainServlet(Executor executor) { } catch (Exception e) { return getErrorResponse(e, request.getPath()); } + })).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> { + try { + String function = request.getPathParameter("function"); + OidcProvider provider = getOidcProvider(request.getPathParameter("provider")); + if (provider == null) + return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server"); + + return switch (function) { + case "login" -> UserHandlers.oidcLoginRequest(provider, request.getQueryParameter("redirect")); + case "callback" -> UserHandlers.oidcLoginCallback(provider, URI.create(request.getFullUrl())); + case "delete" -> UserHandlers.oidcDeleteCallback(provider, URI.create(request.getFullUrl())); + default -> HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`"); + }; + } catch (Exception e) { + return getErrorResponse(e, request.getPath()); + } })).map(POST, "/login", AsyncServlet.ofBlocking(executor, request -> { try { LoginRequest body = mapper.readValue(request.loadBody().getResult().asArray(), @@ -469,6 +488,14 @@ AsyncServlet mainServlet(Executor executor) { } catch (Exception e) { return getErrorResponse(e, request.getPath()); } + })).map(GET, "/user/delete", AsyncServlet.ofBlocking(executor, request -> { + try { + String session = request.getQueryParameter("session"); + String redirect = request.getQueryParameter("redirect"); + return UserHandlers.oidcDeleteRequest(session, redirect); + } catch (Exception e) { + return getErrorResponse(e, request.getPath()); + } })).map(POST, "/logout", AsyncServlet.ofBlocking(executor, request -> { try { return getJsonResponse(UserHandlers.logoutResponse(request.getHeader(AUTHORIZATION)), "private"); @@ -506,6 +533,15 @@ AsyncServlet mainServlet(Executor executor) { return new CustomServletDecorator(router); } + private static OidcProvider getOidcProvider(String provider) { + for (int i = 0; i < Constants.OIDC_PROVIDERS.size(); i++) { + OidcProvider curr = Constants.OIDC_PROVIDERS.get(i); + if (curr == null || !curr.name.equals(provider)) continue; + return curr; + } + return null; + } + private static String[] getArray(String s) { if (s == null) { diff --git a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java index 3e0bfe58..1412ce84 100644 --- a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java +++ b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java @@ -1,6 +1,20 @@ package me.kavin.piped.server.handlers.auth; import com.fasterxml.jackson.core.JsonProcessingException; +import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.proc.BadJOSEException; +import com.nimbusds.jwt.JWT; +import com.nimbusds.jwt.JWTParser; +import com.nimbusds.oauth2.sdk.*; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.pkce.CodeChallengeMethod; +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; +import com.nimbusds.openid.connect.sdk.*; +import com.nimbusds.openid.connect.sdk.claims.IDTokenClaimsSet; +import com.nimbusds.openid.connect.sdk.claims.UserInfo; +import io.activej.http.HttpResponse; import jakarta.persistence.criteria.CriteriaBuilder; import jakarta.persistence.criteria.CriteriaQuery; import jakarta.persistence.criteria.Root; @@ -9,9 +23,13 @@ import me.kavin.piped.utils.DatabaseSessionFactory; import me.kavin.piped.utils.ExceptionHandler; import me.kavin.piped.utils.RequestUtils; +import me.kavin.piped.utils.obj.OidcProvider; +import me.kavin.piped.utils.obj.db.OidcData; +import me.kavin.piped.utils.obj.db.OidcUserData; import me.kavin.piped.utils.obj.db.User; import me.kavin.piped.utils.resp.*; import org.apache.commons.codec.digest.DigestUtils; +import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.hibernate.Session; import org.hibernate.StatelessSession; @@ -19,6 +37,8 @@ import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import java.io.IOException; +import java.net.URI; +import java.time.Instant; import java.util.Set; import java.util.UUID; @@ -76,6 +96,9 @@ public static byte[] registerResponse(String user, String pass) throws Exception } private static boolean hashMatch(String hash, String pass) { + if (hash.isBlank()) { + return false; + } return hash.startsWith("$argon2") ? argon2PasswordEncoder.matches(pass, hash) : bcryptPasswordEncoder.matches(pass, hash); @@ -109,10 +132,277 @@ public static byte[] loginResponse(String user, String pass) } } - public static byte[] deleteUserResponse(String session, String pass) throws IOException { + public static HttpResponse oidcLoginRequest(OidcProvider provider, String redirectUri) throws Exception { + if (StringUtils.isBlank(redirectUri)) { + return HttpResponse.ofCode(400).withHtml("redirect is a required parameter"); + } + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); + CodeVerifier codeVerifier = new CodeVerifier(); + OidcData data = new OidcData(redirectUri, codeVerifier); + String state = data.getState(); + + DatabaseHelper.setOidcData(data); + + AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder( + new ResponseType("code"), + new Scope("openid"), + provider.clientID, callback + ) + .endpointURI(provider.authUri) + .codeChallenge(codeVerifier, CodeChallengeMethod.S256) + .state(new State(state)) + .nonce(data.getOidNonce()).build(); + + if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) { + return HttpResponse.redirect302(oidcRequest.toURI().toString()); + } + return HttpResponse.ok200().withHtml( + "" + + "

Warning:

You are trying to give
" +
+                        redirectUri +
+                        "
access to your Piped account. If you wish to continue click " + + "here"); + } + + public static HttpResponse oidcLoginCallback(OidcProvider provider, URI requestUri) throws Exception { + AuthenticationSuccessResponse authResponse = parseOidcUri(requestUri); + + OidcData data = DatabaseHelper.getOidcData(authResponse.getState().toString()); + if (data == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent invalid state data. Try again or contact your oidc admin" + ); + } + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); + AuthorizationCode code = authResponse.getAuthorizationCode(); + + if (code == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid code. Try again or contact your oidc admin" + ); + } + + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.getOidVerifier()); + + ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); + TokenRequest tokenReq = new TokenRequest.Builder(provider.tokenUri, clientAuth, codeGrant).build(); - if (StringUtils.isBlank(session) || StringUtils.isBlank(pass)) - ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters")); + com.nimbusds.oauth2.sdk.http.HTTPResponse tokenResponseText = tokenReq.toHTTPRequest().send(); + OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tokenResponseText); + + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); + } + + OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse(); + + JWT idToken = JWTParser.parse(successResponse.getOIDCTokens().getIDTokenString()); + + try { + provider.validator.validate(idToken, data.getOidNonce()); + } catch (BadJOSEException e) { + System.err.println("Invalid token received: " + e); + return HttpResponse.ofCode(400).withHtml("Received a bad token. Please try again"); + } catch (JOSEException e) { + System.err.println("Token processing error: " + e); + return HttpResponse.ofCode(500).withHtml("Internal processing error. Please try again"); + } + + UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken()); + UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send()); + + if (!userInfoResponse.indicatesSuccess()) { + return HttpResponse.ofCode(500).withHtml( + "The userinfo endpoint returned an error. Please try again or contact your oidc admin\n\n" + + userInfoResponse.toErrorResponse().getErrorObject().getDescription()); + } + + UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo(); + + String sub = userInfo.getSubject().toString(); + String sessionId; + try (Session s = DatabaseSessionFactory.createSession()) { + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(OidcUserData.class); + Root root = cr.from(OidcUserData.class); + + cr.select(root).where(root.get("sub").in(sub)); + + OidcUserData dbuser = s.createQuery(cr).uniqueResult(); + + if (dbuser != null) { + sessionId = dbuser.getUser().getSessionId(); + } else { + OidcUserData newUser = new OidcUserData(sub, RandomStringUtils.randomAlphabetic(24), provider.name); + + var tr = s.beginTransaction(); + s.persist(newUser); + tr.commit(); + + sessionId = newUser.getUser().getSessionId(); + } + } + return HttpResponse.redirect302(data.data + "?session=" + sessionId); + } + + public static HttpResponse oidcDeleteRequest(String session, String redirect) throws Exception { + + if (StringUtils.isBlank(session)) { + return HttpResponse.ofCode(400).withHtml("session is a required parameter"); + } + + if (StringUtils.isBlank(redirect)) { + return HttpResponse.ofCode(400).withHtml("redirect is a required parameter"); + } + + OidcProvider provider = null; + + try (Session s = DatabaseSessionFactory.createSession()) { + User user = DatabaseHelper.getUserFromSession(session); + + if (user == null) { + return HttpResponse.ofCode(400).withHtml("User not found"); + } + + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(OidcUserData.class); + Root root = cr.from(OidcUserData.class); + cr.select(root).where(cb.equal(root.get("user"), user)); + + OidcUserData oidcUserData = s.createQuery(cr).uniqueResult(); + + if (oidcUserData == null) { + return HttpResponse.ofCode(400).withHtml("User doesn't have an oidc account"); + } + + for (OidcProvider oidcProvider : Constants.OIDC_PROVIDERS) { + if (oidcProvider.name.equals(oidcUserData.getProvider())) { + provider = oidcProvider; + break; + } + } + } + + if (provider == null) { + return HttpResponse.ofCode(400).withHtml("Invalid user"); + } + + CodeVerifier pkceVerifier = new CodeVerifier(); + + URI callback = URI.create(String.format("%s/oidc/%s/delete", Constants.PUBLIC_URL, provider.name)); + OidcData data = new OidcData(session + "|" + redirect, pkceVerifier); + String state = data.getState(); + + DatabaseHelper.setOidcData(data); + + com.nimbusds.openid.connect.sdk.AuthenticationRequest.Builder oidcRequestBuilder = new AuthenticationRequest.Builder( + new ResponseType("code"), + new Scope("openid"), provider.clientID, callback + ) + .endpointURI(provider.authUri) + .codeChallenge(pkceVerifier, CodeChallengeMethod.S256) + .state(new State(state)) + .nonce(data.getOidNonce()); + + if (provider.sendMaxAge) { + // This parameter is optional and the idp doesn't have to honor it. + oidcRequestBuilder.maxAge(0); + } + + return HttpResponse.redirect302(oidcRequestBuilder.build().toURI().toString()); + } + + public static HttpResponse oidcDeleteCallback(OidcProvider provider, URI requestUri) throws Exception { + + AuthenticationSuccessResponse sr = parseOidcUri(requestUri); + + OidcData data = DatabaseHelper.getOidcData(sr.getState().toString()); + + if (data == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent invalid state data. Try again or contact your oidc admin" + ); + } + + String redirect = data.data.split("\\|")[1]; + String session = data.data.split("\\|")[0]; + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/delete"); + AuthorizationCode code = sr.getAuthorizationCode(); + + if (code == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid code. Try again or contact your oidc admin" + ); + } + + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback, data.getOidVerifier()); + + ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); + + TokenRequest tokenRequest = new TokenRequest.Builder(provider.tokenUri, clientAuth, codeGrant).build(); + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tokenRequest.toHTTPRequest().send()); + + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); + } + + OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse(); + + JWT idToken = JWTParser.parse(successResponse.getOIDCTokens().getIDTokenString()); + + IDTokenClaimsSet claims; + try { + claims = provider.validator.validate(idToken, data.getOidNonce()); + } catch (BadJOSEException e) { + System.err.println("Invalid token received: " + e); + return HttpResponse.ofCode(400).withHtml("Received a bad token. Please try again"); + } catch (JOSEException e) { + System.err.println("Token processing error: " + e); + return HttpResponse.ofCode(500).withHtml("Internal processing error. Please try again"); + } + + if (provider.sendMaxAge) { + Long authTime = (Long) claims.getNumberClaim("auth_time"); + + if (authTime == null) { + return HttpResponse.ofCode(400).withHtml("Couldn't get the `auth_time` claim from the provided id token"); + } + + if (authTime <= data.start) { + return HttpResponse.ofCode(500).withHtml( + "Your oidc provider didn't verify your identity. Please try again or contact your oidc admin." + ); + } + } + + try (Session s = DatabaseSessionFactory.createSession()) { + var tr = s.beginTransaction(); + + User toDelete = DatabaseHelper.getUserFromSession(session); + + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(OidcUserData.class); + Root root = cr.from(OidcUserData.class); + + cr.select(root).where(cb.equal(root.get("user"), toDelete)); + + s.remove(s.createQuery(cr).uniqueResult()); + tr.commit(); + } + + return HttpResponse.redirect302(redirect + "?deleted=true"); + } + + public static byte[] deleteUserResponse(String session, String pass) throws IOException { + if (StringUtils.isBlank(session)) + ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter")); try (Session s = DatabaseSessionFactory.createSession()) { User user = DatabaseHelper.getUserFromSession(session); @@ -151,4 +441,14 @@ public static byte[] logoutResponse(String session) throws JsonProcessingExcepti return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); } + + private static AuthenticationSuccessResponse parseOidcUri(URI uri) throws Exception { + AuthenticationResponse response = AuthenticationResponseParser.parse(uri); + + if (response instanceof AuthenticationErrorResponse) { + System.err.println(response.toErrorResponse().getErrorObject()); + throw new Exception(response.toErrorResponse().getErrorObject().toString()); + } + return response.toSuccessResponse(); + } } diff --git a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java index a4af6272..b7b9b151 100644 --- a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java +++ b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java @@ -8,6 +8,7 @@ import me.kavin.piped.utils.obj.db.*; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.exception.ExceptionUtils; +import org.hibernate.Session; import org.hibernate.SharedSessionContract; import org.hibernate.StatelessSession; import org.schabi.newpipe.extractor.channel.ChannelInfo; @@ -236,4 +237,34 @@ public static Channel saveChannel(String channelId) { return channel; } + + public static void setOidcData(OidcData data) { + try (Session s = DatabaseSessionFactory.createSession()) { + var tr = s.beginTransaction(); + s.persist(data); + tr.commit(); + } + } + + public static OidcData getOidcData(String state) { + try (Session s = DatabaseSessionFactory.createSession()) { + + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(OidcData.class); + Root root = cr.from(OidcData.class); + cr.select(root).where(cb.equal(root.get("state"), state)); + + OidcData data = s.createQuery(cr).uniqueResult(); + + if (data == null){ + return null; + } + + var tr = s.beginTransaction(); + s.remove(data); + tr.commit(); + + return data; + } + } } diff --git a/src/main/java/me/kavin/piped/utils/DatabaseSessionFactory.java b/src/main/java/me/kavin/piped/utils/DatabaseSessionFactory.java index edc60892..f333cc2b 100644 --- a/src/main/java/me/kavin/piped/utils/DatabaseSessionFactory.java +++ b/src/main/java/me/kavin/piped/utils/DatabaseSessionFactory.java @@ -20,7 +20,7 @@ public class DatabaseSessionFactory { sessionFactory = configuration.addAnnotatedClass(User.class).addAnnotatedClass(Channel.class) .addAnnotatedClass(Video.class).addAnnotatedClass(PubSub.class).addAnnotatedClass(Playlist.class) - .addAnnotatedClass(PlaylistVideo.class).addAnnotatedClass(UnauthenticatedSubscription.class).buildSessionFactory(); + .addAnnotatedClass(PlaylistVideo.class).addAnnotatedClass(UnauthenticatedSubscription.class).addAnnotatedClass(OidcUserData.class).addAnnotatedClass(OidcData.class).buildSessionFactory(); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java new file mode 100644 index 00000000..89fb38fd --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java @@ -0,0 +1,37 @@ +package me.kavin.piped.utils.obj; + +import com.nimbusds.oauth2.sdk.GeneralException; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.Issuer; +import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata; +import com.nimbusds.openid.connect.sdk.validators.IDTokenValidator; + +import java.io.IOException; +import java.net.URI; + +public class OidcProvider { + + public final String name; + public final ClientID clientID; + public final Secret clientSecret; + public final URI authUri; + public final URI tokenUri; + public final Boolean sendMaxAge; + public URI userinfoUri; + public IDTokenValidator validator; + + public OidcProvider(String name, String clientId, String clientSecret, String issuer, String sendMaxAge) throws GeneralException, IOException { + this.name = name; + this.clientID = new ClientID(clientId); + this.clientSecret = new Secret(clientSecret); + this.sendMaxAge = Boolean.valueOf(sendMaxAge); + + Issuer iss = new Issuer(issuer); + OIDCProviderMetadata providerData = OIDCProviderMetadata.resolve(iss); + this.authUri = providerData.getAuthorizationEndpointURI(); + this.tokenUri = providerData.getTokenEndpointURI(); + this.userinfoUri = providerData.getUserInfoEndpointURI(); + this.validator = new IDTokenValidator(iss, this.clientID, providerData.getIDTokenJWSAlgs().getFirst(), providerData.getJWKSetURI().toURL()); + } +} diff --git a/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java b/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java new file mode 100644 index 00000000..cb283446 --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/db/OidcData.java @@ -0,0 +1,77 @@ +package me.kavin.piped.utils.obj.db; + +import com.nimbusds.oauth2.sdk.pkce.CodeVerifier; +import com.nimbusds.openid.connect.sdk.Nonce; + +import java.io.Serializable; + +import org.apache.commons.codec.binary.Base64; +import org.apache.commons.codec.digest.DigestUtils; +import jakarta.persistence.*; + +@Entity +@Table(name = "oidc_logins") +public class OidcData implements Serializable { + + @Column(name = "nonce", unique = true, length = 256) + @Id + public String nonce; + + @Column(name = "verifier", length = 128) + public String verifierSecret; + + @Column(name = "data") + public String data; + + @Column(name = "state") + public String state; + + @Column(name = "start") + public long start; + + public OidcData(String data, CodeVerifier pkceVerifier) { + this.nonce = new Nonce().toString(); + this.verifierSecret = pkceVerifier.getValue(); + this.data = data; + this.start = System.currentTimeMillis() / 1000L; + this.state = getState(); + } + + public OidcData() { + } + + public boolean validateNonce(String nonce) { + return this.nonce.equals(nonce); + } + + public String getState() { + String value = this.nonce + this.data; + + byte[] hash = DigestUtils.sha256(value); + return Base64.encodeBase64String(hash); + } + + public Nonce getOidNonce(){ + return new Nonce(this.nonce); + } + + public void setNonce(String nonce) { + this.nonce = nonce; + } + + public CodeVerifier getOidVerifier(){ + return new CodeVerifier(this.verifierSecret); + } + + public void setVerifier(String verifier) { + this.verifierSecret = verifier; + } + + public void setData(String data) { + this.data = data; + } + + public void setState(String state) { + this.state = state; + } +} diff --git a/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java b/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java new file mode 100644 index 00000000..748e7d02 --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java @@ -0,0 +1,57 @@ +package me.kavin.piped.utils.obj.db; + +import java.util.Set; + +import org.hibernate.annotations.Cascade; + +import java.io.Serializable; + +import jakarta.persistence.*; + +@Entity +@Table(name = "oidc_user_data") +public class OidcUserData implements Serializable { + + public OidcUserData() { + } + + public OidcUserData(String sub, String username, String provider) { + this.sub = sub; + this.provider = provider; + this.user = new User(username,"", Set.of()); + } + + @Column(name = "sub", unique = true, length = 255) + @Id + private String sub; + + @OneToOne + @Cascade(org.hibernate.annotations.CascadeType.ALL) + private User user; + + @Column(name = "provider", nullable = false) + private String provider; + + public User getUser() { + return user; + } + + public void setUser(User user) { + this.user = user; + } + public String getSub() { + return sub; + } + + public void setSub(String sub) { + this.sub = sub; + } + + public String getProvider() { + return provider; + } + + public void setProvider(String provider) { + this.provider = provider; + } +} diff --git a/src/main/resources/changelog/db.changelog-master.xml b/src/main/resources/changelog/db.changelog-master.xml index 4d3a056a..40184931 100644 --- a/src/main/resources/changelog/db.changelog-master.xml +++ b/src/main/resources/changelog/db.changelog-master.xml @@ -6,4 +6,5 @@ + diff --git a/src/main/resources/changelog/version/2-add-oidc.xml b/src/main/resources/changelog/version/2-add-oidc.xml new file mode 100644 index 00000000..b413f18d --- /dev/null +++ b/src/main/resources/changelog/version/2-add-oidc.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +