Skip to content

Commit

Permalink
Send gRPC errors properly (#1983)
Browse files Browse the repository at this point in the history
According to the [gRPC spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md), all
gRPC responses, including errors, should have an HTTP status of 200. gRPC errors are signaled using
the `grpc-status` header. This brings Misk in compliance with the spec, translating HTTP errors into
properly formatted gRPC errors.


Co-authored-by: Jesse Wilson <[email protected]>
Co-authored-by: Eric Wolak <[email protected]>
  • Loading branch information
ewolak-sq and squarejesse authored Jun 30, 2021
1 parent 068a7bd commit 171a7ce
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 24 deletions.
24 changes: 15 additions & 9 deletions misk-actions/src/main/kotlin/misk/exceptions/Exceptions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,25 @@ open class WebActionException(
) : Exception(message, cause) {
val isClientError = code in (400..499)
val isServerError = code in (500..599)

constructor(
code: Int,
message: String,
cause: Throwable? = null
) : this(code, message, message, cause)
}

/** Base exception for when resources are not found */
open class NotFoundException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_NOT_FOUND, message, message, cause)
WebActionException(HTTP_NOT_FOUND, message, cause)

/** Base exception for when authentication fails */
open class UnauthenticatedException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_UNAUTHORIZED, message, message, cause)
WebActionException(HTTP_UNAUTHORIZED, message, cause)

/** Base exception for when authenticated credentials lack access to a resource */
open class UnauthorizedException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_FORBIDDEN, message, message, cause)
WebActionException(HTTP_FORBIDDEN, message, cause)

/**
* Base exception for when a resource is unavailable. The message is not exposed to the caller.
Expand All @@ -49,19 +55,19 @@ open class ResourceUnavailableException(message: String = "", cause: Throwable?

/** Base exception for bad client requests */
open class BadRequestException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_BAD_REQUEST, message, message, cause)
WebActionException(HTTP_BAD_REQUEST, message, cause)

/** Base exception for when a request causes a conflict */
open class ConflictException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_CONFLICT, message, message, cause)
WebActionException(HTTP_CONFLICT, message, cause)

/** This exception is custom to Misk. */
open class UnprocessableEntityException(message: String = "", cause: Throwable? = null) :
WebActionException(422, message, message, cause)
WebActionException(422, message, cause)

/** This exception is custom to Misk. */
open class TooManyRequestsException(message: String = "", cause: Throwable? = null) :
WebActionException(429, message, message, cause)
WebActionException(429, message, cause)

/** This exception is custom to Misk. */
open class ClientClosedRequestException(message: String = "", cause: Throwable? = null) :
Expand All @@ -75,10 +81,10 @@ open class GatewayTimeoutException(message: String = "", cause: Throwable? = nul
WebActionException(HTTP_GATEWAY_TIMEOUT, "GATEWAY_TIMEOUT", message, cause)

open class PayloadTooLargeException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_ENTITY_TOO_LARGE, message, message, cause)
WebActionException(HTTP_ENTITY_TOO_LARGE, message, cause)

open class UnsupportedMediaTypeException(message: String = "", cause: Throwable? = null) :
WebActionException(HTTP_UNSUPPORTED_TYPE, message, message, cause)
WebActionException(HTTP_UNSUPPORTED_TYPE, message, cause)

