pkcs7: fix hashing, return recipient with decryption

This commit is contained in:
Shiz 2022-09-28 05:16:54 +02:00
parent 26e1475c51
commit eab8316ca2
1 changed files with 15 additions and 24 deletions

View File

@ -11,7 +11,7 @@ from cryptography.x509 import load_der_x509_certificate
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
from cryptography.hazmat.primitives.asymmetric.utils import Prehashed
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.hashes import Hash, MD5, SHA1
from cryptography.hazmat.primitives.hashes import Hash, SHA1
from cryptography.hazmat.primitives.padding import PKCS7 as PKCS7Padding
from cryptography.hazmat.primitives.cmac import CMAC
@ -21,7 +21,10 @@ from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
from cryptography.hazmat.primitives.ciphers.algorithms import TripleDES, AES
from cryptography.hazmat.primitives.ciphers.modes import ECB, CBC, OFB, CFB
from cryptography_x509_verify import VerifyAspect, Verifier, VerifyFailure, get_sig_algo_from_oid
from cryptography_x509_verify import (
VerifyAspect, Verifier, VerifyFailure,
get_sig_algo_from_oid, get_hash_algo_from_asn,
)
def to_cms_cert(cert):
@ -45,10 +48,7 @@ def to_cms_hash_algo(algo):
def from_cms_hash_algo(desc):
algo, *params = desc['algorithm'].native.split('_')
return {
'md5': MD5(),
'sha1': SHA1(),
}[algo]
return get_hash_algo_from_asn(algo)
def to_cms_asym_pad(pad):
return {
@ -118,9 +118,8 @@ def match_cms_cert_id(recip, cms_cert):
return match_cms_cert_issuer_serial(recip, cms_cert.issuer, cms_cert.serial_number)
def match_cms_cert_issuer_serial(recip, issuer, serial):
issuer_name = asn1crypto.x509.Name.load(issuer.dump())
return (isinstance(recip, asn1crypto.cms.IssuerAndSerialNumber)
and match_x509_name(issuer_name, recip['issuer']) and recip['serial_number'].native == serial)
and match_x509_name(issuer, recip['issuer']) and recip['serial_number'].native == serial)
def p7data_hash(hash_algo, attrs, content):
@ -261,7 +260,7 @@ class PKCS7Signer:
self.msg = msg
def matches(self, cert):
return match_cms_cert_id(self.msg['sid'].chosen, cert)
return match_cms_cert_id(self.msg['sid'].chosen, to_cms_cert(cert))
def get_hash_algo(self):
return from_cms_hash_algo(self.msg['digest_algorithm'])
@ -287,7 +286,7 @@ class PKCS7Recipient:
self.msg = msg
def matches(self, cert):
return match_cms_cert_id(self.msg['rid'].chosen, cert)
return match_cms_cert_id(self.msg['rid'].chosen, to_cms_cert(cert))
def get_key_algo(self):
return from_cms_asym_algo(self.msg['key_encryption_algorithm'])
@ -338,7 +337,7 @@ class PKCS7:
'sid': to_cms_cert_id(cert),
'digest_algorithm': to_cms_hash_algo(algo),
'signature_algorithm': to_cms_sign_algo(key.public_key(), pad),
'signature': p7data_sign(key, pad, algo, 'data', data),
'signature': p7data_sign(key, pad, algo, 'data', data, use_attrs=True),
})
@classmethod
@ -467,15 +466,14 @@ class PKCS7:
return self.decrypt_content(key, content)
def decrypt_for_cert(self, cert, key, content=None):
cms_cert = to_cms_cert(cert)
for recip in self.get_recipients():
if recip.matches(cms_cert):
if recip.matches(cert):
return self.decrypt_for_recipient(recip, key, content)
return None
return None, None
def decrypt_for_recipient(self, recip, key, content=None):
key_dec = recip.decrypt_key(key)
return self.decrypt_content(key_dec, content)
return self.decrypt_content(key_dec, content), recip
def decrypt_content(self, key: bytes, content=None) -> bytes:
algo, mode = self.get_encryption_algo(key)
@ -486,13 +484,7 @@ class PKCS7:
content_pad = cipher.update(content) + cipher.finalize()
unpadder = PKCS7Padding(algo.block_size).unpadder()
content_dec = unpadder.update(content_pad) + unpadder.finalize()
meta = {
'key': key,
}
if hasattr(mode, 'initialization_vector'):
meta['iv'] = mode.initialization_vector
return content_dec, meta
return content_dec
def verify_content(self, content=None):
if not content:
@ -505,9 +497,8 @@ class PKCS7:
)
def verify_for_cert(self, cert, key, content=None):
cms_cert = to_cms_cert(cert)
for recip in self.get_recipients():
if recip.matches(cms_cert):
if recip.matches(cert):
return self.verify_for_recipient(recip, key, content)
return False