diff options
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/base.py | 3 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/ciphers/modes.py | 21 | ||||
-rw-r--r-- | cryptography/hazmat/primitives/interfaces.py | 7 | ||||
-rw-r--r-- | docs/hazmat/primitives/interfaces.rst | 12 | ||||
-rw-r--r-- | tests/hazmat/bindings/test_openssl.py | 3 | ||||
-rw-r--r-- | tests/hazmat/primitives/test_block.py | 17 |
6 files changed, 61 insertions, 2 deletions
diff --git a/cryptography/hazmat/primitives/ciphers/base.py b/cryptography/hazmat/primitives/ciphers/base.py index 3d733afc..d046a012 100644 --- a/cryptography/hazmat/primitives/ciphers/base.py +++ b/cryptography/hazmat/primitives/ciphers/base.py @@ -28,6 +28,9 @@ class Cipher(object): if not isinstance(algorithm, interfaces.CipherAlgorithm): raise TypeError("Expected interface of interfaces.CipherAlgorithm") + if mode is not None: + mode.validate_for_algorithm(algorithm) + self.algorithm = algorithm self.mode = mode self._backend = backend diff --git a/cryptography/hazmat/primitives/ciphers/modes.py b/cryptography/hazmat/primitives/ciphers/modes.py index 1d0de689..597b4e3e 100644 --- a/cryptography/hazmat/primitives/ciphers/modes.py +++ b/cryptography/hazmat/primitives/ciphers/modes.py @@ -25,11 +25,20 @@ class CBC(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) class ECB(object): name = "ECB" + def validate_for_algorithm(self, algorithm): + pass + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -39,6 +48,12 @@ class OFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithInitializationVector) @@ -48,6 +63,12 @@ class CFB(object): def __init__(self, initialization_vector): self.initialization_vector = initialization_vector + def validate_for_algorithm(self, algorithm): + if len(self.initialization_vector) * 8 != algorithm.block_size: + raise ValueError("Invalid iv size ({0}) for {1}".format( + len(self.initialization_vector), self.name + )) + @utils.register_interface(interfaces.Mode) @utils.register_interface(interfaces.ModeWithNonce) diff --git a/cryptography/hazmat/primitives/interfaces.py b/cryptography/hazmat/primitives/interfaces.py index 8cc9d42c..672ac96a 100644 --- a/cryptography/hazmat/primitives/interfaces.py +++ b/cryptography/hazmat/primitives/interfaces.py @@ -39,6 +39,13 @@ class Mode(six.with_metaclass(abc.ABCMeta)): A string naming this mode. (e.g. ECB, CBC) """ + @abc.abstractmethod + def validate_for_algorithm(self, algorithm): + """ + Checks that all the necessary invariants of this (mode, algorithm) + combination are met. + """ + class ModeWithInitializationVector(six.with_metaclass(abc.ABCMeta)): @abc.abstractproperty diff --git a/docs/hazmat/primitives/interfaces.rst b/docs/hazmat/primitives/interfaces.rst index 11cff51a..e798c0e6 100644 --- a/docs/hazmat/primitives/interfaces.rst +++ b/docs/hazmat/primitives/interfaces.rst @@ -56,6 +56,18 @@ Interfaces used by the symmetric cipher modes described in The name may be used by a backend to influence the operation of a cipher in conjunction with the algorithm's name. + .. method:: validate_for_algorithm(algorithm) + + :param CipherAlgorithm algorithm: + + Checks that the combination of this mode with the provided algorithm + meets any necessary invariants. This should raise an exception if they + are not met. + + For example, the :class:`~cryptography.hazmat.primitives.modes.CBC` + mode uses this method to check that the provided initialization + vector's length matches the block size of the algorithm. + .. class:: ModeWithInitializationVector diff --git a/tests/hazmat/bindings/test_openssl.py b/tests/hazmat/bindings/test_openssl.py index 9f27aab7..1cadc75c 100644 --- a/tests/hazmat/bindings/test_openssl.py +++ b/tests/hazmat/bindings/test_openssl.py @@ -23,7 +23,8 @@ from cryptography.hazmat.primitives.ciphers.modes import CBC class DummyMode(object): - pass + def validate_for_algorithm(self, algorithm): + pass @utils.register_interface(interfaces.CipherAlgorithm) diff --git a/tests/hazmat/primitives/test_block.py b/tests/hazmat/primitives/test_block.py index 9460c53d..b41f8922 100644 --- a/tests/hazmat/primitives/test_block.py +++ b/tests/hazmat/primitives/test_block.py @@ -30,6 +30,11 @@ class DummyCipher(object): pass +class DummyMode(object): + def validate_for_algorithm(self, algorithm): + pass + + class TestCipher(object): def test_instantiate_without_backend(self): Cipher( @@ -101,10 +106,20 @@ class TestCipherContext(object): def test_nonexistent_cipher(self, backend): cipher = Cipher( - DummyCipher(), object(), backend + DummyCipher(), DummyMode(), backend ) with pytest.raises(UnsupportedAlgorithm): cipher.encryptor() with pytest.raises(UnsupportedAlgorithm): cipher.decryptor() + + +class TestModeValidation(object): + def test_cbc(self, backend): + with pytest.raises(ValueError): + Cipher( + algorithms.AES(b"\x00" * 16), + modes.CBC(b"abc"), + backend, + ) |