/** Similar to [kotlin.require], but throws [BadRequestException] if the check fails */
inline fun requireRequest(check: Boolean, lazyMessage: () -> String) {
Expand Down
3 changes: 3 additions & 0 deletions misk-grpc-tests/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ sourceSets {

dependencies {
implementation(Dependencies.assertj)
implementation(Dependencies.awaitility)
implementation(Dependencies.junitApi)
implementation(Dependencies.kotlinTest)
implementation(Dependencies.docker)
Expand All @@ -76,6 +77,8 @@ dependencies {
implementation(project(":misk-actions"))
implementation(project(":misk-core"))
implementation(project(":misk-inject"))
implementation(project(":misk-metrics"))
implementation(project(":misk-metrics-testing"))
implementation(project(":misk-service"))
implementation(project(":misk-testing"))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package misk.grpc.miskserver

import misk.exceptions.WebActionException
import misk.web.actions.WebAction
import misk.web.interceptors.LogRequestResponse
import routeguide.Feature
Expand All @@ -10,6 +11,9 @@ import javax.inject.Inject
class GetFeatureGrpcAction @Inject constructor() : WebAction, RouteGuideGetFeatureBlockingServer {
@LogRequestResponse(bodySampling = 1.0, errorBodySampling = 1.0)
override fun GetFeature(request: Point): Feature {
if (request.latitude == -1) {
throw WebActionException(request.longitude ?: 500, "unexpected latitude error!")
}
return Feature(name = "maple tree", location = request)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package misk.grpc.miskserver

import com.google.inject.Provides
import com.google.inject.util.Modules
import misk.MiskTestingServiceModule
import misk.inject.KAbstractModule
import misk.metrics.FakeMetricsModule
import misk.web.WebActionModule
import misk.web.WebServerTestingModule
import misk.web.jetty.JettyService
Expand All @@ -12,7 +14,7 @@ import javax.inject.Named
class RouteGuideMiskServiceModule : KAbstractModule() {
override fun configure() {
install(WebServerTestingModule(webConfig = WebServerTestingModule.TESTING_WEB_CONFIG))
install(MiskTestingServiceModule())
install(Modules.override(MiskTestingServiceModule()).with(FakeMetricsModule()))
install(WebActionModule.create<GetFeatureGrpcAction>())
install(WebActionModule.create<RouteChatGrpcAction>())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package misk.grpc

import com.google.inject.Guice
import com.google.inject.util.Modules
import javax.inject.Inject
import javax.inject.Named
import com.squareup.wire.GrpcException
import com.squareup.wire.GrpcStatus
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.runBlocking
Expand All @@ -13,18 +13,30 @@ import misk.grpc.miskserver.RouteChatGrpcAction
import misk.grpc.miskserver.RouteGuideMiskServiceModule
import misk.inject.getInstance
import misk.logging.LogCollectorModule
import misk.metrics.FakeMetrics
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import misk.web.interceptors.RequestLoggingInterceptor
import okhttp3.HttpUrl
import org.assertj.core.api.Assertions.assertThat
import org.awaitility.Durations.ONE_HUNDRED_MILLISECONDS
import org.awaitility.Durations.ONE_MILLISECOND
import org.awaitility.kotlin.atMost
import org.awaitility.kotlin.await
import org.awaitility.kotlin.matches
import org.awaitility.kotlin.untilCallTo
import org.awaitility.kotlin.withPollDelay
import org.awaitility.kotlin.withPollInterval
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import routeguide.Feature
import routeguide.Point
import routeguide.RouteGuideClient
import routeguide.RouteNote
import wisp.logging.LogCollector
import javax.inject.Inject
import javax.inject.Named
import kotlin.test.assertFailsWith

@MiskTest(startService = true)
class MiskClientMiskServerTest {
Expand All @@ -37,6 +49,7 @@ class MiskClientMiskServerTest {
@Inject lateinit var logCollector: LogCollector
@Inject lateinit var routeChatGrpcAction: RouteChatGrpcAction
@Inject @field:Named("grpc server") lateinit var serverUrl: HttpUrl
@Inject lateinit var metrics: FakeMetrics

private lateinit var routeGuide: RouteGuideClient
private lateinit var callCounter: RouteGuideCallCounter
Expand Down Expand Up @@ -105,4 +118,60 @@ class MiskClientMiskServerTest {
sendChannel.close()
}
}

@Test
fun serverFailureGeneric() {
val point = Point(
latitude = -1,
longitude = 500
)

runBlocking {
val e = assertFailsWith<GrpcException> {
routeGuide.GetFeature().execute(point)
}
assertThat(e.grpcMessage).isEqualTo("Internal Server Error")
assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNKNOWN)

// Assert that _metrics_ counted a 500 and no 200s, even though an HTTP 200 was returned
// over HTTP. The 200 is implicitly asserted by the fact that we got a GrpcException, which
// is only thrown if a properly constructed gRPC error is received.
assertResponseCount(200, 0)
assertResponseCount(500, 1)
}
}

@Test
fun serverFailureNotFound() {
val point = Point(
latitude = -1,
longitude = 404
)

runBlocking {
val e = assertFailsWith<GrpcException> {
routeGuide.GetFeature().execute(point)
}
assertThat(e.grpcMessage).isEqualTo("unexpected latitude error!")
assertThat(e.grpcStatus).isEqualTo(GrpcStatus.UNIMPLEMENTED)
.withFailMessage("wrong gRPC status ${e.grpcStatus.name}")

// Assert that _metrics_ counted a 404 and no 200s, even though an HTTP 200 was returned
// over HTTP. The 200 is implicitly asserted by the fact that we got a GrpcException, which
// is only thrown if a properly constructed gRPC error is received.
assertResponseCount(200, 0)
assertResponseCount(404, 1)
}
}

private fun assertResponseCount(code: Int, count: Int) {
await withPollInterval ONE_MILLISECOND atMost ONE_HUNDRED_MILLISECONDS untilCallTo {
metrics.histogramCount(
"http_request_latency_ms",
"action" to "GetFeatureGrpcAction",
"caller" to "unknown",
"code" to code.toString(),
)?.toInt() ?: 0
} matches { it == count }
}
}
6 changes: 6 additions & 0 deletions misk-testing/src/main/kotlin/misk/web/FakeHttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ data class FakeHttpCall(
override val dispatchMechanism: DispatchMechanism = DispatchMechanism.GET,
override val requestHeaders: Headers = headersOf(),
override var statusCode: Int = 200,
override var networkStatusCode: Int = 200,
val headersBuilder: Headers.Builder = Headers.Builder(),
var sendTrailers: Boolean = false,
val trailersBuilder: Headers.Builder = Headers.Builder(),
Expand All @@ -28,6 +29,11 @@ data class FakeHttpCall(
override val responseHeaders: Headers
get() = headersBuilder.build()

override fun setStatusCodes(statusCode: Int, networkStatusCode: Int) {
this.statusCode = statusCode
this.networkStatusCode = networkStatusCode
}

override fun setResponseHeader(name: String, value: String) {
headersBuilder.set(name, value)
}
Expand Down
12 changes: 11 additions & 1 deletion misk/src/main/kotlin/misk/web/HttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,20 @@ interface HttpCall {
val dispatchMechanism: DispatchMechanism
val requestHeaders: Headers

/** The HTTP response under construction. */
/** Meaningful HTTP status about what actually happened. Not sent over the wire in the case
* of gRPC, which always returns HTTP 200 even for errors. */
var statusCode: Int

/** The HTTP status code actually sent over the network. For gRPC, this is always 200, even
* for errors, per the spec. **/
val networkStatusCode: Int

val responseHeaders: Headers

/** Set both the raw network status code and the meaningful status code that's
* recorded in metrics */
fun setStatusCodes(statusCode: Int, networkStatusCode: Int)

fun setResponseHeader(name: String, value: String)
fun addResponseHeaders(headers: Headers)

Expand Down
12 changes: 11 additions & 1 deletion misk/src/main/kotlin/misk/web/ServletHttpCall.kt
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,26 @@ internal data class ServletHttpCall(
var responseBody: BufferedSink? = null,
var webSocket: WebSocket? = null
) : HttpCall {
private var _actualStatusCode: Int? = null

override var statusCode: Int
get() = upstreamResponse.statusCode
get() = _actualStatusCode ?: upstreamResponse.statusCode
set(value) {
_actualStatusCode = value
upstreamResponse.statusCode = value
}

override val networkStatusCode: Int
get() = upstreamResponse.statusCode

override val responseHeaders: Headers
get() = upstreamResponse.headers

override fun setStatusCodes(statusCode: Int, networkStatusCode: Int) {
_actualStatusCode = statusCode
upstreamResponse.statusCode = networkStatusCode
}

override fun setResponseHeader(name: String, value: String) {
upstreamResponse.setHeader(name, value)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package misk.web.exceptions

import com.google.common.util.concurrent.UncheckedExecutionException
import com.squareup.wire.GrpcStatus
import com.squareup.wire.ProtoAdapter
import misk.Action
import misk.exceptions.UnauthenticatedException
import misk.exceptions.UnauthorizedException
import misk.grpc.GrpcMessageSink
import misk.web.DispatchMechanism
import misk.web.HttpCall
import misk.web.NetworkChain
import misk.web.NetworkInterceptor
import misk.web.Response
import misk.web.ResponseBody
import misk.web.mediatype.MediaTypes
import misk.web.toResponseBody
import okhttp3.Headers.Companion.toHeaders
import okio.Buffer
import okio.BufferedSink
import okio.ByteString
import wisp.logging.getLogger
import wisp.logging.log
import java.lang.reflect.InvocationTargetException
Expand All @@ -37,13 +45,67 @@ class ExceptionHandlingInterceptor(
} catch (th: Throwable) {
val response = toResponse(th)
chain.httpCall.statusCode = response.statusCode
chain.httpCall.takeResponseBody()?.use { sink ->
chain.httpCall.addResponseHeaders(response.headers)
(response.body as ResponseBody).writeTo(sink)
if (chain.httpCall.dispatchMechanism == DispatchMechanism.GRPC) {
sendGrpcFailure(chain.httpCall, response)
} else {
sendHttpFailure(chain.httpCall, response)
}
}
}

private fun sendHttpFailure(httpCall: HttpCall, response: Response<*>) {
httpCall.takeResponseBody()?.use { sink ->
httpCall.addResponseHeaders(response.headers)
(response.body as ResponseBody).writeTo(sink)
}
}

/**
* Borrow behavior from [GrpcFeatureBinding] to send a gRPC error with an HTTP 200 status code.
* This is weird but it's how gRPC clients work.
*
* One thing to note is for our metrics we want to pretend that the HTTP code is what we sent.
* Otherwise gRPC requests that crashed and yielded an HTTP 200 code will confuse operators.
*/
private fun sendGrpcFailure(httpCall: HttpCall, response: Response<*>) {
httpCall.setStatusCodes(httpCall.statusCode, 200)
httpCall.requireTrailers()
httpCall.setResponseHeader("grpc-encoding", "identity")
httpCall.setResponseHeader("Content-Type", MediaTypes.APPLICATION_GRPC)
httpCall.setResponseTrailer(
"grpc-status",
toGrpcStatus(response.statusCode).code.toString()
)
httpCall.setResponseTrailer("grpc-message", this.grpcMessage(response))
httpCall.takeResponseBody()?.use { responseBody: BufferedSink ->
GrpcMessageSink(responseBody, ProtoAdapter.BYTES, grpcEncoding = "identity")
.use { messageSink ->
messageSink.write(ByteString.EMPTY)
}
}
}

private fun grpcMessage(response: Response<*>): String {
val buffer = Buffer()
(response.body as ResponseBody).writeTo(buffer)
return buffer.readUtf8()
}

/** https://grpc.github.io/grpc/core/md_doc_http-grpc-status-mapping.html */
private fun toGrpcStatus(statusCode: Int): GrpcStatus {
return when (statusCode) {
400 -> GrpcStatus.INTERNAL
401 -> GrpcStatus.UNAUTHENTICATED
403 -> GrpcStatus.PERMISSION_DENIED
404 -> GrpcStatus.UNIMPLEMENTED
429 -> GrpcStatus.UNAVAILABLE
502 -> GrpcStatus.UNAVAILABLE
503 -> GrpcStatus.UNAVAILABLE
504 -> GrpcStatus.UNAVAILABLE
else -> GrpcStatus.UNKNOWN
}
}

private fun toResponse(th: Throwable): Response<*> = when (th) {
is UnauthenticatedException -> UNAUTHENTICATED_RESPONSE
is UnauthorizedException -> UNAUTHORIZED_RESPONSE
Expand Down
Loading

0 comments on commit 171a7ce

Please sign in to comment.