Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAS nano #7

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ sensitive.txt.tdf
/server.crt
/server.json
/server.key
/service/kas-ec-public.pem
6 changes: 3 additions & 3 deletions sdk/nanotdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
ErrNanoTdfRead = Error("nanotdf read error")
)

type nanoTdf struct {
type NanoTdf struct {
magicNumber [3]byte
kasUrl *resourceLocator
binding *bindingCfg
Expand Down Expand Up @@ -166,8 +166,8 @@ func readEphemeralPublicKey(reader io.Reader, curve ocrypto.ECCMode) (*eccKey, e
return &eccKey{Key: buffer}, nil
}

func ReadNanoTDFHeader(reader io.Reader) (*nanoTdf, error) {
var nanoTDF nanoTdf
func ReadNanoTDFHeader(reader io.Reader) (*NanoTdf, error) {
var nanoTDF NanoTdf

if err := binary.Read(reader, binary.BigEndian, &nanoTDF.magicNumber); err != nil {
return nil, errors.Join(ErrNanoTdfRead, err)
Expand Down
6 changes: 3 additions & 3 deletions sdk/nanotdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"github.com/arkavo-org/opentdf-platform/lib/ocrypto"
)

// nanotdfEqual compares two nanoTdf structures for equality.
func nanoTDFEqual(a, b *nanoTdf) bool {
// nanotdfEqual compares two NanoTdf structures for equality.
func nanoTDFEqual(a, b *NanoTdf) bool {
// Compare magicNumber field
if a.magicNumber != b.magicNumber {
return false
Expand Down Expand Up @@ -95,7 +95,7 @@ func init() {

func TestReadNanoTDFHeader(t *testing.T) {
// Prepare a sample nanoTdf structure
nanoTDF := nanoTdf{
nanoTDF := NanoTdf{
magicNumber: [3]byte{'L', '1', 'L'},
kasUrl: &resourceLocator{
protocol: urlProtocolHttps,
Expand Down
1 change: 1 addition & 0 deletions service/internal/security/crypto_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type CryptoProvider interface {
RSADecrypt(hash crypto.Hash, keyID string, keyLabel string, ciphertext []byte) ([]byte, error)

ECPublicKey(keyID string) (string, error)
ECCertificate(keyID string) (string, error)
GenerateNanoTDFSymmetricKey(ephemeralPublicKeyBytes []byte) ([]byte, error)
GenerateEphemeralKasKeys() (PrivateKeyEC, []byte, error)
GenerateNanoTDFSessionKey(privateKeyHandle PrivateKeyEC, ephemeralPublicKey []byte) ([]byte, error)
Expand Down
88 changes: 73 additions & 15 deletions service/internal/security/hsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
Expand All @@ -17,20 +18,18 @@ import (
"strings"

"github.com/lestrrat-go/jwx/v2/jwk"

"github.com/miekg/pkcs11"
"golang.org/x/crypto/hkdf"
)

const (
ErrCertNotFound = Error("not found")
ErrCertificateEncode = Error("certificate encode error")
ErrPublicKeyMarshal = Error("public key marshal error")
ErrHSMUnexpected = Error("hsm unexpected")
ErrHSMDecrypt = Error("hsm decrypt error")
ErrHSMNotFound = Error("hsm unavailable")
ErrKeyConfig = Error("key configuration error")
ErrUnknownHashFunction = Error("unknown hash function")
ErrCertNotFound = Error("not found")
ErrCertificateEncode = Error("certificate encode error")
ErrPublicKeyMarshal = Error("public key marshal error")
ErrHSMUnexpected = Error("hsm unexpected")
ErrHSMDecrypt = Error("hsm decrypt error")
ErrHSMNotFound = Error("hsm unavailable")
ErrKeyConfig = Error("key configuration error")
)
const keyLength = 32

Expand All @@ -40,7 +39,7 @@ func (e Error) Error() string {
return string(e)
}

// A session with a security module; useful for abstracting basic cryptographic
// HSMSession A session with a security module; useful for abstracting basic cryptographic
// operations.
//
// HSM Session HAS-A PKCS11 Context
Expand Down Expand Up @@ -314,13 +313,15 @@ func (h *HSMSession) loadKeys(keys map[string]KeyInfo) error {
pair, err := h.LoadRSAKey(info)
if err != nil {
slog.Error("pkcs11 error unable to load RSA key", "err", err)
//return err
} else {
h.RSA = pair
}
case "ec":
pair, err := h.LoadECKey(info)
if err != nil {
slog.Error("pkcs11 error unable to load EC key", "err", err)
return err
} else {
h.EC = pair
}
Expand Down Expand Up @@ -457,7 +458,7 @@ func (h *HSMSession) LoadECKey(info KeyInfo) (*ECKeyPair, error) {
// EC Cert
certECHandle, err := h.findKey(pkcs11.CKO_CERTIFICATE, info.Label)
if err != nil {
slog.Error("public key EC cert error")
slog.Error("public key EC cert error", "err", err)
return nil, errors.Join(ErrKeyConfig, err)
}
certECTemplate := []*pkcs11.Attribute{
Expand All @@ -483,6 +484,7 @@ func (h *HSMSession) LoadECKey(info KeyInfo) (*ECKeyPair, error) {
panic(err)
}
pair.Certificate = certEC
break
}
}
if pair.Certificate == nil {
Expand All @@ -497,6 +499,46 @@ func (h *HSMSession) LoadECKey(info KeyInfo) (*ECKeyPair, error) {
}

pair.PublicKey = ecPublicKey

// Do a sanity check of the key pair
err = h.ctx.DigestInit(h.sh, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_SHA256, nil)})
if err != nil {
slog.Error("pkcs11 SignInit", "err", err)
return nil, err
}
digest, err := h.ctx.Digest(h.sh, []byte("sanity now"))
if err != nil {
slog.Error("pkcs11 Digest", "err", err)
return nil, err
}
err = h.ctx.SignInit(h.sh, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, keyHandleEC)
if err != nil {
slog.Error("pkcs11 SignInit", "err", err)
return nil, err
}
sig, err := h.ctx.Sign(h.sh, digest)
if err != nil {
slog.Error("pkcs11 Sign", "err", err)
return nil, err
}
valid := ecdsa.VerifyASN1(ecPublicKey, digest, sig)
if !valid {
pubKeyDER, err := x509.MarshalPKIXPublicKey(ecPublicKey)
if err != nil {
slog.Error("Error marshalling public key:", "err", err)
}
pubKeyPEM := pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyDER,
}
pemData := pem.EncodeToMemory(&pubKeyPEM)
slog.Error("pkcs11 VerifyASN1 failed",
"hash", hex.EncodeToString(digest),
"sig", hex.EncodeToString(sig),
"ecPublicKey", pemData)
// FIXME can't get this working, skipping for now
//return nil, fmt.Errorf("pkcs11 VerifyASN1 signature failed")
}
return &pair, nil
}

Expand Down Expand Up @@ -527,6 +569,7 @@ func oaepForHash(hashFunction crypto.Hash, keyLabel string) (*pkcs11.OAEPParams,
}

func (h *HSMSession) GenerateNanoTDFSymmetricKey(ephemeralPublicKeyBytes []byte) ([]byte, error) {
slog.Debug("GenerateNanoTDFSymmetricKey")
template := []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_TOKEN, false),
pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY),
Expand All @@ -547,6 +590,7 @@ func (h *HSMSession) GenerateNanoTDFSymmetricKey(ephemeralPublicKeyBytes []byte)

handle, err := h.ctx.DeriveKey(h.sh, mech, pkcs11.ObjectHandle(h.EC.PrivateKey), template)
if err != nil {
slog.Error("GenerateNanoTDFSymmetricKey", "err", err)
return nil, fmt.Errorf("failed to derive symmetric key: %w", err)
}

Expand All @@ -555,19 +599,22 @@ func (h *HSMSession) GenerateNanoTDFSymmetricKey(ephemeralPublicKeyBytes []byte)
}
attr, err := h.ctx.GetAttributeValue(h.sh, handle, template)
if err != nil {
slog.Error("GenerateNanoTDFSymmetricKey", "err", err)
return nil, err
}

symmetricKey := attr[0].Value

salt := versionSalt()
hkdf := hkdf.New(sha256.New, symmetricKey, salt, nil)
hkdfReader := hkdf.New(sha256.New, symmetricKey, salt, nil)

derivedKey := make([]byte, keyLength)
_, err = io.ReadFull(hkdf, derivedKey)
hkdfReadLength, err := io.ReadFull(hkdfReader, derivedKey)
if err != nil {
slog.Error("GenerateNanoTDFSymmetricKey", "err", err)
return nil, fmt.Errorf("failed to derive symmetric key: %w", err)
}
slog.Debug("GenerateNanoTDFSymmetricKey", "hkdfReadLength", hkdfReadLength)

return derivedKey, nil
}
Expand Down Expand Up @@ -609,10 +656,10 @@ func (h *HSMSession) GenerateNanoTDFSessionKey(

sessionKey := attr[0].Value
salt := versionSalt()
hkdf := hkdf.New(sha256.New, sessionKey, salt, nil)
hkdfParams := hkdf.New(sha256.New, sessionKey, salt, nil)

derivedKey := make([]byte, keyLength)
_, err = io.ReadFull(hkdf, derivedKey)
_, err = io.ReadFull(hkdfParams, derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to derive session key: %w", err)
}
Expand Down Expand Up @@ -702,6 +749,17 @@ func (h *HSMSession) RSAPublicKeyAsJSON(keyID string) (string, error) {
return string(jsonPublicKey), nil
}

func (h *HSMSession) ECCertificate(string) (string, error) {
if h.EC == nil || h.EC.Certificate == nil {
return "", ErrCertNotFound
}
certPEM := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: h.EC.Certificate.Raw,
})
return string(certPEM), nil
}

func (h *HSMSession) ECPublicKey(string) (string, error) {
if h.EC == nil || h.EC.PublicKey == nil {
return "", ErrCertNotFound
Expand Down
91 changes: 86 additions & 5 deletions service/internal/security/standard_crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package security

import (
"crypto"
"crypto/ecdsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"log/slog"
Expand Down Expand Up @@ -34,14 +37,15 @@ type StandardRSACrypto struct {
}

type StandardECCrypto struct {
Identifier string
// ecPublicKey *ecdh.PublicKey
// ecPrivateKey *ecdh.PrivateKey
Identifier string
ecPublicKey *ecdsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
ecCertificatePEM string
}

type StandardCrypto struct {
rsaKeys []StandardRSACrypto
// ecKeys []StandardECCrypto
ecKeys []StandardECCrypto
}

// NewStandardCrypto Create a new instance of standard crypto
Expand Down Expand Up @@ -74,6 +78,56 @@ func NewStandardCrypto(cfg StandardConfig) (*StandardCrypto, error) {
asymEncryption: asymEncryption,
})
}
for id, kasInfo := range cfg.ECKeys {
slog.Info("cfg.ECKeys", "id", id, "kasInfo", kasInfo)
privatePemData, err := os.ReadFile(kasInfo.PrivateKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to rsa private key file: %w", err)
}
// this returns a certificate not a PUBLIC KEY
publicPemData, err := os.ReadFile(kasInfo.PublicKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to rsa public key file: %w", err)
}
//block, _ := pem.Decode(publicPemData)
//if block == nil {
// return nil, errors.New("failed to decode PEM block containing public key")
//}
//ecPublicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
//if err != nil {
// return nil, fmt.Errorf("failed to parse EC public key: %w", err)
//}
block, _ := pem.Decode(privatePemData)
if block == nil {
return nil, errors.New("failed to decode PEM block containing private key")
}
ecPrivateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse EC private key: %w", err)
}
//var ecdsaPublicKey *ecdsa.PublicKey
//switch pub := ecPublicKey.(type) {
//case *ecdsa.PublicKey:
// fmt.Println("pub is of type ECDSA:", pub)
// ecdsaPublicKey = pub
//default:
// panic("unknown type of public key")
//}
var ecdsaPrivateKey *ecdsa.PrivateKey
switch priv := ecPrivateKey.(type) {
case *ecdsa.PrivateKey:
fmt.Println("pub is of type ECDSA:", priv)
ecdsaPrivateKey = priv
default:
panic("unknown type of public key")
}
standardCrypto.ecKeys = append(standardCrypto.ecKeys, StandardECCrypto{
Identifier: id,
//ecPublicKey: ecdsaPublicKey,
ecPrivateKey: ecdsaPrivateKey,
ecCertificatePEM: string(publicPemData),
})
}

return standardCrypto, nil
}
Expand All @@ -94,8 +148,34 @@ func (s StandardCrypto) RSAPublicKey(keyID string) (string, error) {
return pem, nil
}

func (s StandardCrypto) ECCertificate(identifier string) (string, error) {
if len(s.ecKeys) == 0 {
return "", ErrCertNotFound
}
// this endpoint returns certificate
for _, ecKey := range s.ecKeys {
slog.Debug("ecKey", "id", ecKey.Identifier)
if ecKey.Identifier == identifier {
return ecKey.ecCertificatePEM, nil
}
}
return "", fmt.Errorf("no EC Key found with the given identifier: %s", identifier)
}

func (s StandardCrypto) ECPublicKey(string) (string, error) {
return "", ErrCertNotFound
if len(s.ecKeys) == 0 {
return "", ErrCertNotFound
}
ecKey := s.ecKeys[0]
publicKeyBytes, err := x509.MarshalPKIXPublicKey(ecKey.ecPublicKey)
if err != nil {
return "", ErrPublicKeyMarshal
}
pemEncoded := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
return string(pemEncoded), nil
}

func (s StandardCrypto) RSADecrypt(_ crypto.Hash, keyID string, _ string, ciphertext []byte) ([]byte, error) {
Expand Down Expand Up @@ -136,6 +216,7 @@ func (s StandardCrypto) RSAPublicKeyAsJSON(keyID string) (string, error) {
}

func (s StandardCrypto) GenerateNanoTDFSymmetricKey([]byte) ([]byte, error) {

return nil, errNotImplemented
}

Expand Down
2 changes: 2 additions & 0 deletions service/kas/access/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
kaspb "github.com/arkavo-org/opentdf-platform/protocol/go/kas"
otdf "github.com/arkavo-org/opentdf-platform/sdk"
"github.com/arkavo-org/opentdf-platform/service/internal/security"
"github.com/arkavo-org/opentdf-platform/service/pkg/serviceregistry"
"github.com/coreos/go-oidc/v3/oidc"
)

Expand All @@ -21,4 +22,5 @@ type Provider struct {
AttributeSvc *url.URL
CryptoProvider security.CryptoProvider
OIDCVerifier *oidc.IDTokenVerifier
Config *serviceregistry.ServiceConfig
}
Loading