aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--cryptography/hazmat/backends/multibackend.py35
-rw-r--r--tests/hazmat/backends/test_multibackend.py10
2 files changed, 35 insertions, 10 deletions
diff --git a/cryptography/hazmat/backends/multibackend.py b/cryptography/hazmat/backends/multibackend.py
index e560c7df..94152370 100644
--- a/cryptography/hazmat/backends/multibackend.py
+++ b/cryptography/hazmat/backends/multibackend.py
@@ -30,11 +30,19 @@ class PrioritizedMultiBackend(object):
def __init__(self, backends):
self._backends = backends
+ def _filtered_backends(self, interface):
+ for b in self._backends:
+ if isinstance(b, interface):
+ yield b
+
def cipher_supported(self, algorithm, mode):
- return any(b.cipher_supported(algorithm, mode) for b in self._backends)
+ return any(
+ b.cipher_supported(algorithm, mode)
+ for b in self._filtered_backends(CipherBackend)
+ )
def create_symmetric_encryption_ctx(self, algorithm, mode):
- for b in self._backends:
+ for b in self._filtered_backends(CipherBackend):
try:
return b.create_symmetric_encryption_ctx(algorithm, mode)
except UnsupportedAlgorithm:
@@ -42,7 +50,7 @@ class PrioritizedMultiBackend(object):
raise UnsupportedAlgorithm
def create_symmetric_decryption_ctx(self, algorithm, mode):
- for b in self._backends:
+ for b in self._filtered_backends(CipherBackend):
try:
return b.create_symmetric_decryption_ctx(algorithm, mode)
except UnsupportedAlgorithm:
@@ -50,10 +58,13 @@ class PrioritizedMultiBackend(object):
raise UnsupportedAlgorithm
def hash_supported(self, algorithm):
- return any(b.hash_supported(algorithm) for b in self._backends)
+ return any(
+ b.hash_supported(algorithm)
+ for b in self._filtered_backends(HashBackend)
+ )
def create_hash_ctx(self, algorithm):
- for b in self._backends:
+ for b in self._filtered_backends(HashBackend):
try:
return b.create_hash_ctx(algorithm)
except UnsupportedAlgorithm:
@@ -61,10 +72,13 @@ class PrioritizedMultiBackend(object):
raise UnsupportedAlgorithm
def hmac_supported(self, algorithm):
- return any(b.hmac_supported(algorithm) for b in self._backends)
+ return any(
+ b.hmac_supported(algorithm)
+ for b in self._filtered_backends(HMACBackend)
+ )
def create_hmac_ctx(self, key, algorithm):
- for b in self._backends:
+ for b in self._filtered_backends(HMACBackend):
try:
return b.create_hmac_ctx(key, algorithm)
except UnsupportedAlgorithm:
@@ -72,11 +86,14 @@ class PrioritizedMultiBackend(object):
raise UnsupportedAlgorithm
def pbkdf2_hmac_supported(self, algorithm):
- return any(b.pbkdf2_hmac_supported(algorithm) for b in self._backends)
+ return any(
+ b.pbkdf2_hmac_supported(algorithm)
+ for b in self._filtered_backends(PBKDF2HMACBackend)
+ )
def derive_pbkdf2_hmac(self, algorithm, length, salt, iterations,
key_material):
- for b in self._backends:
+ for b in self._filtered_backends(PBKDF2HMACBackend):
try:
return b.derive_pbkdf2_hmac(
algorithm, length, salt, iterations, key_material
diff --git a/tests/hazmat/backends/test_multibackend.py b/tests/hazmat/backends/test_multibackend.py
index 03b3187b..127c0d3e 100644
--- a/tests/hazmat/backends/test_multibackend.py
+++ b/tests/hazmat/backends/test_multibackend.py
@@ -13,13 +13,17 @@
import pytest
+from cryptography import utils
from cryptography.exceptions import UnsupportedAlgorithm
+from cryptography.hazmat.backends.interfaces import (
+ CipherBackend, HashBackend, HMACBackend, PBKDF2HMACBackend
+)
from cryptography.hazmat.backends.multibackend import PrioritizedMultiBackend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+@utils.register_interface(CipherBackend)
class DummyCipherBackend(object):
def __init__(self, supported_ciphers):
self._ciphers = supported_ciphers
@@ -36,6 +40,7 @@ class DummyCipherBackend(object):
raise UnsupportedAlgorithm
+@utils.register_interface(HashBackend)
class DummyHashBackend(object):
def __init__(self, supported_algorithms):
self._algorithms = supported_algorithms
@@ -48,6 +53,7 @@ class DummyHashBackend(object):
raise UnsupportedAlgorithm
+@utils.register_interface(HMACBackend)
class DummyHMACBackend(object):
def __init__(self, supported_algorithms):
self._algorithms = supported_algorithms
@@ -60,6 +66,7 @@ class DummyHMACBackend(object):
raise UnsupportedAlgorithm
+@utils.register_interface(PBKDF2HMACBackend)
class DummyPBKDF2HMAC(object):
def __init__(self, supported_algorithms):
self._algorithms = supported_algorithms
@@ -77,6 +84,7 @@ class DummyPBKDF2HMAC(object):
class TestPrioritizedMultiBackend(object):
def test_ciphers(self):
backend = PrioritizedMultiBackend([
+ DummyHashBackend([]),
DummyCipherBackend([
(algorithms.AES, modes.CBC),
])