|
|
|
@ -0,0 +1,308 @@ |
|
|
|
|
import enum |
|
|
|
|
import itertools |
|
|
|
|
from datetime import datetime |
|
|
|
|
from typing import Type, Sequence, Optional as O, Tuple, Union |
|
|
|
|
|
|
|
|
|
# Core imports |
|
|
|
|
import asn1crypto.algos |
|
|
|
|
from cryptography.exceptions import InvalidSignature |
|
|
|
|
from cryptography.x509 import Certificate, BasicConstraints, KeyUsage, ExtendedKeyUsage |
|
|
|
|
from cryptography.x509.oid import ObjectIdentifier |
|
|
|
|
from cryptography.hazmat.primitives.hashes import HashAlgorithm, Hash |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.padding import AsymmetricPadding |
|
|
|
|
try: |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.types import CERTIFICATE_PUBLIC_KEY_TYPES |
|
|
|
|
except ImportError: |
|
|
|
|
from cryptography.hazmat._types import _PUBLIC_KEY_TYPES as CERTIFICATE_PUBLIC_KEY_TYPES |
|
|
|
|
|
|
|
|
|
# Algorithm imports |
|
|
|
|
from cryptography.hazmat.primitives.hashes import ( |
|
|
|
|
MD5, SHA1, SHA224, SHA256, SHA384, SHA512, |
|
|
|
|
SHA3_224, SHA3_256, SHA3_384, SHA3_512, SHAKE128, SHAKE256, |
|
|
|
|
) |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey |
|
|
|
|
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15, PSS, MGF1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VerifyAspect: |
|
|
|
|
used_extension_oids: list[ObjectIdentifier] = [] |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
class VerifyFailure(Exception): |
|
|
|
|
def __init__(self, cert: Certificate, aspects: list[VerifyAspect]) -> None: |
|
|
|
|
self.cert = cert |
|
|
|
|
self.aspects = aspects |
|
|
|
|
super().__init__('{} failed verification: {}'.format( |
|
|
|
|
cert.subject.rfc4514_string(), |
|
|
|
|
', '.join(str(a) for a in aspects), |
|
|
|
|
)) |
|
|
|
|
|
|
|
|
|
class Verifier: |
|
|
|
|
def __init__(self, aspects: O[list[VerifyAspect]] = None) -> None: |
|
|
|
|
self.aspects = aspects or DEFAULT_VERIFY_ASPECTS |
|
|
|
|
|
|
|
|
|
def verify_cert(self, cert: Certificate, issuer: Certificate, extra_aspects: list[VerifyAspect] = []) -> None: |
|
|
|
|
failed_aspects = [] |
|
|
|
|
for a in itertools.chain(self.aspects, extra_aspects): |
|
|
|
|
if not a.verify(cert, issuer): |
|
|
|
|
failed_aspects.append(a) |
|
|
|
|
if failed_aspects: |
|
|
|
|
raise VerifyFailure(cert, failed_aspects) |
|
|
|
|
|
|
|
|
|
def verify_chain(self, cert: Certificate, roots: Sequence[Certificate], chain: O[Sequence[Certificate]] = None, |
|
|
|
|
extra_aspects: list[VerifyAspect] = [], extra_issuer_aspects: list[VerifyAspect] = [], extra_leaf_aspects: list[VerifyAspect] = []) -> None: |
|
|
|
|
if chain is None: |
|
|
|
|
chain = [] |
|
|
|
|
chain.extend(roots) |
|
|
|
|
|
|
|
|
|
path_len = 0 |
|
|
|
|
curr_cert = cert |
|
|
|
|
while curr_cert not in roots: |
|
|
|
|
aspects = extra_aspects.copy() |
|
|
|
|
if curr_cert == cert: |
|
|
|
|
aspects.extend(extra_leaf_aspects) |
|
|
|
|
else: |
|
|
|
|
aspects.extend(extra_issuer_aspects) |
|
|
|
|
for issuer in chain: |
|
|
|
|
if not VerifySignature().verify(curr_cert, issuer): |
|
|
|
|
continue |
|
|
|
|
self.verify_cert(curr_cert, issuer, extra_aspects=aspects) |
|
|
|
|
curr_cert = issuer |
|
|
|
|
break |
|
|
|
|
else: |
|
|
|
|
raise VerifyFailure(curr_cert, []) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_sig_algo_from_asn(name: str) -> Type[CERTIFICATE_PUBLIC_KEY_TYPES]: |
|
|
|
|
return { |
|
|
|
|
'dsa': DSAPublicKey, |
|
|
|
|
'rsassa': RSAPublicKey, |
|
|
|
|
'ecdsa': EllipticCurvePublicKey, |
|
|
|
|
'ed448': Ed448PublicKey, |
|
|
|
|
'ed25519': Ed25519PrivateKey, |
|
|
|
|
}[name] |
|
|
|
|
|
|
|
|
|
def get_hash_algo_from_asn(name: str) -> HashAlgorithm: |
|
|
|
|
return { |
|
|
|
|
'md5': MD5(), |
|
|
|
|
'sha1': SHA1(), |
|
|
|
|
'sha224': SHA224(), |
|
|
|
|
'sha256': SHA256(), |
|
|
|
|
'sha384': SHA384(), |
|
|
|
|
'sha512': SHA512(), |
|
|
|
|
'sha3_224': SHA3_224(), |
|
|
|
|
'sha3_256': SHA3_256(), |
|
|
|
|
'sha3_384': SHA3_384(), |
|
|
|
|
'sha3_512': SHA3_512(), |
|
|
|
|
'shake128': SHAKE128(32), |
|
|
|
|
'shake256': SHAKE256(64), |
|
|
|
|
}[name] |
|
|
|
|
|
|
|
|
|
def get_mask_algo_from_asn(name: str, params): |
|
|
|
|
return { |
|
|
|
|
'mgf1': MGF1(get_hash_algo_from_asn(params['algorithm'].native)), |
|
|
|
|
}[name] |
|
|
|
|
|
|
|
|
|
def get_pad_from_asn(name: str, params) -> AsymmetricPadding: |
|
|
|
|
if name == 'pkcs1v15': |
|
|
|
|
return PKCS1v15() |
|
|
|
|
elif name == 'pss': |
|
|
|
|
mgp = params['mask_gen_algorithm'] |
|
|
|
|
return PSS(get_mask_algo_from_asn(mgp['algorithm'].native, mgp['parameters']), params['salt_length'].native) |
|
|
|
|
else: |
|
|
|
|
raise KeyError(name) |
|
|
|
|
|
|
|
|
|
def get_sig_algo_from_oid(oid: str) -> Tuple[Type[CERTIFICATE_PUBLIC_KEY_TYPES], O[HashAlgorithm], O[AsymmetricPadding]]: |
|
|
|
|
algo = asn1crypto.algos.SignedDigestAlgorithm({ |
|
|
|
|
'algorithm': oid, |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
sig_name, *sig_params = algo.signature_algo.split('_') |
|
|
|
|
sig_algo = get_sig_algo_from_asn(sig_name) |
|
|
|
|
if algo.signature_algo != algo['algorithm'].native: |
|
|
|
|
hash_algo = get_hash_algo_from_asn(algo.hash_algo) |
|
|
|
|
else: |
|
|
|
|
hash_algo = None |
|
|
|
|
pad_algo = get_pad_from_asn(sig_params[0], algo['parameters']) if sig_params else None |
|
|
|
|
return sig_algo, hash_algo, pad_algo |
|
|
|
|
|
|
|
|
|
def get_sig_algo_from_cert(cert: Certificate) -> Tuple[Type[CERTIFICATE_PUBLIC_KEY_TYPES], HashAlgorithm, O[AsymmetricPadding]]: |
|
|
|
|
return get_sig_algo_from_oid(cert.signature_algorithm_oid.dotted_string) |
|
|
|
|
|
|
|
|
|
def verify_signature(key: CERTIFICATE_PUBLIC_KEY_TYPES, hash_algo: HashAlgorithm, padding: AsymmetricPadding, signature: bytes, content: bytes) -> bool: |
|
|
|
|
try: |
|
|
|
|
key.verify(signature, content, padding, hash_algo) |
|
|
|
|
return True |
|
|
|
|
except InvalidSignature: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Default aspects |
|
|
|
|
|
|
|
|
|
class VerifySignature(VerifyAspect): |
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
key_type, _, key_pad = get_sig_algo_from_cert(cert) |
|
|
|
|
key = issuer.public_key() |
|
|
|
|
if not isinstance(key, key_type): |
|
|
|
|
return False |
|
|
|
|
return verify_signature(key, |
|
|
|
|
cert.signature_hash_algorithm, |
|
|
|
|
key_pad, |
|
|
|
|
cert.signature, |
|
|
|
|
cert.tbs_certificate_bytes, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
class VerifyIssuerBefore(VerifyAspect): |
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
return issuer.not_valid_before <= cert.not_valid_before <= issuer.not_valid_after |
|
|
|
|
|
|
|
|
|
class VerifyIssuerAfter(VerifyAspect): |
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
return issuer.not_valid_before <= cert.not_valid_after <= issuer.not_valid_after |
|
|
|
|
|
|
|
|
|
class VerifyIssuerPurpose(VerifyAspect): |
|
|
|
|
used_extension_oids = [BasicConstraints.oid, KeyUsage.oid] |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
constraints = issuer.extensions.get_extension_for_class(BasicConstraints) |
|
|
|
|
if constraints: |
|
|
|
|
if not constraints.value.ca: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
usage = issuer.extensions.get_extension_for_class(KeyUsage) |
|
|
|
|
if usage and not usage.value.key_cert_sign: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
class VerifyCertBefore(VerifyAspect): |
|
|
|
|
def __init__(self, now: O[datetime] = None) -> None: |
|
|
|
|
self.now = now |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
now = self.now or datetime.now() |
|
|
|
|
return cert.not_valid_before <= now |
|
|
|
|
|
|
|
|
|
class VerifyCertAfter(VerifyAspect): |
|
|
|
|
def __init__(self, now: O[datetime] = None) -> None: |
|
|
|
|
self.now = now |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
now = self.now or datetime.now() |
|
|
|
|
return now <= cert.not_valid_after |
|
|
|
|
|
|
|
|
|
KEY_ALGORITHM_TYPE = Union[ |
|
|
|
|
CERTIFICATE_PUBLIC_KEY_TYPES, |
|
|
|
|
Tuple[CERTIFICATE_PUBLIC_KEY_TYPES, int], |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
class VerifyKeyAlgorithms(VerifyAspect): |
|
|
|
|
def __init__(self, bad_algorithms: Sequence[KEY_ALGORITHM_TYPE]) -> None: |
|
|
|
|
self.bad_algorithms = tuple(bad_algorithms) |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
key = cert.public_key() |
|
|
|
|
for entry in self.bad_algorithms: |
|
|
|
|
if isinstance(entry, tuple): |
|
|
|
|
algo, key_size = entry |
|
|
|
|
else: |
|
|
|
|
algo = entry |
|
|
|
|
key_size = None |
|
|
|
|
if isinstance(key, algo): |
|
|
|
|
if key_size is None or key.key_size < key_size: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
HASH_ALGORITHM_TYPE = HashAlgorithm |
|
|
|
|
|
|
|
|
|
class VerifyHashAlgorithms(VerifyAspect): |
|
|
|
|
def __init__(self, bad_algorithms: Sequence[HASH_ALGORITHM_TYPE]) -> None: |
|
|
|
|
self.bad_algorithms = tuple(bad_algorithms) |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
if isinstance(cert.signature_hash_algorithm, self.bad_algorithms): |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
class VerifyUnknownExtensions(VerifyAspect): |
|
|
|
|
def __init__(self, known_extensions: Sequence[ObjectIdentifier], only_critical: bool = True) -> None: |
|
|
|
|
self.known_extensions = set(known_extensions) |
|
|
|
|
self.only_critical = only_critical |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
for ext in itertools.chain(cert.extensions, issuer.extensions): |
|
|
|
|
if ext.oid not in self.known_extensions: |
|
|
|
|
if ext.critical or not self.only_critical: |
|
|
|
|
return False |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_BAD_KEY_ALGORITHMS = [ |
|
|
|
|
DSAPublicKey, |
|
|
|
|
(RSAPublicKey, 1024), |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
DEFAULT_BAD_HASH_ALGORITHMS = [MD5, SHA1] |
|
|
|
|
|
|
|
|
|
def make_default_aspects( |
|
|
|
|
now: O[datetime] = None, |
|
|
|
|
bad_key_algos: Sequence[KEY_ALGORITHM_TYPE] = DEFAULT_BAD_KEY_ALGORITHMS, |
|
|
|
|
bad_hash_algos: Sequence[HashAlgorithm] = DEFAULT_BAD_HASH_ALGORITHMS, |
|
|
|
|
) -> list[VerifyAspect]: |
|
|
|
|
aspects = [ |
|
|
|
|
VerifySignature(), |
|
|
|
|
VerifyIssuerBefore(), |
|
|
|
|
VerifyIssuerAfter(), |
|
|
|
|
VerifyIssuerPurpose(), |
|
|
|
|
VerifyCertBefore(now), |
|
|
|
|
VerifyCertAfter(now), |
|
|
|
|
VerifyKeyAlgorithms(bad_key_algos), |
|
|
|
|
VerifyHashAlgorithms(bad_hash_algos), |
|
|
|
|
] |
|
|
|
|
used_extensions = [] |
|
|
|
|
for a in aspects: |
|
|
|
|
used_extensions.extend(a.used_extension_oids) |
|
|
|
|
aspects.append(VerifyUnknownExtensions(used_extensions)) |
|
|
|
|
return aspects |
|
|
|
|
|
|
|
|
|
DEFAULT_VERIFY_ASPECTS = make_default_aspects() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Specialty aspects |
|
|
|
|
|
|
|
|
|
class VerifyPathLength(VerifyAspect): |
|
|
|
|
used_extension_oids = [BasicConstraints.oid] |
|
|
|
|
|
|
|
|
|
def __init__(self, path_length: int = 0) -> None: |
|
|
|
|
self.path_length = path_length |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
path_length = self.path_length |
|
|
|
|
cert_constraints = cert.extensions.get_extension_for_class(BasicConstraints) |
|
|
|
|
if cert_constraints and cert_constraints.value.ca: |
|
|
|
|
path_length += 1 |
|
|
|
|
constraints = issuer.extensions.get_extension_for_class(BasicConstraints) |
|
|
|
|
if constraints: |
|
|
|
|
if constraints.value.path_length is not None: |
|
|
|
|
if constraints.value.path_length < path_length: |
|
|
|
|
return False |
|
|
|
|
if cert_constraints.value.path_length is None or cert_constraints.value.path_length >= constraints.value.path_length: |
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
class VerifyExtendedKeyUsage(VerifyAspect): |
|
|
|
|
used_extension_oids = [ExtendedKeyUsage.oid] |
|
|
|
|
|
|
|
|
|
def __init__(self, usages: Sequence[ObjectIdentifier]) -> None: |
|
|
|
|
self.usages = set(usages) |
|
|
|
|
|
|
|
|
|
def verify(self, cert: Certificate, issuer: Certificate) -> bool: |
|
|
|
|
ext_key_usages = cert.extension.get_extension_for_class(ExtendedKeyUsage) |
|
|
|
|
if not ext_key_usages: |
|
|
|
|
return False |
|
|
|
|
return self.usages & set(ext_key_usages.value.usages) == self.usages |