Skip to content

Commit

Permalink
Implementing a more maintainable version of annotation based authenti…
Browse files Browse the repository at this point in the history
…cation
  • Loading branch information
NovaFox161 committed Oct 21, 2023
1 parent 6387ef4 commit e600214
Show file tree
Hide file tree
Showing 19 changed files with 271 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package org.dreamexposure.discal.cam.business

import org.dreamexposure.discal.core.business.ApiKeyService
import org.dreamexposure.discal.core.business.SessionService
import org.dreamexposure.discal.core.config.Config
import org.dreamexposure.discal.core.extensions.isExpiredTtl
import org.dreamexposure.discal.core.`object`.new.security.Scope
import org.dreamexposure.discal.core.`object`.new.security.TokenType
import org.springframework.stereotype.Component

@Component
class SecurityService(
private val sessionService: SessionService,
private val apiKeyService: ApiKeyService,
) {
suspend fun authenticateToken(token: String): Boolean {
val schema = getSchema(token)
val tokenStr = token.removePrefix(schema.schema)

return when (schema) {
TokenType.BEARER -> authenticateUserToken(tokenStr)
TokenType.APP -> authenticateAppToken(tokenStr)
TokenType.INTERNAL -> authenticateInternalToken(tokenStr)
else -> false
}
}

suspend fun validateTokenSchema(token: String, allowedSchemas: List<TokenType>): Boolean {
if (allowedSchemas.isEmpty()) return true // No schemas required
val schema = getSchema(token)

return allowedSchemas.contains(schema)
}

suspend fun authorizeToken(token: String, requiredScopes: List<Scope>): Boolean {
if (requiredScopes.isEmpty()) return true // No scopes required

val schema = getSchema(token)
val tokenStr = token.removePrefix(schema.schema)

val scopes = when (schema) {
TokenType.BEARER -> getScopesForUserToken(tokenStr)
TokenType.APP -> getScopesForAppToken(tokenStr)
TokenType.INTERNAL -> getScopesForInternalToken()
else -> return false
}

return scopes.containsAll(requiredScopes)
}


// Authentication based on token type
private suspend fun authenticateUserToken(token: String): Boolean {
val session = sessionService.getSession(token) ?: return false

return !session.expiresAt.isExpiredTtl()
}

private suspend fun authenticateAppToken(token: String): Boolean {
val key = apiKeyService.getKey(token) ?: return false

return !key.blocked
}

private fun authenticateInternalToken(token: String): Boolean {
return Config.SECRET_DISCAL_API_KEY.getString() == token
}

// Fetching scopes for tokens
private suspend fun getScopesForUserToken(token: String): List<Scope> {
return sessionService.getSession(token)?.scopes ?: emptyList()
}

private suspend fun getScopesForAppToken(token: String): List<Scope> {
return apiKeyService.getKey(token)?.scopes ?: emptyList()
}

private fun getScopesForInternalToken(): List<Scope> = Scope.entries.toList()

// Various other stuff
private fun getSchema(token: String): TokenType {
return when {
token.startsWith(TokenType.BEARER.schema) -> TokenType.BEARER
token.startsWith(TokenType.APP.schema) -> TokenType.APP
token.startsWith(TokenType.INTERNAL.schema) -> TokenType.INTERNAL
else -> TokenType.NONE
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package org.dreamexposure.discal.cam.controllers.v1

import discord4j.common.util.Snowflake
import org.dreamexposure.discal.cam.managers.CalendarAuthManager
import org.dreamexposure.discal.core.annotations.Authentication
import org.dreamexposure.discal.core.annotations.SecurityRequirement
import org.dreamexposure.discal.core.enums.calendar.CalendarHost
import org.dreamexposure.discal.core.`object`.network.discal.CredentialData
import org.dreamexposure.discal.core.`object`.new.security.Scope.CALENDAR_TOKEN_READ
import org.dreamexposure.discal.core.`object`.new.security.TokenType.INTERNAL
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestParam
Expand All @@ -15,7 +17,7 @@ import org.springframework.web.bind.annotation.RestController
class TokenController(
private val calendarAuthManager: CalendarAuthManager,
) {
@Authentication(access = Authentication.AccessLevel.ADMIN)
@SecurityRequirement(schemas = [INTERNAL], scopes = [CALENDAR_TOKEN_READ])
@GetMapping(produces = ["application/json"])
suspend fun getToken(@RequestParam host: CalendarHost, @RequestParam id: Int, @RequestParam guild: Snowflake?): CredentialData? {
return calendarAuthManager.getCredentialData(host, id, guild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import org.dreamexposure.discal.cam.json.discal.LoginResponse
import org.dreamexposure.discal.cam.json.discal.TokenRequest
import org.dreamexposure.discal.cam.json.discal.TokenResponse
import org.dreamexposure.discal.cam.managers.DiscordOauthManager
import org.dreamexposure.discal.core.annotations.Authentication
import org.dreamexposure.discal.core.annotations.SecurityRequirement
import org.dreamexposure.discal.core.`object`.new.security.Scope.OAUTH2_DISCORD
import org.dreamexposure.discal.core.`object`.new.security.TokenType.BEARER
import org.springframework.web.bind.annotation.*

@RestController
Expand All @@ -14,20 +16,20 @@ class DiscordOauthController(
) {

@GetMapping("login")
@Authentication(access = Authentication.AccessLevel.PUBLIC)
@SecurityRequirement(disableSecurity = true, scopes = [])
suspend fun login(): LoginResponse {
val link = discordOauthManager.getOauthLinkForLogin()
return LoginResponse(link)
}

@GetMapping("logout")
@Authentication(access = Authentication.AccessLevel.WRITE)
@SecurityRequirement(schemas = [BEARER], scopes = [OAUTH2_DISCORD])
suspend fun logout(@RequestHeader("Authorization") token: String) {
discordOauthManager.handleLogout(token)
}

@PostMapping("code")
@Authentication(access = Authentication.AccessLevel.PUBLIC)
@SecurityRequirement(disableSecurity = true, scopes = [])
suspend fun token(@RequestBody body: TokenRequest): TokenResponse {
return discordOauthManager.handleCodeExchange(body.state, body.code)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CalendarAuthManager(
}
} catch (ex: Exception) {
LOGGER.error("Get CredentialData Exception | guildId:$guild | credentialId:$id | calendarHost:${host.name}", ex)
null
throw ex // rethrow
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.dreamexposure.discal.core.business.SessionService
import org.dreamexposure.discal.core.config.Config
import org.dreamexposure.discal.core.crypto.KeyGenerator
import org.dreamexposure.discal.core.`object`.WebSession
import org.dreamexposure.discal.core.`object`.new.security.Scope
import org.dreamexposure.discal.core.utils.GlobalVal.discordApiUrl
import org.springframework.http.HttpStatus
import org.springframework.stereotype.Component
Expand Down Expand Up @@ -51,7 +52,8 @@ class DiscordOauthManager(
apiToken,
authInfo.user!!.id,
accessToken = dTokens.accessToken,
refreshToken = dTokens.refreshToken
refreshToken = dTokens.refreshToken,
scopes = Scope.defaultWebsiteLoginScopes(),
)

sessionService.removeAndInsertSession(session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package org.dreamexposure.discal.cam.security

import com.fasterxml.jackson.databind.ObjectMapper
import kotlinx.coroutines.reactive.awaitFirst
import kotlinx.coroutines.reactive.awaitFirstOrNull
import kotlinx.coroutines.reactor.mono
import org.dreamexposure.discal.cam.business.SecurityService
import org.dreamexposure.discal.core.annotations.SecurityRequirement
import org.dreamexposure.discal.core.extensions.spring.writeJsonString
import org.dreamexposure.discal.core.`object`.rest.ErrorResponse
import org.springframework.http.HttpStatus
import org.springframework.stereotype.Component
import org.springframework.web.method.HandlerMethod
import org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerMapping
import org.springframework.web.server.ServerWebExchange
import org.springframework.web.server.WebFilter
import org.springframework.web.server.WebFilterChain
import reactor.core.publisher.Mono

@Component
class SecurityWebFilter(
private val securityService: SecurityService,
private val handlerMapping: RequestMappingHandlerMapping,
private val objectMapper: ObjectMapper,
) : WebFilter {

override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
return mono {
doSecurityFilter(exchange, chain)
}.then(chain.filter(exchange))
}

suspend fun doSecurityFilter(exchange: ServerWebExchange, chain: WebFilterChain) {
val handlerMethod = handlerMapping.getHandler(exchange)
.cast(HandlerMethod::class.java)
.awaitFirst()

if (!handlerMethod.hasMethodAnnotation(SecurityRequirement::class.java)) {
throw IllegalStateException("No SecurityRequirement annotation!")
}

val authAnnotation = handlerMethod.getMethodAnnotation(SecurityRequirement::class.java)!!
val authHeader = exchange.request.headers.getOrEmpty("Authorization").firstOrNull()


if (authAnnotation.disableSecurity) return

if (authHeader == null) {
exchange.response.statusCode = HttpStatus.UNAUTHORIZED
exchange.response.writeJsonString(
objectMapper.writeValueAsString(ErrorResponse("Missing Authorization header"))
).awaitFirstOrNull()
return
}

if (authHeader.equals("teapot", ignoreCase = true)) {
exchange.response.statusCode = HttpStatus.I_AM_A_TEAPOT
exchange.response.writeJsonString(
objectMapper.writeValueAsString(ErrorResponse("I'm a teapot"))
).awaitFirstOrNull()
return
}

if (!securityService.authenticateToken(authHeader)) {
exchange.response.statusCode = HttpStatus.UNAUTHORIZED
exchange.response.writeJsonString(
objectMapper.writeValueAsString(ErrorResponse("Unauthenticated"))
).awaitFirstOrNull()
return
}

if (!securityService.validateTokenSchema(authHeader, authAnnotation.schemas.toList())) {
exchange.response.statusCode = HttpStatus.UNAUTHORIZED
exchange.response.writeJsonString(
objectMapper.writeValueAsString(ErrorResponse("Unsupported schema"))
).awaitFirstOrNull()
return
}

if (!securityService.authorizeToken(authHeader, authAnnotation.scopes.toList())) {
exchange.response.statusCode = HttpStatus.FORBIDDEN
exchange.response.writeJsonString(
objectMapper.writeValueAsString(ErrorResponse("Access denied"))
).awaitFirstOrNull()
return
}

// If we made it to the end, everything is good to go.
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.dreamexposure.discal.core.annotations

import org.dreamexposure.discal.core.`object`.new.security.Scope
import org.dreamexposure.discal.core.`object`.new.security.TokenType

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.FUNCTION)
annotation class SecurityRequirement(
val schemas: Array<TokenType> = [], // Default to allowing any token kind
val scopes: Array<Scope>,
val disableSecurity: Boolean = false,
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class DefaultCredentialService(
var credential = credentialsCache.get(key = number)
if (credential != null) return credential

val data = credentialsRepository.findByCredentialNumber(number).awaitSingle()
val data = credentialsRepository.findByCredentialNumber(number).awaitSingleOrNull() ?: return null
credential = Credential(data)

credentialsCache.put(key = number, value = credential)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class DefaultSessionService(
expiresAt = session.expiresAt,
accessToken = session.accessToken,
refreshToken = session.refreshToken,
scopes = session.scopes.joinToString(",") { it.name }
)).map(::WebSession).awaitSingle()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ data class ApiData(
val apiKey: String,
val blocked: Boolean,
val timeIssued: Long,
val scopes: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ data class SessionData(
val expiresAt: Instant,
val accessToken: String,
val refreshToken: String,
val scopes: String,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package org.dreamexposure.discal.core.extensions.spring

import org.springframework.http.MediaType
import org.springframework.http.server.reactive.ServerHttpResponse
import reactor.core.publisher.Mono

fun ServerHttpResponse.writeJsonString(json: String): Mono<Void> {
val factory = bufferFactory()
val buffer = factory.wrap(json.toByteArray())

headers.contentType = MediaType.APPLICATION_JSON
return writeWith(Mono.just(buffer))
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@ package org.dreamexposure.discal.core.`object`
import discord4j.common.util.Snowflake
import org.dreamexposure.discal.core.database.SessionData
import org.dreamexposure.discal.core.extensions.asSnowflake
import org.dreamexposure.discal.core.extensions.asStringListFromDatabase
import org.dreamexposure.discal.core.`object`.new.security.Scope
import java.time.Instant
import java.time.temporal.ChronoUnit

data class WebSession(
val token: String,

val user: Snowflake,

val expiresAt: Instant = Instant.now().plus(7, ChronoUnit.DAYS),

val accessToken: String,

val refreshToken: String,
val scopes: List<Scope>,
) {
constructor(data: SessionData) : this(
token = data.token,
user = data.userId.asSnowflake(),
expiresAt = data.expiresAt,
accessToken = data.accessToken,
refreshToken = data.refreshToken,
scopes = data.scopes.asStringListFromDatabase().map(Scope::valueOf),
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@ package org.dreamexposure.discal.core.`object`.new
import discord4j.common.util.Snowflake
import org.dreamexposure.discal.core.database.ApiData
import org.dreamexposure.discal.core.extensions.asInstantMilli
import org.dreamexposure.discal.core.extensions.asStringListFromDatabase
import org.dreamexposure.discal.core.`object`.new.security.Scope
import java.time.Instant

data class ApiKey(
val userId: Snowflake,
val key: String,
val blocked: Boolean,
val timeIssued: Instant,
val userId: Snowflake,
val key: String,
val blocked: Boolean,
val timeIssued: Instant,
val scopes: List<Scope>,
) {
constructor(data: ApiData): this(
userId = Snowflake.of(data.apiKey),
key = data.apiKey,
blocked = data.blocked,
timeIssued = data.timeIssued.asInstantMilli(),
scopes = data.scopes.asStringListFromDatabase().map(Scope::valueOf),
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.dreamexposure.discal.core.`object`.new.security

enum class Scope {
CALENDAR_TOKEN_READ,

OAUTH2_DISCORD,
;

companion object {
fun defaultWebsiteLoginScopes() = listOf(
OAUTH2_DISCORD,
)

fun defaultBasicAppScopes() = listOf<Scope>()
}
}
Loading

0 comments on commit e600214

Please sign in to comment.