diff options
-rw-r--r-- | cryptography/bindings/openssl/api.py | 13 | ||||
-rw-r--r-- | cryptography/primitives/abc/__init__.py | 0 | ||||
-rw-r--r-- | cryptography/primitives/abc/block/__init__.py | 0 | ||||
-rw-r--r-- | cryptography/primitives/abc/block/modes.py | 21 | ||||
-rw-r--r-- | cryptography/primitives/block/modes.py | 9 | ||||
-rw-r--r-- | tests/bindings/test_openssl.py | 5 |
6 files changed, 36 insertions, 12 deletions
diff --git a/cryptography/bindings/openssl/api.py b/cryptography/bindings/openssl/api.py index fd54a8ff..f95e4d62 100644 --- a/cryptography/bindings/openssl/api.py +++ b/cryptography/bindings/openssl/api.py @@ -13,6 +13,8 @@ from __future__ import absolute_import, division, print_function +from cryptography.primitives.abc.block import modes + import cffi @@ -72,7 +74,7 @@ class API(object): ) evp_cipher = self._lib.EVP_get_cipherbyname(ciphername.encode("ascii")) assert evp_cipher != self._ffi.NULL - iv_nonce = mode.get_iv_or_nonce(self) + iv_nonce = self._introspect(mode) # TODO: Sometimes this needs to be a DecryptInit, when? res = self._lib.EVP_EncryptInit_ex( @@ -86,8 +88,13 @@ class API(object): self._lib.EVP_CIPHER_CTX_set_padding(ctx, 0) return ctx - def get_iv_for_ecb(self): - return self._ffi.NULL + def _introspect(self, mode): + if isinstance(mode, modes.ModeWithInitializationVector): + return mode.initialization_vector + elif isinstance(mode, modes.ModeWithNonce): + return mode.nonce + else: + return self._ffi.NULL def update_encrypt_context(self, ctx, plaintext): buf = self._ffi.new("unsigned char[]", len(plaintext)) diff --git a/cryptography/primitives/abc/__init__.py b/cryptography/primitives/abc/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/cryptography/primitives/abc/__init__.py diff --git a/cryptography/primitives/abc/block/__init__.py b/cryptography/primitives/abc/block/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/cryptography/primitives/abc/block/__init__.py diff --git a/cryptography/primitives/abc/block/modes.py b/cryptography/primitives/abc/block/modes.py new file mode 100644 index 00000000..609a2ae3 --- /dev/null +++ b/cryptography/primitives/abc/block/modes.py @@ -0,0 +1,21 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import abc + + +ModeWithInitializationVector = abc.ABCMeta('ModeWithInitializationVector', + (object, ), {}) +ModeWithNonce = abc.ABCMeta('ModeWithNonce', (object, ), {}) diff --git a/cryptography/primitives/block/modes.py b/cryptography/primitives/block/modes.py index e4fc886e..1e9b14b7 100644 --- a/cryptography/primitives/block/modes.py +++ b/cryptography/primitives/block/modes.py @@ -13,6 +13,8 @@ from __future__ import absolute_import, division, print_function +from cryptography.primitives.abc.block import modes + class CBC(object): name = "CBC" @@ -21,12 +23,9 @@ class CBC(object): super(CBC, self).__init__() self.initialization_vector = initialization_vector - def get_iv_or_nonce(self, api): - return self.initialization_vector - class ECB(object): name = "ECB" - def get_iv_or_nonce(self, api): - return api.get_iv_for_ecb() + +modes.ModeWithInitializationVector.register(CBC) diff --git a/tests/bindings/test_openssl.py b/tests/bindings/test_openssl.py index f25236cc..b23c4ccc 100644 --- a/tests/bindings/test_openssl.py +++ b/tests/bindings/test_openssl.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cryptography.bindings.openssl import api +from cryptography.bindings.openssl.api import api class TestOpenSSL(object): @@ -28,6 +28,3 @@ class TestOpenSSL(object): for every OpenSSL. """ assert api.openssl_version_text().startswith("OpenSSL") - - def test_get_iv_for_ecb(self): - assert api.get_iv_for_ecb() == api._ffi.NULL |