Skip to content

Commit

Permalink
... return full chain of certs ...
Browse files Browse the repository at this point in the history
  • Loading branch information
milahu committed Jun 2, 2024
1 parent ecd1def commit e5b05bc
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 119 deletions.
229 changes: 125 additions & 104 deletions aia.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@

# pyopenssl
import OpenSSL
from OpenSSL.crypto import (
X509,
Error,
_openssl_assert,
)
from OpenSSL._util import (
lib as _lib,
ffi as _ffi,
)

# https://cryptography.io/en/latest/x509/
import cryptography
Expand Down Expand Up @@ -160,6 +169,7 @@ def __init__(
cafile=None,
cache_db=None,
cache_dir=None,
verify_depth=100, # default is -1 = infinite
# TODO load/store trusted root certs
# trusted_db=None,
# trusted_dir=None,
Expand All @@ -179,9 +189,12 @@ def __init__(
self.cache_db_con = None
self.cache_db_cur = None
self.cache_dir = cache_dir
self._context = OpenSSL.SSL.Context(method=OpenSSL.SSL.TLS_CLIENT_METHOD)
self._ssl_context = OpenSSL.SSL.Context(method=OpenSSL.SSL.TLS_CLIENT_METHOD)
if verify_depth:
self._ssl_context.set_verify_depth(verify_depth)
# logger.debug(f"verify_depth = {self._ssl_context.get_verify_depth()}")
# this throws OpenSSL.SSL.Error if cafile is missing or empty
self._context.load_verify_locations(cafile=self.cafile)
self._ssl_context.load_verify_locations(cafile=self.cafile)
self._cadata_from_host_regex = dict()
self._trusted_root_certs = list()

Expand All @@ -198,7 +211,7 @@ def get_host_cert_chain(self, host, timeout=5):
port = int(port)
# https://stackoverflow.com/a/67212703/10440128
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
conn = OpenSSL.SSL.Connection(self._context, socket=sock)
conn = OpenSSL.SSL.Connection(self._ssl_context, socket=sock)
conn.settimeout(timeout)
# NOTE this block can throw OpenSSL.SSL.Error ...
conn.connect((host, port))
Expand Down Expand Up @@ -228,16 +241,31 @@ def do_handshake():

conn.close()

full_cert_chain = conn.get_peer_cert_chain()
verified_cert_chain = conn.get_verified_chain()
host_cert_chain = conn.get_peer_cert_chain()

if len(verified_cert_chain) == len(full_cert_chain):
return host_cert_chain

# TODO remove
"""
verified_chain = conn.get_verified_chain()
return host_cert_chain, verified_chain
verified_chain = conn.get_verified_chain()
# TODO remove
'''
if len(verified_chain) == len(peer_cert_chain):
# rest_cert_chain is empty
# this does not mean that the chain is valid
# the server can return only 1 cert
return verified_cert_chain, None
rest_cert_chain = full_cert_chain[len(verified_cert_chain):]
return verified_cert_chain, rest_cert_chain
return verified_chain, None
'''
rest_cert_chain = peer_cert_chain[len(verified_chain):]
return verified_chain, rest_cert_chain
"""

def _init_cache_db(self):
if self.cache_db_con:
Expand Down Expand Up @@ -300,7 +328,7 @@ def _read_cert_cache(self, url_parsed):
OpenSSL.crypto.FILETYPE_ASN1, cert_der
)
return cert
logger.debug("not found cert in cache: {url}")
logger.debug(f"not found cert in cache: {url}")

def _write_cert_cache(self, url_parsed, cert):
if not self.cache_dir and not self.cache_db:
Expand All @@ -314,7 +342,6 @@ def _write_cert_cache(self, url_parsed, cert):
query = "insert into certs (url, cert_der) values (?, ?)"
args = (url, cert_der)
cur = self.cache_db_cur.execute(query, args)
print("cur", cur)
if cur.rowcount != 1:
logger.warning(f"failed to add cert to cache_db: {url}")
# write to disk
Expand Down Expand Up @@ -493,7 +520,7 @@ def remove_trusted_root_cert(self, cert):
len2 = len(self._trusted_root_certs)
return len1 != len2 # return True if cert was removed

def aia_chase(self, host, timeout=5, max_chain_depth=100):
def aia_chase(self, host, timeout=5):
"""
Get the certificate chain for host,
up to (and including) the root certificate.
Expand All @@ -502,141 +529,135 @@ def aia_chase(self, host, timeout=5, max_chain_depth=100):
0 = verified_cert_chain
1 = missing_certs
missing_certs are the extra certs
that had to be fetched to verify the chain.
The first cert in cert_chain is the host certificate,
The first cert in verified_cert_chain is the host certificate,
the next certs are the intermediary certificates,
the last cert is the root certificate.
missing_certs are the extra certs
that had to be fetched to verify the chain.
"""

# TODO throw this when an intermediary cert could not be fetched
# raise ssl.SSLCertVerificationError("unable to get local issuer certificate")

# TODO throw a different error when the root cert is not trusted

# note: at this point, verified_cert_chain can be not-yet fully verified.
# it is not-yet fully verified when the last cert is not a trusted root cert.
verified_cert_chain, rest_cert_chain = self.get_host_cert_chain(host, timeout)

# TODO what to do with rest_cert_chain
# avoid fetching certs if we have them already
host_cert_chain = self.get_host_cert_chain(host, timeout)

print("verified_cert_chain")
print_chain(verified_cert_chain)
print("rest_cert_chain")
print_chain(rest_cert_chain)
# print_chain(host_cert_chain, "host_cert_chain")

# no. when the server sends only 1 cert
# then rest_cert_chain is empty, but the chain can be invalid
# if not rest_cert_chain:
# # full chain is valid, no missing certs were fetched
# return verified_cert_chain, None

# the first cert (leaf cert) is always in verified_cert_chain
if not verified_cert_chain:
# the first cert (leaf cert) is always in host_cert_chain
if not host_cert_chain:
# no certs were received
# TODO throw error?
# assuming the user wants to establish a TLS connection
# assuming the user wants to establish a SSL connection
# but the server did return no certificates
return None, None

def is_root_cert(cert):
if cert.get_subject() != cert.get_issuer():
return False
return True

# TODO test with self-signed leaf cert. require len > 1?
# check if verified_cert_chain is complete
cert = verified_cert_chain[-1]
if is_root_cert(cert):
return verified_cert_chain, None

# chase: fetch missing certs
# https://groups.google.com/a/chromium.org/g/net-dev/c/H-ysp5UM_rk

# store = OpenSSL.crypto.X509Store()
leaf_cert = host_cert_chain[0]

# note: rest_certs can be any certs
# not necessarily a chain with leaf_cert
rest_certs = host_cert_chain[1:]

# local cert store so we can add temporary certs
# https://www.pyopenssl.org/en/stable/api/crypto.html#x509store-objects
store = self._context.get_cert_store()
cert_store = self._ssl_context.get_cert_store()

# add trusted root certs
for cert in self._trusted_root_certs:
store.add_cert(cert)
cert_store.add_cert(cert)

# verified_cert_chain[0:-1] certs are verified
# verified_cert_chain[-1] cert is not verified
missing_certs = []

cert = verified_cert_chain[-1] # not verified
# avoid infinite loop
verify_depth = self._ssl_context.get_verify_depth()
if verify_depth == -1:
verify_depth = 1000

missing_certs = []
for verify_chain_idx in range(verify_depth):

for verify_chain_idx in range(max_chain_depth):
# logger.debug("verify_chain_idx", verify_chain_idx)

print("cert subject", cert.get_subject())
print("cert issuer ", cert.get_issuer())
cert_store_ctx = OpenSSL.crypto.X509StoreContext(
cert_store, leaf_cert, rest_certs + missing_certs
)

aia_ca_issuers = get_ca_issuers_of_cert(cert)
print("aia_ca_issuers", aia_ca_issuers)
if len(aia_ca_issuers) == 0:
raise Exception(
"unable to get local issuer certificate. cert has no aia_ca_issuers"
)
assert len(aia_ca_issuers) > 0
# assert len(aia_ca_issuers) == 1 # ?
issuer_cert = self._get_ca_issuer_cert(aia_ca_issuers[0], timeout)
print("issuer_cert subject", issuer_cert.get_subject())
print("issuer_cert issuer ", issuer_cert.get_issuer())
missing_certs.append(issuer_cert)
# missing_certs = [issuer_cert]

print("missing_certs")
print_chain(missing_certs)

# verify this cert
# while True:
# for i in range(2):
# print("verify try", i)
# https://github.com/pyca/pyopenssl/pull/948
ctx = OpenSSL.crypto.X509StoreContext(store, cert, missing_certs)
# no. this adds *untrusted* certs
# ctx = OpenSSL.crypto.X509StoreContext(
# store, cert, self._trusted_root_certs + missing_certs)
try:
ctx.verify_certificate()
# print("cert is valid. full chain is valid, "
# "no missing certs were fetched")
# cert is valid
# full chain is valid, no missing certs were fetched
# verified_cert_chain.append(issuer_cert.to_cryptography())
verified_cert_chain.append(issuer_cert)
verified_cert_chain = list(
map(lambda c: c.to_cryptography(), verified_cert_chain)
)
cert_store_ctx.verify_certificate()
# full chain is valid
verified_cert_chain = self._get_verified_cert_chain(cert_store_ctx)
return verified_cert_chain, missing_certs

except OpenSSL.crypto.X509StoreContextError as exc:

if exc.errors[0] == 20:
# exc.errors [20, 1, 'unable to get local issuer certificate']
print("chain is not complete -> continuing chase")
# import time; time.sleep(5)
cert = issuer_cert
cert = exc.certificate
logger.debug(
f"fetching missing issuer cert for cert {cert.get_subject()}"
)
aia_ca_issuers = get_ca_issuers_of_cert(cert)
logger.debug("aia_ca_issuers", aia_ca_issuers)
if len(aia_ca_issuers) == 0:
raise Exception(
"unable to get local issuer certificate. "
"cert has no aia_ca_issuers"
)
issuer_cert = self._get_ca_issuer_cert(aia_ca_issuers[0], timeout)
logger.debug("issuer_cert subject", issuer_cert.get_subject())
# logger.debug("issuer_cert issuer ", issuer_cert.get_issuer())
missing_certs.append(issuer_cert)
# print_chain(missing_certs, "missing_certs")
continue

if exc.errors[0] == 19:
# exc.errors [19, 1, 'self-signed certificate in certificate chain']
print(
logger.debug(
"chain ends with untrusted root cert. "
"hint: aia_session.add_trusted_root_cert(cert)"
)
raise
print("exc.args", exc.args)
print("exc.certificate", exc.certificate)
print("exc.errors", exc.errors)
print("exc str", str(exc))
# add return values to exception
# exc._aia_verified_cert_chain[-1] is the untrusted root cert
verified_cert_chain = self._get_verified_cert_chain(cert_store_ctx)
exc._aia_verified_cert_chain = verified_cert_chain
exc._aia_missing_certs = missing_certs
raise exc

raise

# on success, we return from the previous for loop
# TODO use a more specific exception
raise Exception("exceeded max_chain_depth")
raise Exception("exceeded verify_depth")

def _get_verified_cert_chain(self, cert_store_ctx):
# based on OpenSSL.crypto._verify_certificate
_store_ctx = _lib.X509_STORE_CTX_new()
_openssl_assert(_store_ctx != _ffi.NULL)
_store_ctx = _ffi.gc(_store_ctx, _lib.X509_STORE_CTX_free)
ret = _lib.X509_STORE_CTX_init(
_store_ctx,
cert_store_ctx._store._store,
cert_store_ctx._cert._x509,
cert_store_ctx._chain,
)
_openssl_assert(ret == 1)
ret = _lib.X509_verify_cert(_store_ctx)
if ret < 0:
raise cert_store_ctx._exception_from_context(_store_ctx)
# ret == 1: ok
# ret == 0: cert was not verified
# based on OpenSSL.crypto.get_verified_chain
_cert_stack = _lib.X509_STORE_CTX_get1_chain(_store_ctx)
_openssl_assert(_cert_stack != _ffi.NULL)
cert_chain = []
for i in range(_lib.sk_X509_num(_cert_stack)):
_cert = _lib.sk_X509_value(_cert_stack, i)
_openssl_assert(_cert != _ffi.NULL)
cert = X509._from_raw_x509_ptr(_cert)
cert_chain.append(cert)
return cert_chain

def cadata_from_host(self, host, **kwargs):
"""
Expand Down
Loading

0 comments on commit e5b05bc

Please sign in to comment.