diff --git a/cc/jwt/internal/BUILD.bazel b/cc/jwt/internal/BUILD.bazel index b1d51d209d..f0f14de62f 100644 --- a/cc/jwt/internal/BUILD.bazel +++ b/cc/jwt/internal/BUILD.bazel @@ -87,6 +87,7 @@ cc_library( "//util:status", "//util:statusor", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -94,6 +95,7 @@ cc_test( name = "jwt_format_test", srcs = ["jwt_format_test.cc"], deps = [ + ":json_util", ":jwt_format", "//util:test_matchers", "//util:test_util", @@ -528,6 +530,7 @@ cc_library( hdrs = ["jwt_public_key_verify_impl.h"], include_prefix = "tink/jwt/internal", deps = [ + ":json_util", ":jwt_format", "//:public_key_verify", "//jwt:jwt_public_key_verify", diff --git a/cc/jwt/internal/CMakeLists.txt b/cc/jwt/internal/CMakeLists.txt index 7cd94d62a1..b9f0f105b3 100644 --- a/cc/jwt/internal/CMakeLists.txt +++ b/cc/jwt/internal/CMakeLists.txt @@ -54,6 +54,7 @@ tink_cc_library( jwt_format.cc jwt_format.h DEPS + protobuf::libprotobuf tink::jwt::internal::json_util tink::util::status tink::util::statusor @@ -64,6 +65,7 @@ tink_cc_test( NAME jwt_format_test SRCS jwt_format_test.cc DEPS + tink::jwt::internal::json_util tink::jwt::internal::jwt_format tink::util::test_matchers tink::util::test_util @@ -488,6 +490,7 @@ tink_cc_library( jwt_public_key_verify_impl.cc jwt_public_key_verify_impl.h DEPS + tink::jwt::internal::json_util tink::jwt::internal::jwt_format tink::core::public_key_verify tink::jwt::jwt_public_key_verify diff --git a/cc/jwt/internal/jwt_format.cc b/cc/jwt/internal/jwt_format.cc index e6bc55aab6..663ca78fee 100644 --- a/cc/jwt/internal/jwt_format.cc +++ b/cc/jwt/internal/jwt_format.cc @@ -51,22 +51,29 @@ bool DecodeHeader(absl::string_view header, std::string* json_header) { return StrictWebSafeBase64Unescape(header, json_header); } -std::string CreateHeader(absl::string_view algorithm) { - std::string header = absl::StrCat(R"({"alg":")", algorithm, R"("})"); - return EncodeHeader(header); +std::string CreateHeader(absl::string_view algorithm, + absl::optional type_header) { + google::protobuf::Struct header; + auto fields = header.mutable_fields(); + if (type_header.has_value()) { + google::protobuf::Value type_value; + type_value.set_string_value(std::string(type_header.value())); + (*fields)["typ"] = type_value; + } + google::protobuf::Value alg_value; + alg_value.set_string_value(std::string(algorithm)); + (*fields)["alg"] = alg_value; + util::StatusOr json_or = + jwt_internal::ProtoStructToJsonString(header); + if (!json_or.ok()) { + // do something + } + return EncodeHeader(json_or.ValueOrDie()); } -util::Status ValidateHeader(absl::string_view encoded_header, +util::Status ValidateHeader(const google::protobuf::Struct& header, absl::string_view algorithm) { - std::string json_header; - if (!DecodeHeader(encoded_header, &json_header)) { - return util::Status(util::error::INVALID_ARGUMENT, "invalid header"); - } - auto proto_or = JsonStringToProtoStruct(json_header); - if (!proto_or.ok()) { - return proto_or.status(); - } - auto fields = proto_or.ValueOrDie().fields(); + auto fields = header.fields(); auto it = fields.find("alg"); if (it == fields.end()) { return util::Status(util::error::INVALID_ARGUMENT, "header is missing alg"); @@ -85,6 +92,19 @@ util::Status ValidateHeader(absl::string_view encoded_header, return util::OkStatus(); } +absl::optional GetTypeHeader( + const google::protobuf::Struct& header) { + auto it = header.fields().find("typ"); + if (it == header.fields().end()) { + return absl::nullopt; + } + const auto& value = it->second; + if (value.kind_case() != google::protobuf::Value::kStringValue) { + return absl::nullopt; + } + return value.string_value(); +} + std::string EncodePayload(absl::string_view json_payload) { return absl::WebSafeBase64Escape(json_payload); } diff --git a/cc/jwt/internal/jwt_format.h b/cc/jwt/internal/jwt_format.h index bf4c7464cb..0ba56e1c1d 100644 --- a/cc/jwt/internal/jwt_format.h +++ b/cc/jwt/internal/jwt_format.h @@ -17,6 +17,7 @@ #ifndef TINK_JWT_INTERNAL_JWT_FORMAT_H_ #define TINK_JWT_INTERNAL_JWT_FORMAT_H_ +#include "google/protobuf/struct.pb.h" #include "tink/util/status.h" #include "tink/util/statusor.h" @@ -27,9 +28,12 @@ namespace jwt_internal { std::string EncodeHeader(absl::string_view json_header); bool DecodeHeader(absl::string_view header, std::string* json_header); -std::string CreateHeader(absl::string_view algorithm); -util::Status ValidateHeader(absl::string_view encoded_header, +std::string CreateHeader(absl::string_view algorithm, + absl::optional type_header); +util::Status ValidateHeader(const google::protobuf::Struct& header, absl::string_view algorithm); +absl::optional GetTypeHeader( + const google::protobuf::Struct& header); std::string EncodePayload(absl::string_view json_payload); bool DecodePayload(absl::string_view payload, std::string* json_payload); diff --git a/cc/jwt/internal/jwt_format_test.cc b/cc/jwt/internal/jwt_format_test.cc index f68302e4c5..8d17b5b228 100644 --- a/cc/jwt/internal/jwt_format_test.cc +++ b/cc/jwt/internal/jwt_format_test.cc @@ -18,6 +18,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "tink/jwt/internal/json_util.h" #include "tink/util/test_matchers.h" #include "tink/util/test_util.h" @@ -79,72 +80,105 @@ TEST(JwtFormat, DecodeAndValidateFixedHeaderHS256) { // Example from https://tools.ietf.org/html/rfc7515#appendix-A.1 std::string encoded_header = "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9"; - std::string output; - ASSERT_TRUE(DecodeHeader(encoded_header, &output)); - EXPECT_THAT(output, Eq("{\"typ\":\"JWT\",\r\n \"alg\":\"HS256\"}")); + std::string json_header; + ASSERT_TRUE(DecodeHeader(encoded_header, &json_header)); + EXPECT_THAT(json_header, Eq("{\"typ\":\"JWT\",\r\n \"alg\":\"HS256\"}")); + + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); - EXPECT_THAT(ValidateHeader(encoded_header, "HS256"), IsOk()); - EXPECT_FALSE(ValidateHeader(encoded_header, "RS256").ok()); + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "RS256").ok()); } TEST(JwtFormat, DecodeAndValidateFixedHeaderRS256) { // Example from https://tools.ietf.org/html/rfc7515#appendix-A.2 std::string encoded_header = "eyJhbGciOiJSUzI1NiJ9"; - std::string output; - ASSERT_TRUE(DecodeHeader(encoded_header, &output)); - EXPECT_THAT(output, Eq(R"({"alg":"RS256"})")); + std::string json_header; + ASSERT_TRUE(DecodeHeader(encoded_header, &json_header)); + EXPECT_THAT(json_header, Eq(R"({"alg":"RS256"})")); + + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); - EXPECT_THAT(ValidateHeader(encoded_header, "RS256"), IsOk()); - EXPECT_FALSE(ValidateHeader(encoded_header, "HS256").ok()); + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "RS256"), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok()); } TEST(JwtFormat, CreateValidateHeader) { - std::string encoded_header = CreateHeader("PS384"); - EXPECT_THAT(ValidateHeader(encoded_header, "PS384"), IsOk()); - EXPECT_FALSE(ValidateHeader(encoded_header, "HS256").ok()); -} + std::string encoded_header = CreateHeader("PS384", absl::nullopt); -TEST(JwtFormat, ValidateEmptyHeaderFails) { - std::string header = "{}"; - EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok()); + std::string json_header; + ASSERT_TRUE(DecodeHeader(encoded_header, &json_header)); + + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "PS384"), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok()); } -TEST(JwtFormat, ValidateInvalidEncodedHeaderFails) { - EXPECT_FALSE( - ValidateHeader("eyJ0eXAiOiJKV1Q?LA0KICJhbGciOiJIUzI1NiJ9", "HS256").ok()); +TEST(JwtFormat, CreateValidateHeaderWithType) { + std::string encoded_header = CreateHeader("PS384", "JWT"); + + std::string json_header; + ASSERT_TRUE(DecodeHeader(encoded_header, &json_header)); + + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "PS384"), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok()); } -TEST(JwtFormat, ValidateInvalidJsonHeaderFails) { - std::string header = R"({"alg":"HS256")"; // missing } - EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok()); +TEST(JwtFormat, ValidateEmptyHeaderFails) { + google::protobuf::Struct empty_header; + EXPECT_FALSE(ValidateHeader(empty_header, "HS256").ok()); } -TEST(JwtFormat, ValidateHeaderIgnoresTyp) { - std::string header = R"({"alg":"HS256","typ":"unknown"})"; - EXPECT_THAT(ValidateHeader(EncodeHeader(header), "HS256"), IsOk()); +TEST(JwtFormat, ValidateHeaderWithUnknownTypeOk) { + std::string json_header = R"({"alg":"HS256","typ":"unknown"})"; + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk()); } TEST(JwtFormat, ValidateHeaderRejectsCrit) { - std::string header = + std::string json_header = R"({"alg":"HS256","crit":["http://example.invalid/UNDEFINED"],)" R"("http://example.invalid/UNDEFINED":true})"; - EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok()); + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok()); } TEST(JwtFormat, ValidateHeaderWithUnknownEntry) { - std::string header = R"({"alg":"HS256","unknown":"header"})"; - EXPECT_THAT(ValidateHeader(EncodeHeader(header), "HS256"), IsOk()); + std::string json_header = R"({"alg":"HS256","unknown":"header"})"; + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + EXPECT_THAT(ValidateHeader(header_or.ValueOrDie(), "HS256"), IsOk()); } TEST(JwtFormat, ValidateHeaderWithInvalidAlgTypFails) { - std::string header = R"({"alg":true})"; - EXPECT_FALSE(ValidateHeader(EncodeHeader(header), "HS256").ok()); + std::string json_header = R"({"alg":true})"; + util::StatusOr header_or = + JsonStringToProtoStruct(json_header); + EXPECT_THAT(header_or.status(), IsOk()); + EXPECT_FALSE(ValidateHeader(header_or.ValueOrDie(), "HS256").ok()); } TEST(JwtFormat, DecodeFixedPayload) { // Example from https://tools.ietf.org/html/rfc7519#section-3.1 - std::string encoded_header = + std::string encoded_payload = "eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0" "dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ"; @@ -152,7 +186,7 @@ TEST(JwtFormat, DecodeFixedPayload) { "{\"iss\":\"joe\",\r\n \"exp\":1300819380,\r\n " "\"http://example.com/is_root\":true}"; std::string output; - ASSERT_TRUE(DecodePayload(encoded_header, &output)); + ASSERT_TRUE(DecodePayload(encoded_payload, &output)); EXPECT_THAT(output, Eq(expected)); } diff --git a/cc/jwt/internal/jwt_mac_impl.cc b/cc/jwt/internal/jwt_mac_impl.cc index f7ad3a5c7d..8d4a8f5e2b 100644 --- a/cc/jwt/internal/jwt_mac_impl.cc +++ b/cc/jwt/internal/jwt_mac_impl.cc @@ -27,8 +27,16 @@ namespace jwt_internal { util::StatusOr JwtMacImpl::ComputeMacAndEncode( const RawJwt& token) const { - std::string encoded_header = CreateHeader(algorithm_); - util::StatusOr payload_or = token.ToString(); + absl::optional type_header; + if (token.HasTypeHeader()) { + util::StatusOr type_or = token.GetTypeHeader(); + if (!type_or.ok()) { + return type_or.status(); + } + type_header = type_or.ValueOrDie(); + } + std::string encoded_header = CreateHeader(algorithm_, type_header); + util::StatusOr payload_or = token.GetJsonPayload(); if (!payload_or.ok()) { return payload_or.status(); } @@ -64,7 +72,16 @@ util::StatusOr JwtMacImpl::VerifyMacAndDecode( util::error::INVALID_ARGUMENT, "only tokens in JWS compact serialization format are supported"); } - util::Status validate_header_result = ValidateHeader(parts[0], algorithm_); + std::string json_header; + if (!DecodeHeader(parts[0], &json_header)) { + return util::Status(util::error::INVALID_ARGUMENT, "invalid header"); + } + auto header_or = JsonStringToProtoStruct(json_header); + if (!header_or.ok()) { + return header_or.status(); + } + util::Status validate_header_result = + ValidateHeader(header_or.ValueOrDie(), algorithm_); if (!validate_header_result.ok()) { return validate_header_result; } @@ -72,7 +89,8 @@ util::StatusOr JwtMacImpl::VerifyMacAndDecode( if (!DecodePayload(parts[1], &json_payload)) { return util::Status(util::error::INVALID_ARGUMENT, "invalid JWT payload"); } - auto raw_jwt_or = RawJwt::FromString(json_payload); + auto raw_jwt_or = + RawJwt::FromJson(GetTypeHeader(header_or.ValueOrDie()), json_payload); if (!raw_jwt_or.ok()) { return raw_jwt_or.status(); } diff --git a/cc/jwt/internal/jwt_mac_impl_test.cc b/cc/jwt/internal/jwt_mac_impl_test.cc index 754f80eab4..31b0cb2f7e 100644 --- a/cc/jwt/internal/jwt_mac_impl_test.cc +++ b/cc/jwt/internal/jwt_mac_impl_test.cc @@ -34,6 +34,7 @@ #include "tink/util/test_util.h" using ::crypto::tink::test::IsOk; +using ::crypto::tink::test::IsOkAndHolds; namespace crypto { namespace tink { @@ -67,13 +68,16 @@ TEST(JwtMacImplTest, CreateAndValidateToken) { std::unique_ptr jwt_mac = std::move(jwt_mac_or.ValueOrDie()); absl::Time now = absl::Now(); - auto builder = RawJwtBuilder().SetIssuer("issuer"); + auto builder = + RawJwtBuilder().SetTypeHeader("typeHeader").SetIssuer("issuer"); ASSERT_THAT(builder.SetNotBefore(now - absl::Seconds(300)), IsOk()); ASSERT_THAT(builder.SetIssuedAt(now), IsOk()); ASSERT_THAT(builder.SetExpiration(now + absl::Seconds(300)), IsOk()); auto raw_jwt_or = builder.Build(); ASSERT_THAT(raw_jwt_or.status(), IsOk()); RawJwt raw_jwt = raw_jwt_or.ValueOrDie(); + EXPECT_TRUE(raw_jwt.HasTypeHeader()); + EXPECT_THAT(raw_jwt.GetTypeHeader(), IsOkAndHolds("typeHeader")); util::StatusOr compact_or = jwt_mac->ComputeMacAndEncode(raw_jwt); @@ -86,7 +90,8 @@ TEST(JwtMacImplTest, CreateAndValidateToken) { jwt_mac->VerifyMacAndDecode(compact, validator); ASSERT_THAT(verified_jwt_or.status(), IsOk()); auto verified_jwt = verified_jwt_or.ValueOrDie(); - EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("issuer")); + EXPECT_THAT(verified_jwt.GetTypeHeader(), IsOkAndHolds("typeHeader")); + EXPECT_THAT(verified_jwt.GetIssuer(), IsOkAndHolds("issuer")); JwtValidator validator2 = JwtValidatorBuilder().SetIssuer("unknown").Build(); EXPECT_FALSE(jwt_mac->VerifyMacAndDecode(compact, validator2).ok()); @@ -110,9 +115,9 @@ TEST(JwtMacImplTest, ValidateFixedToken) { jwt_mac->VerifyMacAndDecode(compact, validator_1970); ASSERT_THAT(verified_jwt_or.status(), IsOk()); auto verified_jwt = verified_jwt_or.ValueOrDie(); - EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("joe")); + EXPECT_THAT(verified_jwt.GetIssuer(), IsOkAndHolds("joe")); EXPECT_THAT(verified_jwt.GetBooleanClaim("http://example.com/is_root"), - test::IsOkAndHolds(true)); + IsOkAndHolds(true)); // verification fails because token is expired JwtValidator validator_now = JwtValidatorBuilder().Build(); diff --git a/cc/jwt/internal/jwt_public_key_sign_impl.cc b/cc/jwt/internal/jwt_public_key_sign_impl.cc index 9207298585..ef1fda6cff 100644 --- a/cc/jwt/internal/jwt_public_key_sign_impl.cc +++ b/cc/jwt/internal/jwt_public_key_sign_impl.cc @@ -26,8 +26,16 @@ namespace jwt_internal { util::StatusOr JwtPublicKeySignImpl::SignAndEncode( const RawJwt& token) const { - std::string encoded_header = CreateHeader(algorithm_); - util::StatusOr payload_or = token.ToString(); + absl::optional type_header; + if (token.HasTypeHeader()) { + util::StatusOr type_or = token.GetTypeHeader(); + if (!type_or.ok()) { + return type_or.status(); + } + type_header = type_or.ValueOrDie(); + } + std::string encoded_header = CreateHeader(algorithm_, type_header); + util::StatusOr payload_or = token.GetJsonPayload(); if (!payload_or.ok()) { return payload_or.status(); } diff --git a/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc b/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc index 7faf810f14..ab8737454c 100644 --- a/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc +++ b/cc/jwt/internal/jwt_public_key_sign_verify_impl_test.cc @@ -66,7 +66,8 @@ class JwtSignatureImplTest : public ::testing::Test { TEST_F(JwtSignatureImplTest, CreateAndValidateToken) { absl::Time now = absl::Now(); - auto builder = RawJwtBuilder().SetIssuer("issuer"); + auto builder = + RawJwtBuilder().SetTypeHeader("typeHeader").SetIssuer("issuer"); ASSERT_THAT(builder.SetNotBefore(now - absl::Seconds(300)), IsOk()); ASSERT_THAT(builder.SetIssuedAt(now), IsOk()); ASSERT_THAT(builder.SetExpiration(now + absl::Seconds(300)), IsOk()); @@ -86,6 +87,7 @@ TEST_F(JwtSignatureImplTest, CreateAndValidateToken) { jwt_verify_->VerifyAndDecode(compact, validator); ASSERT_THAT(verified_jwt_or.status(), IsOk()); auto verified_jwt = verified_jwt_or.ValueOrDie(); + EXPECT_THAT(verified_jwt.GetTypeHeader(), test::IsOkAndHolds("typeHeader")); EXPECT_THAT(verified_jwt.GetIssuer(), test::IsOkAndHolds("issuer")); // Fails with wrong issuer diff --git a/cc/jwt/internal/jwt_public_key_verify_impl.cc b/cc/jwt/internal/jwt_public_key_verify_impl.cc index 9b5efd773f..4d1f5c5cab 100644 --- a/cc/jwt/internal/jwt_public_key_verify_impl.cc +++ b/cc/jwt/internal/jwt_public_key_verify_impl.cc @@ -18,6 +18,7 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_split.h" +#include "tink/jwt/internal/json_util.h" #include "tink/jwt/internal/jwt_format.h" namespace crypto { @@ -46,7 +47,16 @@ util::StatusOr JwtPublicKeyVerifyImpl::VerifyAndDecode( util::error::INVALID_ARGUMENT, "only tokens in JWS compact serialization format are supported"); } - util::Status validate_header_result = ValidateHeader(parts[0], algorithm_); + std::string json_header; + if (!DecodeHeader(parts[0], &json_header)) { + return util::Status(util::error::INVALID_ARGUMENT, "invalid header"); + } + auto header_or = JsonStringToProtoStruct(json_header); + if (!header_or.ok()) { + return header_or.status(); + } + util::Status validate_header_result = + ValidateHeader(header_or.ValueOrDie(), algorithm_); if (!validate_header_result.ok()) { return validate_header_result; } @@ -54,7 +64,8 @@ util::StatusOr JwtPublicKeyVerifyImpl::VerifyAndDecode( if (!DecodePayload(parts[1], &json_payload)) { return util::Status(util::error::INVALID_ARGUMENT, "invalid JWT payload"); } - auto raw_jwt_or = RawJwt::FromString(json_payload); + auto raw_jwt_or = + RawJwt::FromJson(GetTypeHeader(header_or.ValueOrDie()), json_payload); if (!raw_jwt_or.ok()) { return raw_jwt_or.status(); } diff --git a/cc/jwt/raw_jwt.cc b/cc/jwt/raw_jwt.cc index ed8dcf5238..a89972fe32 100644 --- a/cc/jwt/raw_jwt.cc +++ b/cc/jwt/raw_jwt.cc @@ -148,8 +148,9 @@ util::Status ValidateAndFixAudienceClaim(google::protobuf::Struct* json_proto) { } // namespace -util::StatusOr RawJwt::FromString(absl::string_view json_string) { - auto proto_or = jwt_internal::JsonStringToProtoStruct(json_string); +util::StatusOr RawJwt::FromJson(absl::optional type_header, + absl::string_view json_payload) { + auto proto_or = jwt_internal::JsonStringToProtoStruct(json_payload); if (!proto_or.ok()) { return proto_or.status(); } @@ -166,20 +167,31 @@ util::StatusOr RawJwt::FromString(absl::string_view json_string) { if (!audStatus.ok()) { return audStatus; } - RawJwt token(proto); + RawJwt token(type_header, proto); return token; } -util::StatusOr RawJwt::ToString() const { +util::StatusOr RawJwt::GetJsonPayload() const { return jwt_internal::ProtoStructToJsonString(json_proto_); } RawJwt::RawJwt() {} -RawJwt::RawJwt(google::protobuf::Struct json_proto) { +RawJwt::RawJwt(absl::optional type_header, + google::protobuf::Struct json_proto) { + type_header_ = type_header; json_proto_ = json_proto; } +bool RawJwt::HasTypeHeader() const { return type_header_.has_value(); } + +util::StatusOr RawJwt::GetTypeHeader() const { + if (!type_header_.has_value()) { + return util::Status(util::error::INVALID_ARGUMENT, "No type header found"); + } + return *type_header_; +} + bool RawJwt::HasIssuer() const { return json_proto_.fields().contains(std::string(kJwtClaimIssuer)); } @@ -457,6 +469,11 @@ std::vector RawJwt::CustomClaimNames() const { RawJwtBuilder::RawJwtBuilder() {} +RawJwtBuilder& RawJwtBuilder::SetTypeHeader(absl::string_view type_header) { + type_header_ = std::string(type_header); + return *this; +} + RawJwtBuilder& RawJwtBuilder::SetIssuer(absl::string_view issuer) { auto fields = json_proto_.mutable_fields(); google::protobuf::Value value; @@ -612,7 +629,7 @@ util::Status RawJwtBuilder::AddJsonArrayClaim(absl::string_view name, } util::StatusOr RawJwtBuilder::Build() { - RawJwt token(json_proto_); + RawJwt token(type_header_, json_proto_); return token; } diff --git a/cc/jwt/raw_jwt.h b/cc/jwt/raw_jwt.h index 73bc8550d8..8ccca582e1 100644 --- a/cc/jwt/raw_jwt.h +++ b/cc/jwt/raw_jwt.h @@ -36,6 +36,8 @@ class RawJwt { public: RawJwt(); + bool HasTypeHeader() const; + util::StatusOr GetTypeHeader() const; bool HasIssuer() const; util::StatusOr GetIssuer() const; bool HasSubject() const; @@ -63,8 +65,9 @@ class RawJwt { util::StatusOr GetJsonArrayClaim(absl::string_view name) const; std::vector CustomClaimNames() const; - static util::StatusOr FromString(absl::string_view json_string); - util::StatusOr ToString() const; + static util::StatusOr FromJson( + absl::optional type_header, absl::string_view json_payload); + util::StatusOr GetJsonPayload() const; // RawJwt objects are copiable and movable. RawJwt(const RawJwt&) = default; @@ -73,8 +76,10 @@ class RawJwt { RawJwt& operator=(RawJwt&& other) = default; private: - explicit RawJwt(google::protobuf::Struct json_proto); + explicit RawJwt(absl::optional type_header, + google::protobuf::Struct json_proto); friend class RawJwtBuilder; + absl::optional type_header_; google::protobuf::Struct json_proto_; }; @@ -82,6 +87,7 @@ class RawJwtBuilder { public: RawJwtBuilder(); + RawJwtBuilder& SetTypeHeader(absl::string_view type_header); RawJwtBuilder& SetIssuer(absl::string_view issuer); RawJwtBuilder& SetSubject(absl::string_view subject); RawJwtBuilder& AddAudience(absl::string_view audience); @@ -107,6 +113,7 @@ class RawJwtBuilder { RawJwtBuilder& operator=(RawJwtBuilder&& other) = default; private: + absl::optional type_header_; google::protobuf::Struct json_proto_; }; diff --git a/cc/jwt/raw_jwt_test.cc b/cc/jwt/raw_jwt_test.cc index 2253038472..5b3daacc96 100644 --- a/cc/jwt/raw_jwt_test.cc +++ b/cc/jwt/raw_jwt_test.cc @@ -31,8 +31,9 @@ using ::testing::UnorderedElementsAreArray; namespace crypto { namespace tink { -TEST(RawJwt, GetIssuerSubjectJwtIdOK) { +TEST(RawJwt, GetTypeHeaderIssuerSubjectJwtIdOK) { auto jwt_or = RawJwtBuilder() + .SetTypeHeader("typeHeader") .SetIssuer("issuer") .SetSubject("subject") .SetJwtId("jwt_id") @@ -40,6 +41,8 @@ TEST(RawJwt, GetIssuerSubjectJwtIdOK) { ASSERT_THAT(jwt_or.status(), IsOk()); auto jwt = jwt_or.ValueOrDie(); + EXPECT_TRUE(jwt.HasTypeHeader()); + EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader")); EXPECT_TRUE(jwt.HasIssuer()); EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer")); EXPECT_TRUE(jwt.HasSubject()); @@ -255,6 +258,7 @@ TEST(RawJwt, EmptyTokenHasAndIsReturnsFalse) { ASSERT_THAT(jwt_or.status(), IsOk()); auto jwt = jwt_or.ValueOrDie(); + EXPECT_FALSE(jwt.HasTypeHeader()); EXPECT_FALSE(jwt.HasIssuer()); EXPECT_FALSE(jwt.HasSubject()); EXPECT_FALSE(jwt.HasAudiences()); @@ -275,6 +279,7 @@ TEST(RawJwt, EmptyTokenGetReturnsNotOK) { ASSERT_THAT(jwt_or.status(), IsOk()); auto jwt = jwt_or.ValueOrDie(); + EXPECT_FALSE(jwt.GetTypeHeader().ok()); EXPECT_FALSE(jwt.GetIssuer().ok()); EXPECT_FALSE(jwt.GetSubject().ok()); EXPECT_FALSE(jwt.GetAudiences().ok()); @@ -307,12 +312,14 @@ TEST(RawJwt, BuildCanBeCalledTwice) { EXPECT_THAT(jwt2.GetSubject(), IsOkAndHolds("subject2")); } -TEST(RawJwt, FromString) { - auto jwt_or = RawJwt::FromString( +TEST(RawJwt, FromJson) { + auto jwt_or = RawJwt::FromJson( + absl::nullopt, R"({"iss":"issuer", "sub":"subject", "exp":123, "aud":["a1", "a2"]})"); ASSERT_THAT(jwt_or.status(), IsOk()); RawJwt jwt = jwt_or.ValueOrDie(); + EXPECT_FALSE(jwt.HasTypeHeader()); EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer")); EXPECT_THAT(jwt.GetSubject(), IsOkAndHolds("subject")); EXPECT_THAT(jwt.GetExpiration(), IsOkAndHolds(absl::FromUnixSeconds(123))); @@ -320,8 +327,17 @@ TEST(RawJwt, FromString) { EXPECT_THAT(jwt.GetAudiences(), IsOkAndHolds(expected_audiences)); } -TEST(RawJwt, FromStringExpExpiration) { - auto jwt_or = RawJwt::FromString(R"({"exp":1e10})"); +TEST(RawJwt, FromJsonWithTypeHeader) { + auto jwt_or = RawJwt::FromJson("typeHeader", R"({"iss":"issuer"})"); + ASSERT_THAT(jwt_or.status(), IsOk()); + RawJwt jwt = jwt_or.ValueOrDie(); + + EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader")); + EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer")); +} + +TEST(RawJwt, FromJsonExpExpiration) { + auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":1e10})"); ASSERT_THAT(jwt_or.status(), IsOk()); RawJwt jwt = jwt_or.ValueOrDie(); @@ -329,18 +345,18 @@ TEST(RawJwt, FromStringExpExpiration) { IsOkAndHolds(absl::FromUnixSeconds(10000000000))); } -TEST(RawJwt, FromStringExpirationTooLarge) { - auto jwt_or = RawJwt::FromString(R"({"exp":1e30})"); +TEST(RawJwt, FromJsonExpirationTooLarge) { + auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":1e30})"); EXPECT_FALSE(jwt_or.ok()); } -TEST(RawJwt, FromStringNegativeExpirationAreInvalid) { - auto jwt_or = RawJwt::FromString(R"({"exp":-1})"); +TEST(RawJwt, FromJsonNegativeExpirationAreInvalid) { + auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"exp":-1})"); EXPECT_FALSE(jwt_or.ok()); } -TEST(RawJwt, FromStringConvertsStringAudIntoListOfStrings) { - auto jwt_or = RawJwt::FromString(R"({"aud":"audience"})"); +TEST(RawJwt, FromJsonConvertsStringAudIntoListOfStrings) { + auto jwt_or = RawJwt::FromJson(absl::nullopt, R"({"aud":"audience"})"); ASSERT_THAT(jwt_or.status(), IsOk()); RawJwt jwt = jwt_or.ValueOrDie(); @@ -349,23 +365,23 @@ TEST(RawJwt, FromStringConvertsStringAudIntoListOfStrings) { EXPECT_THAT(jwt.GetAudiences(), IsOkAndHolds(expected)); } -TEST(RawJwt, FromStringWithBadRegisteredTypes) { - EXPECT_FALSE(RawJwt::FromString(R"({"iss":123})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"sub":123})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"aud":123})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"aud":[]})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"aud":["abc",123]})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"exp":"abc"})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"nbf":"abc"})").ok()); - EXPECT_FALSE(RawJwt::FromString(R"({"iat":"abc"})").ok()); +TEST(RawJwt, FromJsonWithBadRegisteredTypes) { + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"iss":123})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"sub":123})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":123})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":[]})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"aud":["abc",123]})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"exp":"abc"})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"nbf":"abc"})").ok()); + EXPECT_FALSE(RawJwt::FromJson(absl::nullopt, R"({"iat":"abc"})").ok()); } -TEST(RawJwt, ToString) { +TEST(RawJwt, GetJsonPayload) { auto jwt_or = RawJwtBuilder().SetIssuer("issuer").Build(); ASSERT_THAT(jwt_or.status(), IsOk()); auto jwt = jwt_or.ValueOrDie(); - ASSERT_THAT(jwt.ToString(), IsOkAndHolds(R"({"iss":"issuer"})")); + ASSERT_THAT(jwt.GetJsonPayload(), IsOkAndHolds(R"({"iss":"issuer"})")); } } // namespace tink diff --git a/cc/jwt/verified_jwt.cc b/cc/jwt/verified_jwt.cc index 708230a7ce..6db74d3fd8 100644 --- a/cc/jwt/verified_jwt.cc +++ b/cc/jwt/verified_jwt.cc @@ -30,6 +30,12 @@ VerifiedJwt::VerifiedJwt(const RawJwt& raw_jwt) { raw_jwt_ = raw_jwt; } +bool VerifiedJwt::HasTypeHeader() const { return raw_jwt_.HasTypeHeader(); } + +util::StatusOr VerifiedJwt::GetTypeHeader() const { + return raw_jwt_.GetTypeHeader(); +} + bool VerifiedJwt::HasIssuer() const { return raw_jwt_.HasIssuer(); } @@ -139,8 +145,8 @@ std::vector VerifiedJwt::CustomClaimNames() const { return raw_jwt_.CustomClaimNames(); } -util::StatusOr VerifiedJwt::ToString() { - return raw_jwt_.ToString(); +util::StatusOr VerifiedJwt::GetJsonPayload() { + return raw_jwt_.GetJsonPayload(); } } // namespace tink diff --git a/cc/jwt/verified_jwt.h b/cc/jwt/verified_jwt.h index 67d187c20a..bf9c42d475 100644 --- a/cc/jwt/verified_jwt.h +++ b/cc/jwt/verified_jwt.h @@ -48,6 +48,8 @@ class VerifiedJwt { VerifiedJwt(const VerifiedJwt&) = default; VerifiedJwt& operator=(const VerifiedJwt&) = default; + bool HasTypeHeader() const; + util::StatusOr GetTypeHeader() const; bool HasIssuer() const; util::StatusOr GetIssuer() const; bool HasSubject() const; @@ -76,7 +78,7 @@ class VerifiedJwt { util::StatusOr GetJsonArrayClaim(absl::string_view name) const; std::vector CustomClaimNames() const; - util::StatusOr ToString(); + util::StatusOr GetJsonPayload(); private: VerifiedJwt(); diff --git a/cc/jwt/verified_jwt_test.cc b/cc/jwt/verified_jwt_test.cc index 7bfe6041f4..57e0e93e65 100644 --- a/cc/jwt/verified_jwt_test.cc +++ b/cc/jwt/verified_jwt_test.cc @@ -76,8 +76,9 @@ util::StatusOr CreateVerifiedJwt(const RawJwt& raw_jwt) { validator_builder.Build()); } -TEST(VerifiedJwt, GetIssuerSubjectJwtIdOK) { +TEST(VerifiedJwt, GetTypeIssuerSubjectJwtIdOK) { auto raw_jwt_or = RawJwtBuilder() + .SetTypeHeader("typeHeader") .SetIssuer("issuer") .SetSubject("subject") .SetJwtId("jwt_id") @@ -87,6 +88,8 @@ TEST(VerifiedJwt, GetIssuerSubjectJwtIdOK) { ASSERT_THAT(verified_jwt_or.status(), IsOk()); VerifiedJwt jwt = verified_jwt_or.ValueOrDie(); + EXPECT_TRUE(jwt.HasTypeHeader()); + EXPECT_THAT(jwt.GetTypeHeader(), IsOkAndHolds("typeHeader")); EXPECT_TRUE(jwt.HasIssuer()); EXPECT_THAT(jwt.GetIssuer(), IsOkAndHolds("issuer")); EXPECT_TRUE(jwt.HasSubject()); @@ -255,6 +258,7 @@ TEST(VerifiedJwt, EmptyTokenHasAndIsReturnsFalse) { ASSERT_THAT(verified_jwt_or.status(), IsOk()); VerifiedJwt jwt = verified_jwt_or.ValueOrDie(); + EXPECT_FALSE(jwt.HasTypeHeader()); EXPECT_FALSE(jwt.HasIssuer()); EXPECT_FALSE(jwt.HasSubject()); EXPECT_FALSE(jwt.HasAudiences()); @@ -277,6 +281,7 @@ TEST(VerifiedJwt, EmptyTokenGetReturnsNotOK) { ASSERT_THAT(verified_jwt_or.status(), IsOk()); VerifiedJwt jwt = verified_jwt_or.ValueOrDie(); + EXPECT_FALSE(jwt.GetTypeHeader().ok()); EXPECT_FALSE(jwt.GetIssuer().ok()); EXPECT_FALSE(jwt.GetSubject().ok()); EXPECT_FALSE(jwt.GetAudiences().ok()); @@ -292,14 +297,14 @@ TEST(VerifiedJwt, EmptyTokenGetReturnsNotOK) { EXPECT_FALSE(jwt.GetJsonArrayClaim("array_claim").ok()); } -TEST(VerifiedJwt, ToString) { +TEST(VerifiedJwt, GetJsonPayload) { auto raw_jwt_or = RawJwtBuilder().SetIssuer("issuer").Build(); ASSERT_THAT(raw_jwt_or.status(), IsOk()); auto verified_jwt_or = CreateVerifiedJwt(raw_jwt_or.ValueOrDie()); ASSERT_THAT(verified_jwt_or.status(), IsOk()); VerifiedJwt jwt = verified_jwt_or.ValueOrDie(); - EXPECT_THAT(jwt.ToString(), IsOkAndHolds(R"({"iss":"issuer"})")); + EXPECT_THAT(jwt.GetJsonPayload(), IsOkAndHolds(R"({"iss":"issuer"})")); } TEST(VerifiedJwt, MoveMakesCopy) {