/* * caam - Freescale FSL CAAM support for ahash functions of crypto API * * Copyright (C) 2011-2013 Freescale Semiconductor, Inc. * * Based on caamalg.c crypto API driver. * * relationship of digest job descriptor or first job descriptor after init to * shared descriptors: * * --------------- --------------- * | JobDesc #1 |-------------------->| ShareDesc | * | *(packet 1) | | (hashKey) | * --------------- | (operation) | * --------------- * * relationship of subsequent job descriptors to shared descriptors: * * --------------- --------------- * | JobDesc #2 |-------------------->| ShareDesc | * | *(packet 2) | |------------->| (hashKey) | * --------------- | |-------->| (operation) | * . | | | (load ctx2) | * . | | --------------- * --------------- | | * | JobDesc #3 |------| | * | *(packet 3) | | * --------------- | * . | * . | * --------------- | * | JobDesc #4 |------------ * | *(packet 4) | * --------------- * * The SharedDesc never changes for a connection unless rekeyed, but * each packet will likely be in a different place. So all we need * to know to process the packet is where the input is, where the * output goes, and what context we want to process with. Context is * in the SharedDesc, packet references in the JobDesc. * * So, a job desc looks like: * * --------------------- * | Header | * | ShareDesc Pointer | * | SEQ_OUT_PTR | * | (output buffer) | * | (output length) | * | SEQ_IN_PTR | * | (input buffer) | * | (input length) | * --------------------- */ #include "compat.h" #include "regs.h" #include "intern.h" #include "desc_constr.h" #include "jr.h" #include "error.h" #include "sg_sw_sec4.h" #include "key_gen.h" #define CAAM_CRA_PRIORITY 3000 /* max hash key is max split key size */ #define CAAM_MAX_HASH_KEY_SIZE (SHA512_DIGEST_SIZE * 2) #define CAAM_MAX_HASH_BLOCK_SIZE SHA512_BLOCK_SIZE #define CAAM_MAX_HASH_DIGEST_SIZE SHA512_DIGEST_SIZE /* length of descriptors text */ #define DESC_JOB_IO_LEN (CAAM_CMD_SZ * 5 + CAAM_PTR_SZ * 3) #define DESC_AHASH_BASE (4 * CAAM_CMD_SZ) #define DESC_AHASH_UPDATE_LEN (6 * CAAM_CMD_SZ) #define DESC_AHASH_UPDATE_FIRST_LEN (DESC_AHASH_BASE + 4 * CAAM_CMD_SZ) #define DESC_AHASH_FINAL_LEN (DESC_AHASH_BASE + 5 * CAAM_CMD_SZ) #define DESC_AHASH_FINUP_LEN (DESC_AHASH_BASE + 5 * CAAM_CMD_SZ) #define DESC_AHASH_DIGEST_LEN (DESC_AHASH_BASE + 4 * CAAM_CMD_SZ) #define DESC_HASH_MAX_USED_BYTES (DESC_AHASH_FINAL_LEN + \ CAAM_MAX_HASH_KEY_SIZE) #define DESC_HASH_MAX_USED_LEN (DESC_HASH_MAX_USED_BYTES / CAAM_CMD_SZ) /* caam context sizes for hashes: running digest + 8 */ #define HASH_MSG_LEN 8 #define MAX_CTX_LEN (HASH_MSG_LEN + SHA512_DIGEST_SIZE) #ifdef DEBUG /* for print_hex_dumps with line references */ #define xstr(s) str(s) #define str(s) #s #define debug(format, arg...) printk(format, arg) #else #define debug(format, arg...) #endif /* ahash per-session context */ struct caam_hash_ctx { struct device *jrdev; u32 sh_desc_update[DESC_HASH_MAX_USED_LEN]; u32 sh_desc_update_first[DESC_HASH_MAX_USED_LEN]; u32 sh_desc_fin[DESC_HASH_MAX_USED_LEN]; u32 sh_desc_digest[DESC_HASH_MAX_USED_LEN]; u32 sh_desc_finup[DESC_HASH_MAX_USED_LEN]; dma_addr_t sh_desc_update_dma; dma_addr_t sh_desc_update_first_dma; dma_addr_t sh_desc_fin_dma; dma_addr_t sh_desc_digest_dma; dma_addr_t sh_desc_finup_dma; u32 alg_type; u32 alg_op; u8 key[CAAM_MAX_HASH_KEY_SIZE]; dma_addr_t key_dma; int ctx_len; unsigned int split_key_len; unsigned int split_key_pad_len; }; /* ahash state */ struct caam_hash_state { dma_addr_t buf_dma; dma_addr_t ctx_dma; u8 buf_0[CAAM_MAX_HASH_BLOCK_SIZE] ____cacheline_aligned; int buflen_0; u8 buf_1[CAAM_MAX_HASH_BLOCK_SIZE] ____cacheline_aligned; int buflen_1; u8 caam_ctx[MAX_CTX_LEN]; int (*update)(struct ahash_request *req); int (*final)(struct ahash_request *req); int (*finup)(struct ahash_request *req); int current_buf; }; /* Common job descriptor seq in/out ptr routines */ /* Map state->caam_ctx, and append seq_out_ptr command that points to it */ static inline void map_seq_out_ptr_ctx(u32 *desc, struct device *jrdev, struct caam_hash_state *state, int ctx_len) { state->ctx_dma = dma_map_single(jrdev, state->caam_ctx, ctx_len, DMA_FROM_DEVICE); append_seq_out_ptr(desc, state->ctx_dma, ctx_len, 0); } /* Map req->result, and append seq_out_ptr command that points to it */ static inline dma_addr_t map_seq_out_ptr_result(u32 *desc, struct device *jrdev, u8 *result, int digestsize) { dma_addr_t dst_dma; dst_dma = dma_map_single(jrdev, result, digestsize, DMA_FROM_DEVICE); append_seq_out_ptr(desc, dst_dma, digestsize, 0); return dst_dma; } /* Map current buffer in state and put it in link table */ static inline dma_addr_t buf_map_to_sec4_sg(struct device *jrdev, struct sec4_sg_entry *sec4_sg, u8 *buf, int buflen) { dma_addr_t buf_dma; buf_dma = dma_map_single(jrdev, buf, buflen, DMA_TO_DEVICE); dma_sync_single_for_device(jrdev, buf_dma, buflen, DMA_TO_DEVICE); dma_to_sec4_sg_one(sec4_sg, buf_dma, buflen, 0); return buf_dma; } /* Map req->src and put it in link table */ static inline void src_map_to_sec4_sg(struct device *jrdev, struct scatterlist *src, int src_nents, struct sec4_sg_entry *sec4_sg, bool chained) { dma_map_sg_chained(jrdev, src, src_nents, DMA_TO_DEVICE, chained); sg_to_sec4_sg_last(src, src_nents, sec4_sg, 0); } /* * Only put buffer in link table if it contains data, which is possible, * since a buffer has previously been used, and needs to be unmapped, */ static inline dma_addr_t try_buf_map_to_sec4_sg(struct device *jrdev, struct sec4_sg_entry *sec4_sg, u8 *buf, dma_addr_t buf_dma, int buflen, int last_buflen) { if (buf_dma && !dma_mapping_error(jrdev, buf_dma)) dma_unmap_single(jrdev, buf_dma, last_buflen, DMA_TO_DEVICE); if (buflen) buf_dma = buf_map_to_sec4_sg(jrdev, sec4_sg, buf, buflen); else buf_dma = 0; return buf_dma; } /* Map state->caam_ctx, and add it to link table */ static inline void ctx_map_to_sec4_sg(u32 *desc, struct device *jrdev, struct caam_hash_state *state, int ctx_len, struct sec4_sg_entry *sec4_sg, u32 flag) { state->ctx_dma = dma_map_single(jrdev, state->caam_ctx, ctx_len, flag); if ((flag == DMA_TO_DEVICE) || (flag == DMA_BIDIRECTIONAL)) dma_sync_single_for_device(jrdev, state->ctx_dma, ctx_len, flag); dma_to_sec4_sg_one(sec4_sg, state->ctx_dma, ctx_len, 0); } /* Common shared descriptor commands */ static inline void append_key_ahash(u32 *desc, struct caam_hash_ctx *ctx) { append_key_as_imm(desc, ctx->key, ctx->split_key_pad_len, ctx->split_key_len, CLASS_2 | KEY_DEST_MDHA_SPLIT | KEY_ENC); } /* Append key if it has been set */ static inline void init_sh_desc_key_ahash(u32 *desc, struct caam_hash_ctx *ctx) { u32 *key_jump_cmd; init_sh_desc(desc, HDR_SHARE_SERIAL); if (ctx->split_key_len) { /* Skip if already shared */ key_jump_cmd = append_jump(desc, JUMP_JSL | JUMP_TEST_ALL | JUMP_COND_SHRD); append_key_ahash(desc, ctx); set_jump_tgt_here(desc, key_jump_cmd); } /* Propagate errors from shared to job descriptor */ append_cmd(desc, SET_OK_NO_PROP_ERRORS | CMD_LOAD); } /* * For ahash read data from seqin following state->caam_ctx, * and write resulting class2 context to seqout, which may be state->caam_ctx * or req->result */ static inline void ahash_append_load_str(u32 *desc, int digestsize) { /* Calculate remaining bytes to read */ append_math_add(desc, VARSEQINLEN, SEQINLEN, REG0, CAAM_CMD_SZ); /* Read remaining bytes */ append_seq_fifo_load(desc, 0, FIFOLD_CLASS_CLASS2 | FIFOLD_TYPE_LAST2 | FIFOLD_TYPE_MSG | KEY_VLF); /* Store class2 context bytes */ append_seq_store(desc, digestsize, LDST_CLASS_2_CCB | LDST_SRCDST_BYTE_CONTEXT); } /* * For ahash update, final and finup, import context, read and write to seqout */ static inline void ahash_ctx_data_to_out(u32 *desc, u32 op, u32 state, int digestsize, struct caam_hash_ctx *ctx) { init_sh_desc_key_ahash(desc, ctx); /* Import context from software */ append_cmd(desc, CMD_SEQ_LOAD | LDST_SRCDST_BYTE_CONTEXT | LDST_CLASS_2_CCB | ctx->ctx_len); /* Class 2 operation */ append_operation(desc, op | state | OP_ALG_ENCRYPT); /* * Load from buf and/or src and write to req->result or state->context */ ahash_append_load_str(desc, digestsize); } /* For ahash firsts and digest, read and write to seqout */ static inline void ahash_data_to_out(u32 *desc, u32 op, u32 state, int digestsize, struct caam_hash_ctx *ctx) { init_sh_desc_key_ahash(desc, ctx); /* Class 2 operation */ append_operation(desc, op | state | OP_ALG_ENCRYPT); /* * Load from buf and/or src and write to req->result or state->context */ ahash_append_load_str(desc, digestsize); } static int ahash_set_sh_desc(struct crypto_ahash *ahash) { struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); int digestsize = crypto_ahash_digestsize(ahash); struct device *jrdev = ctx->jrdev; u32 have_key = 0; u32 *desc; if (ctx->split_key_len) have_key = OP_ALG_AAI_HMAC_PRECOMP; /* ahash_update shared descriptor */ desc = ctx->sh_desc_update; init_sh_desc(desc, HDR_SHARE_SERIAL); /* Import context from software */ append_cmd(desc, CMD_SEQ_LOAD | LDST_SRCDST_BYTE_CONTEXT | LDST_CLASS_2_CCB | ctx->ctx_len); /* Class 2 operation */ append_operation(desc, ctx->alg_type | OP_ALG_AS_UPDATE | OP_ALG_ENCRYPT); /* Load data and write to result or context */ ahash_append_load_str(desc, ctx->ctx_len); ctx->sh_desc_update_dma = dma_map_single(jrdev, desc, desc_bytes(desc), DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->sh_desc_update_dma)) { dev_err(jrdev, "unable to map shared descriptor\n"); return -ENOMEM; } #ifdef DEBUG print_hex_dump(KERN_ERR, "ahash update shdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif /* ahash_update_first shared descriptor */ desc = ctx->sh_desc_update_first; ahash_data_to_out(desc, have_key | ctx->alg_type, OP_ALG_AS_INIT, ctx->ctx_len, ctx); ctx->sh_desc_update_first_dma = dma_map_single(jrdev, desc, desc_bytes(desc), DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->sh_desc_update_first_dma)) { dev_err(jrdev, "unable to map shared descriptor\n"); return -ENOMEM; } #ifdef DEBUG print_hex_dump(KERN_ERR, "ahash update first shdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif dma_sync_single_for_device(jrdev, ctx->sh_desc_update_first_dma, desc_bytes(desc), DMA_TO_DEVICE); /* ahash_final shared descriptor */ desc = ctx->sh_desc_fin; ahash_ctx_data_to_out(desc, have_key | ctx->alg_type, OP_ALG_AS_FINALIZE, digestsize, ctx); ctx->sh_desc_fin_dma = dma_map_single(jrdev, desc, desc_bytes(desc), DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->sh_desc_fin_dma)) { dev_err(jrdev, "unable to map shared descriptor\n"); return -ENOMEM; } #ifdef DEBUG print_hex_dump(KERN_ERR, "ahash final shdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif dma_sync_single_for_device(jrdev, ctx->sh_desc_fin_dma, desc_bytes(desc), DMA_TO_DEVICE); /* ahash_finup shared descriptor */ desc = ctx->sh_desc_finup; ahash_ctx_data_to_out(desc, have_key | ctx->alg_type, OP_ALG_AS_FINALIZE, digestsize, ctx); ctx->sh_desc_finup_dma = dma_map_single(jrdev, desc, desc_bytes(desc), DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->sh_desc_finup_dma)) { dev_err(jrdev, "unable to map shared descriptor\n"); return -ENOMEM; } #ifdef DEBUG print_hex_dump(KERN_ERR, "ahash finup shdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif dma_sync_single_for_device(jrdev, ctx->sh_desc_finup_dma, desc_bytes(desc), DMA_TO_DEVICE); /* ahash_digest shared descriptor */ desc = ctx->sh_desc_digest; ahash_data_to_out(desc, have_key | ctx->alg_type, OP_ALG_AS_INITFINAL, digestsize, ctx); ctx->sh_desc_digest_dma = dma_map_single(jrdev, desc, desc_bytes(desc), DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->sh_desc_digest_dma)) { dev_err(jrdev, "unable to map shared descriptor\n"); return -ENOMEM; } #ifdef DEBUG print_hex_dump(KERN_ERR, "ahash digest shdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif dma_sync_single_for_device(jrdev, ctx->sh_desc_digest_dma, desc_bytes(desc), DMA_TO_DEVICE); return 0; } static u32 gen_split_hash_key(struct caam_hash_ctx *ctx, const u8 *key_in, u32 keylen) { return gen_split_key(ctx->jrdev, ctx->key, ctx->split_key_len, ctx->split_key_pad_len, key_in, keylen, ctx->alg_op); } /* Digest hash size if it is too large */ static u32 hash_digest_key(struct caam_hash_ctx *ctx, const u8 *key_in, u32 *keylen, u8 *key_out, u32 digestsize) { struct device *jrdev = ctx->jrdev; u32 *desc; struct split_key_result result; dma_addr_t src_dma, dst_dma; int ret = 0; /* * Hashing descriptor is 6 commands (including header), 2 pointers, * and 2 extended lengths */ desc = kmalloc((CAAM_CMD_SZ * 6 + CAAM_PTR_SZ * 2 + CAAM_EXTLEN_SZ * 2), GFP_KERNEL | GFP_DMA); init_job_desc(desc, 0); src_dma = dma_map_single(jrdev, (void *)key_in, *keylen, DMA_TO_DEVICE); if (dma_mapping_error(jrdev, src_dma)) { dev_err(jrdev, "unable to map key input memory\n"); kfree(desc); return -ENOMEM; } dma_sync_single_for_device(jrdev, src_dma, *keylen, DMA_TO_DEVICE); dst_dma = dma_map_single(jrdev, (void *)key_out, digestsize, DMA_FROM_DEVICE); if (dma_mapping_error(jrdev, dst_dma)) { dev_err(jrdev, "unable to map key output memory\n"); dma_unmap_single(jrdev, src_dma, *keylen, DMA_TO_DEVICE); kfree(desc); return -ENOMEM; } /* Job descriptor to perform unkeyed hash on key_in */ append_operation(desc, ctx->alg_type | OP_ALG_ENCRYPT | OP_ALG_AS_INITFINAL); append_seq_in_ptr(desc, src_dma, *keylen, 0); append_seq_fifo_load(desc, *keylen, FIFOLD_CLASS_CLASS2 | FIFOLD_TYPE_LAST2 | FIFOLD_TYPE_MSG); append_seq_out_ptr(desc, dst_dma, digestsize, 0); append_seq_store(desc, digestsize, LDST_CLASS_2_CCB | LDST_SRCDST_BYTE_CONTEXT); #ifdef DEBUG print_hex_dump(KERN_ERR, "key_in@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, key_in, *keylen, 1); print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif result.err = 0; init_completion(&result.completion); ret = caam_jr_enqueue(jrdev, desc, split_key_done, &result); if (!ret) { /* in progress */ wait_for_completion_interruptible(&result.completion); ret = result.err; #ifdef DEBUG print_hex_dump(KERN_ERR, "digested key@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, key_in, digestsize, 1); #endif } *keylen = digestsize; dma_unmap_single(jrdev, src_dma, *keylen, DMA_TO_DEVICE); dma_sync_single_for_cpu(jrdev, dst_dma, digestsize, DMA_FROM_DEVICE); dma_unmap_single(jrdev, dst_dma, digestsize, DMA_FROM_DEVICE); kfree(desc); return ret; } static int ahash_setkey(struct crypto_ahash *ahash, const u8 *key, unsigned int keylen) { /* Sizes for MDHA pads (*not* keys): MD5, SHA1, 224, 256, 384, 512 */ static const u8 mdpadlen[] = { 16, 20, 32, 32, 64, 64 }; struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct device *jrdev = ctx->jrdev; int blocksize = crypto_tfm_alg_blocksize(&ahash->base); int digestsize = crypto_ahash_digestsize(ahash); int ret = 0; u8 *hashed_key = NULL; #ifdef DEBUG printk(KERN_ERR "keylen %d\n", keylen); #endif if (keylen > blocksize) { hashed_key = kmalloc(sizeof(u8) * digestsize, GFP_KERNEL | GFP_DMA); if (!hashed_key) return -ENOMEM; ret = hash_digest_key(ctx, key, &keylen, hashed_key, digestsize); if (ret) goto badkey; key = hashed_key; } /* Pick class 2 key length from algorithm submask */ ctx->split_key_len = mdpadlen[(ctx->alg_op & OP_ALG_ALGSEL_SUBMASK) >> OP_ALG_ALGSEL_SHIFT] * 2; ctx->split_key_pad_len = ALIGN(ctx->split_key_len, 16); #ifdef DEBUG printk(KERN_ERR "split_key_len %d split_key_pad_len %d\n", ctx->split_key_len, ctx->split_key_pad_len); print_hex_dump(KERN_ERR, "key in @"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, key, keylen, 1); #endif ret = gen_split_hash_key(ctx, key, keylen); if (ret) goto badkey; ctx->key_dma = dma_map_single(jrdev, ctx->key, ctx->split_key_pad_len, DMA_TO_DEVICE); if (dma_mapping_error(jrdev, ctx->key_dma)) { dev_err(jrdev, "unable to map key i/o memory\n"); return -ENOMEM; } dma_sync_single_for_device(jrdev, ctx->key_dma, ctx->split_key_pad_len, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "ctx.key@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, ctx->key, ctx->split_key_pad_len, 1); #endif ret = ahash_set_sh_desc(ahash); if (ret) { dma_unmap_single(jrdev, ctx->key_dma, ctx->split_key_pad_len, DMA_TO_DEVICE); } kfree(hashed_key); return ret; badkey: kfree(hashed_key); crypto_ahash_set_flags(ahash, CRYPTO_TFM_RES_BAD_KEY_LEN); return -EINVAL; } /* * ahash_edesc - s/w-extended ahash descriptor * @dst_dma: physical mapped address of req->result * @sec4_sg_dma: physical mapped address of h/w link table * @chained: if source is chained * @src_nents: number of segments in input scatterlist * @sec4_sg_bytes: length of dma mapped sec4_sg space * @sec4_sg: pointer to h/w link table * @hw_desc: the h/w job descriptor followed by any referenced link tables */ struct ahash_edesc { dma_addr_t dst_dma; dma_addr_t sec4_sg_dma; bool chained; int src_nents; int sec4_sg_bytes; struct sec4_sg_entry *sec4_sg; u32 hw_desc[0]; }; static inline void ahash_unmap(struct device *dev, struct ahash_edesc *edesc, struct ahash_request *req, int dst_len) { if (edesc->src_nents) dma_unmap_sg_chained(dev, req->src, edesc->src_nents, DMA_TO_DEVICE, edesc->chained); if (edesc->dst_dma) { dma_sync_single_for_cpu(dev, edesc->dst_dma, dst_len, DMA_FROM_DEVICE); dma_unmap_single(dev, edesc->dst_dma, dst_len, DMA_FROM_DEVICE); } if (edesc->sec4_sg_bytes) dma_unmap_single(dev, edesc->sec4_sg_dma, edesc->sec4_sg_bytes, DMA_TO_DEVICE); } static inline void ahash_unmap_ctx(struct device *dev, struct ahash_edesc *edesc, struct ahash_request *req, int dst_len, u32 flag) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); if (state->ctx_dma) { if ((flag == DMA_FROM_DEVICE) || (flag == DMA_BIDIRECTIONAL)) dma_sync_single_for_cpu(dev, state->ctx_dma, ctx->ctx_len, flag); dma_unmap_single(dev, state->ctx_dma, ctx->ctx_len, flag); } ahash_unmap(dev, edesc, req, dst_len); } static void ahash_done(struct device *jrdev, u32 *desc, u32 err, void *context) { struct ahash_request *req = context; struct ahash_edesc *edesc; struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); int digestsize = crypto_ahash_digestsize(ahash); #ifdef DEBUG struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); dev_err(jrdev, "%s %d: err 0x%x\n", __func__, __LINE__, err); #endif edesc = (struct ahash_edesc *)((char *)desc - offsetof(struct ahash_edesc, hw_desc)); if (err) { char tmp[CAAM_ERROR_STR_MAX]; dev_err(jrdev, "%08x: %s\n", err, caam_jr_strstatus(tmp, err)); } ahash_unmap(jrdev, edesc, req, digestsize); kfree(edesc); #ifdef DEBUG print_hex_dump(KERN_ERR, "ctx@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, state->caam_ctx, ctx->ctx_len, 1); if (req->result) print_hex_dump(KERN_ERR, "result@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, req->result, digestsize, 1); #endif req->base.complete(&req->base, err); } static void ahash_done_bi(struct device *jrdev, u32 *desc, u32 err, void *context) { struct ahash_request *req = context; struct ahash_edesc *edesc; struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); #ifdef DEBUG struct caam_hash_state *state = ahash_request_ctx(req); int digestsize = crypto_ahash_digestsize(ahash); dev_err(jrdev, "%s %d: err 0x%x\n", __func__, __LINE__, err); #endif edesc = (struct ahash_edesc *)((char *)desc - offsetof(struct ahash_edesc, hw_desc)); if (err) { char tmp[CAAM_ERROR_STR_MAX]; dev_err(jrdev, "%08x: %s\n", err, caam_jr_strstatus(tmp, err)); } ahash_unmap_ctx(jrdev, edesc, req, ctx->ctx_len, DMA_BIDIRECTIONAL); kfree(edesc); #ifdef DEBUG print_hex_dump(KERN_ERR, "ctx@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, state->caam_ctx, ctx->ctx_len, 1); if (req->result) print_hex_dump(KERN_ERR, "result@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, req->result, digestsize, 1); #endif req->base.complete(&req->base, err); } static void ahash_done_ctx_src(struct device *jrdev, u32 *desc, u32 err, void *context) { struct ahash_request *req = context; struct ahash_edesc *edesc; struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); int digestsize = crypto_ahash_digestsize(ahash); #ifdef DEBUG struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); dev_err(jrdev, "%s %d: err 0x%x\n", __func__, __LINE__, err); #endif edesc = (struct ahash_edesc *)((char *)desc - offsetof(struct ahash_edesc, hw_desc)); if (err) { char tmp[CAAM_ERROR_STR_MAX]; dev_err(jrdev, "%08x: %s\n", err, caam_jr_strstatus(tmp, err)); } ahash_unmap_ctx(jrdev, edesc, req, digestsize, DMA_FROM_DEVICE); kfree(edesc); #ifdef DEBUG print_hex_dump(KERN_ERR, "ctx@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, state->caam_ctx, ctx->ctx_len, 1); if (req->result) print_hex_dump(KERN_ERR, "result@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, req->result, digestsize, 1); #endif req->base.complete(&req->base, err); } static void ahash_done_ctx_dst(struct device *jrdev, u32 *desc, u32 err, void *context) { struct ahash_request *req = context; struct ahash_edesc *edesc; struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); #ifdef DEBUG struct caam_hash_state *state = ahash_request_ctx(req); int digestsize = crypto_ahash_digestsize(ahash); dev_err(jrdev, "%s %d: err 0x%x\n", __func__, __LINE__, err); #endif edesc = (struct ahash_edesc *)((char *)desc - offsetof(struct ahash_edesc, hw_desc)); if (err) { char tmp[CAAM_ERROR_STR_MAX]; dev_err(jrdev, "%08x: %s\n", err, caam_jr_strstatus(tmp, err)); } ahash_unmap_ctx(jrdev, edesc, req, ctx->ctx_len, DMA_TO_DEVICE); kfree(edesc); #ifdef DEBUG print_hex_dump(KERN_ERR, "ctx@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, state->caam_ctx, ctx->ctx_len, 1); if (req->result) print_hex_dump(KERN_ERR, "result@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, req->result, digestsize, 1); #endif req->base.complete(&req->base, err); } /* submit update job descriptor */ static int ahash_update_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int *buflen = state->current_buf ? &state->buflen_1 : &state->buflen_0; u8 *next_buf = state->current_buf ? state->buf_0 : state->buf_1; int *next_buflen = state->current_buf ? &state->buflen_0 : &state->buflen_1, last_buflen; int in_len = *buflen + req->nbytes, to_hash; u32 *sh_desc = ctx->sh_desc_update, *desc = NULL; dma_addr_t ptr = ctx->sh_desc_update_dma; int src_nents, sec4_sg_bytes, sec4_sg_src_index; struct ahash_edesc *edesc; bool chained = false; int ret = 0; int sh_len; last_buflen = *next_buflen; *next_buflen = in_len & (crypto_tfm_alg_blocksize(&ahash->base) - 1); to_hash = in_len - *next_buflen; if (to_hash) { src_nents = __sg_count(req->src, req->nbytes - (*next_buflen), &chained); sec4_sg_src_index = 1 + (*buflen ? 1 : 0); sec4_sg_bytes = (sec4_sg_src_index + src_nents) * sizeof(struct sec4_sg_entry); /* * allocate space for base edesc and hw desc commands, * link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } edesc->src_nents = src_nents; edesc->chained = chained; edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); ctx_map_to_sec4_sg(desc, jrdev, state, ctx->ctx_len, edesc->sec4_sg, DMA_BIDIRECTIONAL); state->buf_dma = try_buf_map_to_sec4_sg(jrdev, edesc->sec4_sg + 1, buf, state->buf_dma, *buflen, last_buflen); if (src_nents) { src_map_to_sec4_sg(jrdev, req->src, src_nents, edesc->sec4_sg + sec4_sg_src_index, chained); if (*next_buflen) { sg_copy_part(next_buf, req->src, to_hash - *buflen, req->nbytes); state->current_buf = !state->current_buf; } } else { (edesc->sec4_sg + sec4_sg_src_index - 1)->len |= SEC4_SG_LEN_FIN; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); append_seq_in_ptr(desc, edesc->sec4_sg_dma, ctx->ctx_len + to_hash, LDST_SGF); append_seq_out_ptr(desc, state->ctx_dma, ctx->ctx_len, 0); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done_bi, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap_ctx(jrdev, edesc, req, ctx->ctx_len, DMA_BIDIRECTIONAL); kfree(edesc); } } else if (*next_buflen) { sg_copy(buf + *buflen, req->src, req->nbytes); *buflen = *next_buflen; *next_buflen = last_buflen; } #ifdef DEBUG print_hex_dump(KERN_ERR, "buf@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, buf, *buflen, 1); print_hex_dump(KERN_ERR, "next buf@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, next_buf, *next_buflen, 1); #endif return ret; } static int ahash_final_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int buflen = state->current_buf ? state->buflen_1 : state->buflen_0; int last_buflen = state->current_buf ? state->buflen_0 : state->buflen_1; u32 *sh_desc = ctx->sh_desc_fin, *desc; dma_addr_t ptr = ctx->sh_desc_fin_dma; int sec4_sg_bytes; int digestsize = crypto_ahash_digestsize(ahash); struct ahash_edesc *edesc; int ret = 0; int sh_len; sec4_sg_bytes = (1 + (buflen ? 1 : 0)) * sizeof(struct sec4_sg_entry); /* allocate space for base edesc and hw desc commands, link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); edesc->src_nents = 0; ctx_map_to_sec4_sg(desc, jrdev, state, ctx->ctx_len, edesc->sec4_sg, DMA_TO_DEVICE); state->buf_dma = try_buf_map_to_sec4_sg(jrdev, edesc->sec4_sg + 1, buf, state->buf_dma, buflen, last_buflen); (edesc->sec4_sg + sec4_sg_bytes - 1)->len |= SEC4_SG_LEN_FIN; append_seq_in_ptr(desc, edesc->sec4_sg_dma, ctx->ctx_len + buflen, LDST_SGF); edesc->dst_dma = map_seq_out_ptr_result(desc, jrdev, req->result, digestsize); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done_ctx_src, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap_ctx(jrdev, edesc, req, digestsize, DMA_FROM_DEVICE); kfree(edesc); } return ret; } static int ahash_finup_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int buflen = state->current_buf ? state->buflen_1 : state->buflen_0; int last_buflen = state->current_buf ? state->buflen_0 : state->buflen_1; u32 *sh_desc = ctx->sh_desc_finup, *desc; dma_addr_t ptr = ctx->sh_desc_finup_dma; int sec4_sg_bytes, sec4_sg_src_index; int src_nents; int digestsize = crypto_ahash_digestsize(ahash); struct ahash_edesc *edesc; bool chained = false; int ret = 0; int sh_len; src_nents = __sg_count(req->src, req->nbytes, &chained); sec4_sg_src_index = 1 + (buflen ? 1 : 0); sec4_sg_bytes = (sec4_sg_src_index + src_nents) * sizeof(struct sec4_sg_entry); /* allocate space for base edesc and hw desc commands, link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); edesc->src_nents = src_nents; edesc->chained = chained; edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); ctx_map_to_sec4_sg(desc, jrdev, state, ctx->ctx_len, edesc->sec4_sg, DMA_TO_DEVICE); state->buf_dma = try_buf_map_to_sec4_sg(jrdev, edesc->sec4_sg + 1, buf, state->buf_dma, buflen, last_buflen); src_map_to_sec4_sg(jrdev, req->src, src_nents, edesc->sec4_sg + sec4_sg_src_index, chained); append_seq_in_ptr(desc, edesc->sec4_sg_dma, ctx->ctx_len + buflen + req->nbytes, LDST_SGF); edesc->dst_dma = map_seq_out_ptr_result(desc, jrdev, req->result, digestsize); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done_ctx_src, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap_ctx(jrdev, edesc, req, digestsize, DMA_FROM_DEVICE); kfree(edesc); } return ret; } static int ahash_digest(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u32 *sh_desc = ctx->sh_desc_digest, *desc; dma_addr_t ptr = ctx->sh_desc_digest_dma; int digestsize = crypto_ahash_digestsize(ahash); int src_nents, sec4_sg_bytes; dma_addr_t src_dma; struct ahash_edesc *edesc; bool chained = false; int ret = 0; u32 options; int sh_len; src_nents = sg_count(req->src, req->nbytes, &chained); dma_map_sg_chained(jrdev, req->src, src_nents ? : 1, DMA_TO_DEVICE, chained); sec4_sg_bytes = src_nents * sizeof(struct sec4_sg_entry); /* allocate space for base edesc and hw desc commands, link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + sec4_sg_bytes + DESC_JOB_IO_LEN, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->src_nents = src_nents; edesc->chained = chained; sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); if (src_nents) { sg_to_sec4_sg_last(req->src, src_nents, edesc->sec4_sg, 0); src_dma = edesc->sec4_sg_dma; options = LDST_SGF; } else { src_dma = sg_dma_address(req->src); options = 0; } append_seq_in_ptr(desc, src_dma, req->nbytes, options); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, edesc->sec4_sg_bytes, DMA_TO_DEVICE); edesc->dst_dma = map_seq_out_ptr_result(desc, jrdev, req->result, digestsize); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap(jrdev, edesc, req, digestsize); kfree(edesc); } return ret; } /* submit ahash final if it the first job descriptor */ static int ahash_final_no_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int buflen = state->current_buf ? state->buflen_1 : state->buflen_0; u32 *sh_desc = ctx->sh_desc_digest, *desc; dma_addr_t ptr = ctx->sh_desc_digest_dma; int digestsize = crypto_ahash_digestsize(ahash); struct ahash_edesc *edesc; int ret = 0; int sh_len; /* allocate space for base edesc and hw desc commands, link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); state->buf_dma = dma_map_single(jrdev, buf, buflen, DMA_TO_DEVICE); append_seq_in_ptr(desc, state->buf_dma, buflen, 0); edesc->dst_dma = map_seq_out_ptr_result(desc, jrdev, req->result, digestsize); edesc->src_nents = 0; dma_sync_single_for_device(jrdev, state->buf_dma, buflen, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap(jrdev, edesc, req, digestsize); kfree(edesc); } return ret; } /* submit ahash update if it the first job descriptor after update */ static int ahash_update_no_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int *buflen = state->current_buf ? &state->buflen_1 : &state->buflen_0; u8 *next_buf = state->current_buf ? state->buf_0 : state->buf_1; int *next_buflen = state->current_buf ? &state->buflen_0 : &state->buflen_1; int in_len = *buflen + req->nbytes, to_hash; int sec4_sg_bytes, src_nents; struct ahash_edesc *edesc; u32 *desc, *sh_desc = ctx->sh_desc_update_first; dma_addr_t ptr = ctx->sh_desc_update_first_dma; bool chained = false; int ret = 0; int sh_len; *next_buflen = in_len & (crypto_tfm_alg_blocksize(&ahash->base) - 1); to_hash = in_len - *next_buflen; if (to_hash) { src_nents = __sg_count(req->src, req->nbytes - (*next_buflen), &chained); sec4_sg_bytes = (1 + src_nents) * sizeof(struct sec4_sg_entry); /* * allocate space for base edesc and hw desc commands, * link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } edesc->src_nents = src_nents; edesc->chained = chained; edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); state->buf_dma = buf_map_to_sec4_sg(jrdev, edesc->sec4_sg, buf, *buflen); src_map_to_sec4_sg(jrdev, req->src, src_nents, edesc->sec4_sg + 1, chained); if (*next_buflen) { sg_copy_part(next_buf, req->src, to_hash - *buflen, req->nbytes); state->current_buf = !state->current_buf; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); append_seq_in_ptr(desc, edesc->sec4_sg_dma, to_hash, LDST_SGF); map_seq_out_ptr_ctx(desc, jrdev, state, ctx->ctx_len); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done_ctx_dst, req); if (!ret) { ret = -EINPROGRESS; state->update = ahash_update_ctx; state->finup = ahash_finup_ctx; state->final = ahash_final_ctx; } else { ahash_unmap_ctx(jrdev, edesc, req, ctx->ctx_len, DMA_TO_DEVICE); kfree(edesc); } } else if (*next_buflen) { sg_copy(buf + *buflen, req->src, req->nbytes); *buflen = *next_buflen; *next_buflen = 0; } #ifdef DEBUG print_hex_dump(KERN_ERR, "buf@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, buf, *buflen, 1); print_hex_dump(KERN_ERR, "next buf@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, next_buf, *next_buflen, 1); #endif return ret; } /* submit ahash finup if it the first job descriptor after update */ static int ahash_finup_no_ctx(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *buf = state->current_buf ? state->buf_1 : state->buf_0; int buflen = state->current_buf ? state->buflen_1 : state->buflen_0; int last_buflen = state->current_buf ? state->buflen_0 : state->buflen_1; u32 *sh_desc = ctx->sh_desc_digest, *desc; dma_addr_t ptr = ctx->sh_desc_digest_dma; int sec4_sg_bytes, sec4_sg_src_index, src_nents; int digestsize = crypto_ahash_digestsize(ahash); struct ahash_edesc *edesc; bool chained = false; int sh_len; int ret = 0; src_nents = __sg_count(req->src, req->nbytes, &chained); sec4_sg_src_index = 2; sec4_sg_bytes = (sec4_sg_src_index + src_nents) * sizeof(struct sec4_sg_entry); /* allocate space for base edesc and hw desc commands, link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); edesc->src_nents = src_nents; edesc->chained = chained; edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); state->buf_dma = try_buf_map_to_sec4_sg(jrdev, edesc->sec4_sg, buf, state->buf_dma, buflen, last_buflen); src_map_to_sec4_sg(jrdev, req->src, src_nents, edesc->sec4_sg + 1, chained); append_seq_in_ptr(desc, edesc->sec4_sg_dma, buflen + req->nbytes, LDST_SGF); edesc->dst_dma = map_seq_out_ptr_result(desc, jrdev, req->result, digestsize); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done, req); if (!ret) { ret = -EINPROGRESS; } else { ahash_unmap(jrdev, edesc, req, digestsize); kfree(edesc); } return ret; } /* submit first update job descriptor after init */ static int ahash_update_first(struct ahash_request *req) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); struct device *jrdev = ctx->jrdev; gfp_t flags = (req->base.flags & (CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP)) ? GFP_KERNEL : GFP_ATOMIC; u8 *next_buf = state->buf_0 + state->current_buf * CAAM_MAX_HASH_BLOCK_SIZE; int *next_buflen = &state->buflen_0 + state->current_buf; int to_hash; u32 *sh_desc = ctx->sh_desc_update_first, *desc; dma_addr_t ptr = ctx->sh_desc_update_first_dma; int sec4_sg_bytes, src_nents; dma_addr_t src_dma; u32 options; struct ahash_edesc *edesc; bool chained = false; int ret = 0; int sh_len; *next_buflen = req->nbytes & (crypto_tfm_alg_blocksize(&ahash->base) - 1); to_hash = req->nbytes - *next_buflen; if (to_hash) { src_nents = sg_count(req->src, req->nbytes - (*next_buflen), &chained); dma_map_sg_chained(jrdev, req->src, src_nents ? : 1, DMA_TO_DEVICE, chained); sec4_sg_bytes = src_nents * sizeof(struct sec4_sg_entry); /* * allocate space for base edesc and hw desc commands, * link tables */ edesc = kzalloc(sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN + sec4_sg_bytes, GFP_DMA | flags); if (!edesc) { dev_err(jrdev, "could not allocate extended descriptor\n"); return -ENOMEM; } edesc->src_nents = src_nents; edesc->chained = chained; edesc->sec4_sg_bytes = sec4_sg_bytes; edesc->sec4_sg = (void *)edesc + sizeof(struct ahash_edesc) + DESC_JOB_IO_LEN; edesc->sec4_sg_dma = dma_map_single(jrdev, edesc->sec4_sg, sec4_sg_bytes, DMA_TO_DEVICE); if (src_nents) { sg_to_sec4_sg_last(req->src, src_nents, edesc->sec4_sg, 0); src_dma = edesc->sec4_sg_dma; options = LDST_SGF; } else { src_dma = sg_dma_address(req->src); options = 0; } if (*next_buflen) sg_copy_part(next_buf, req->src, to_hash, req->nbytes); sh_len = desc_len(sh_desc); desc = edesc->hw_desc; init_job_desc_shared(desc, ptr, sh_len, HDR_SHARE_DEFER | HDR_REVERSE); append_seq_in_ptr(desc, src_dma, to_hash, options); map_seq_out_ptr_ctx(desc, jrdev, state, ctx->ctx_len); dma_sync_single_for_device(jrdev, edesc->sec4_sg_dma, sec4_sg_bytes, DMA_TO_DEVICE); #ifdef DEBUG print_hex_dump(KERN_ERR, "jobdesc@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, desc, desc_bytes(desc), 1); #endif ret = caam_jr_enqueue(jrdev, desc, ahash_done_ctx_dst, req); if (!ret) { ret = -EINPROGRESS; state->update = ahash_update_ctx; state->finup = ahash_finup_ctx; state->final = ahash_final_ctx; } else { ahash_unmap_ctx(jrdev, edesc, req, ctx->ctx_len, DMA_TO_DEVICE); kfree(edesc); } } else if (*next_buflen) { state->update = ahash_update_no_ctx; state->finup = ahash_finup_no_ctx; state->final = ahash_final_no_ctx; sg_copy(next_buf, req->src, req->nbytes); } #ifdef DEBUG print_hex_dump(KERN_ERR, "next buf@"xstr(__LINE__)": ", DUMP_PREFIX_ADDRESS, 16, 4, next_buf, *next_buflen, 1); #endif return ret; } static int ahash_finup_first(struct ahash_request *req) { return ahash_digest(req); } static int ahash_init(struct ahash_request *req) { struct caam_hash_state *state = ahash_request_ctx(req); memset(state, 0, sizeof(struct caam_hash_state)); state->update = ahash_update_first; state->finup = ahash_finup_first; state->final = ahash_final_no_ctx; return 0; } static int ahash_update(struct ahash_request *req) { struct caam_hash_state *state = ahash_request_ctx(req); return state->update(req); } static int ahash_finup(struct ahash_request *req) { struct caam_hash_state *state = ahash_request_ctx(req); return state->finup(req); } static int ahash_final(struct ahash_request *req) { struct caam_hash_state *state = ahash_request_ctx(req); return state->final(req); } static int ahash_export(struct ahash_request *req, void *out) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); memcpy(out, ctx, sizeof(struct caam_hash_ctx)); memcpy(out + sizeof(struct caam_hash_ctx), state, sizeof(struct caam_hash_state)); return 0; } static int ahash_import(struct ahash_request *req, const void *in) { struct crypto_ahash *ahash = crypto_ahash_reqtfm(req); struct caam_hash_ctx *ctx = crypto_ahash_ctx(ahash); struct caam_hash_state *state = ahash_request_ctx(req); memcpy(ctx, in, sizeof(struct caam_hash_ctx)); memcpy(state, in + sizeof(struct caam_hash_ctx), sizeof(struct caam_hash_state)); return 0; } struct caam_hash_template { char name[CRYPTO_MAX_ALG_NAME]; char driver_name[CRYPTO_MAX_ALG_NAME]; char hmac_name[CRYPTO_MAX_ALG_NAME]; char hmac_driver_name[CRYPTO_MAX_ALG_NAME]; unsigned int blocksize; struct ahash_alg template_ahash; u32 alg_type; u32 alg_op; }; /* ahash descriptors */ static struct caam_hash_template driver_hash[] = { { .name = "sha1", .driver_name = "sha1-caam", .hmac_name = "hmac(sha1)", .hmac_driver_name = "hmac-sha1-caam", .blocksize = SHA1_BLOCK_SIZE, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = SHA1_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_SHA1, .alg_op = OP_ALG_ALGSEL_SHA1 | OP_ALG_AAI_HMAC, }, { .name = "sha224", .driver_name = "sha224-caam", .hmac_name = "hmac(sha224)", .hmac_driver_name = "hmac-sha224-caam", .blocksize = SHA224_BLOCK_SIZE, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = SHA224_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_SHA224, .alg_op = OP_ALG_ALGSEL_SHA224 | OP_ALG_AAI_HMAC, }, { .name = "sha256", .driver_name = "sha256-caam", .hmac_name = "hmac(sha256)", .hmac_driver_name = "hmac-sha256-caam", .blocksize = SHA256_BLOCK_SIZE, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = SHA256_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_SHA256, .alg_op = OP_ALG_ALGSEL_SHA256 | OP_ALG_AAI_HMAC, }, { .name = "sha384", .driver_name = "sha384-caam", .hmac_name = "hmac(sha384)", .hmac_driver_name = "hmac-sha384-caam", .blocksize = SHA384_BLOCK_SIZE, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = SHA384_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_SHA384, .alg_op = OP_ALG_ALGSEL_SHA384 | OP_ALG_AAI_HMAC, }, { .name = "sha512", .driver_name = "sha512-caam", .hmac_name = "hmac(sha512)", .hmac_driver_name = "hmac-sha512-caam", .blocksize = SHA512_BLOCK_SIZE, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = SHA512_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_SHA512, .alg_op = OP_ALG_ALGSEL_SHA512 | OP_ALG_AAI_HMAC, }, { .name = "md5", .driver_name = "md5-caam", .hmac_name = "hmac(md5)", .hmac_driver_name = "hmac-md5-caam", .blocksize = MD5_BLOCK_WORDS * 4, .template_ahash = { .init = ahash_init, .update = ahash_update, .final = ahash_final, .finup = ahash_finup, .digest = ahash_digest, .export = ahash_export, .import = ahash_import, .setkey = ahash_setkey, .halg = { .digestsize = MD5_DIGEST_SIZE, }, }, .alg_type = OP_ALG_ALGSEL_MD5, .alg_op = OP_ALG_ALGSEL_MD5 | OP_ALG_AAI_HMAC, }, }; struct caam_hash_alg { struct list_head entry; struct device *ctrldev; int alg_type; int alg_op; struct ahash_alg ahash_alg; }; static int caam_hash_cra_init(struct crypto_tfm *tfm) { struct crypto_ahash *ahash = __crypto_ahash_cast(tfm); struct crypto_alg *base = tfm->__crt_alg; struct hash_alg_common *halg = container_of(base, struct hash_alg_common, base); struct ahash_alg *alg = container_of(halg, struct ahash_alg, halg); struct caam_hash_alg *caam_hash = container_of(alg, struct caam_hash_alg, ahash_alg); struct caam_hash_ctx *ctx = crypto_tfm_ctx(tfm); struct caam_drv_private *priv = dev_get_drvdata(caam_hash->ctrldev); /* Sizes for MDHA running digests: MD5, SHA1, 224, 256, 384, 512 */ static const u8 runninglen[] = { HASH_MSG_LEN + MD5_DIGEST_SIZE, HASH_MSG_LEN + SHA1_DIGEST_SIZE, HASH_MSG_LEN + 32, HASH_MSG_LEN + SHA256_DIGEST_SIZE, HASH_MSG_LEN + 64, HASH_MSG_LEN + SHA512_DIGEST_SIZE }; int tgt_jr = atomic_inc_return(&priv->tfm_count); int ret = 0; /* * distribute tfms across job rings to ensure in-order * crypto request processing per tfm */ ctx->jrdev = priv->jrdev[tgt_jr % priv->total_jobrs]; /* copy descriptor header template value */ ctx->alg_type = OP_TYPE_CLASS2_ALG | caam_hash->alg_type; ctx->alg_op = OP_TYPE_CLASS2_ALG | caam_hash->alg_op; ctx->ctx_len = runninglen[(ctx->alg_op & OP_ALG_ALGSEL_SUBMASK) >> OP_ALG_ALGSEL_SHIFT]; crypto_ahash_set_reqsize(__crypto_ahash_cast(tfm), sizeof(struct caam_hash_state)); ret = ahash_set_sh_desc(ahash); return ret; } static void caam_hash_cra_exit(struct crypto_tfm *tfm) { struct caam_hash_ctx *ctx = crypto_tfm_ctx(tfm); if (ctx->sh_desc_update_dma && !dma_mapping_error(ctx->jrdev, ctx->sh_desc_update_dma)) dma_unmap_single(ctx->jrdev, ctx->sh_desc_update_dma, desc_bytes(ctx->sh_desc_update), DMA_TO_DEVICE); if (ctx->sh_desc_update_first_dma && !dma_mapping_error(ctx->jrdev, ctx->sh_desc_update_first_dma)) dma_unmap_single(ctx->jrdev, ctx->sh_desc_update_first_dma, desc_bytes(ctx->sh_desc_update_first), DMA_TO_DEVICE); if (ctx->sh_desc_fin_dma && !dma_mapping_error(ctx->jrdev, ctx->sh_desc_fin_dma)) dma_unmap_single(ctx->jrdev, ctx->sh_desc_fin_dma, desc_bytes(ctx->sh_desc_fin), DMA_TO_DEVICE); if (ctx->sh_desc_digest_dma && !dma_mapping_error(ctx->jrdev, ctx->sh_desc_digest_dma)) dma_unmap_single(ctx->jrdev, ctx->sh_desc_digest_dma, desc_bytes(ctx->sh_desc_digest), DMA_TO_DEVICE); if (ctx->sh_desc_finup_dma && !dma_mapping_error(ctx->jrdev, ctx->sh_desc_finup_dma)) dma_unmap_single(ctx->jrdev, ctx->sh_desc_finup_dma, desc_bytes(ctx->sh_desc_finup), DMA_TO_DEVICE); } static struct caam_hash_alg * caam_hash_alloc(struct device *ctrldev, struct caam_hash_template *template, bool keyed) { struct caam_hash_alg *t_alg; struct ahash_alg *halg; struct crypto_alg *alg; t_alg = kzalloc(sizeof(struct caam_hash_alg), GFP_KERNEL); if (!t_alg) { dev_err(ctrldev, "failed to allocate t_alg\n"); return ERR_PTR(-ENOMEM); } t_alg->ahash_alg = template->template_ahash; halg = &t_alg->ahash_alg; alg = &halg->halg.base; if (keyed) { snprintf(alg->cra_name, CRYPTO_MAX_ALG_NAME, "%s", template->hmac_name); snprintf(alg->cra_driver_name, CRYPTO_MAX_ALG_NAME, "%s", template->hmac_driver_name); } else { snprintf(alg->cra_name, CRYPTO_MAX_ALG_NAME, "%s", template->name); snprintf(alg->cra_driver_name, CRYPTO_MAX_ALG_NAME, "%s", template->driver_name); } alg->cra_module = THIS_MODULE; alg->cra_init = caam_hash_cra_init; alg->cra_exit = caam_hash_cra_exit; alg->cra_ctxsize = sizeof(struct caam_hash_ctx); alg->cra_priority = CAAM_CRA_PRIORITY; alg->cra_blocksize = template->blocksize; alg->cra_alignmask = 0; alg->cra_flags = CRYPTO_ALG_ASYNC | CRYPTO_ALG_TYPE_AHASH; alg->cra_type = &crypto_ahash_type; t_alg->alg_type = template->alg_type; t_alg->alg_op = template->alg_op; t_alg->ctrldev = ctrldev; return t_alg; } int caam_algapi_hash_startup(struct platform_device *pdev) { struct device *ctrldev; struct caam_drv_private *priv; int i = 0, err = 0, md_limit = 0, md_inst; u64 cha_inst; ctrldev = &pdev->dev; priv = dev_get_drvdata(ctrldev); INIT_LIST_HEAD(&priv->hash_list); atomic_set(&priv->tfm_count, -1); /* register algorithms the device supports */ cha_inst = rd_reg64(&priv->ctrl->perfmon.cha_num); md_inst = (cha_inst & CHA_ID_MD_MASK) >> CHA_ID_MD_SHIFT; if (md_inst) { md_limit = SHA512_DIGEST_SIZE; if ((rd_reg64(&priv->ctrl->perfmon.cha_id) & CHA_ID_MD_MASK) == CHA_ID_MD_LP256) /* LP256 limits digest size */ md_limit = SHA256_DIGEST_SIZE; } for (i = 0; i < ARRAY_SIZE(driver_hash); i++) { struct caam_hash_alg *t_alg; /* If no MD instantiated, or MD too small, skip */ if ((!md_inst) || (driver_hash[i].template_ahash.halg.digestsize > md_limit)) continue; /* register hmac version */ t_alg = caam_hash_alloc(ctrldev, &driver_hash[i], true); if (IS_ERR(t_alg)) { err = PTR_ERR(t_alg); dev_warn(ctrldev, "%s alg allocation failed\n", driver_hash[i].driver_name); continue; } err = crypto_register_ahash(&t_alg->ahash_alg); if (err) { dev_warn(ctrldev, "%s alg registration failed\n", t_alg->ahash_alg.halg.base.cra_driver_name); kfree(t_alg); } else list_add_tail(&t_alg->entry, &priv->hash_list); /* register unkeyed version */ t_alg = caam_hash_alloc(ctrldev, &driver_hash[i], false); if (IS_ERR(t_alg)) { err = PTR_ERR(t_alg); dev_warn(ctrldev, "%s alg allocation failed\n", driver_hash[i].driver_name); continue; } err = crypto_register_ahash(&t_alg->ahash_alg); if (err) { dev_warn(ctrldev, "%s alg registration failed\n", t_alg->ahash_alg.halg.base.cra_driver_name); kfree(t_alg); } else list_add_tail(&t_alg->entry, &priv->hash_list); } return err; } void caam_algapi_hash_shutdown(struct platform_device *pdev) { struct device *ctrldev; struct caam_drv_private *priv; struct caam_hash_alg *t_alg, *n; ctrldev = &pdev->dev; priv = dev_get_drvdata(ctrldev); if (!priv->hash_list.next) return; list_for_each_entry_safe(t_alg, n, &priv->hash_list, entry) { crypto_unregister_ahash(&t_alg->ahash_alg); list_del(&t_alg->entry); kfree(t_alg); } } #ifdef CONFIG_OF static void __exit caam_algapi_hash_exit(void) { struct device_node *dev_node; struct platform_device *pdev; dev_node = of_find_compatible_node(NULL, NULL, "fsl,sec-v4.0"); if (!dev_node) { dev_node = of_find_compatible_node(NULL, NULL, "fsl,sec4.0"); if (!dev_node) return; } pdev = of_find_device_by_node(dev_node); if (!pdev) return; of_node_put(dev_node); } static int __init caam_algapi_hash_init(void) { struct device_node *dev_node; struct platform_device *pdev; int err = 0; dev_node = of_find_compatible_node(NULL, NULL, "fsl,sec-v4.0"); if (!dev_node) { dev_node = of_find_compatible_node(NULL, NULL, "fsl,sec4.0"); if (!dev_node) return -ENODEV; } pdev = of_find_device_by_node(dev_node); if (!pdev) return -ENODEV; of_node_put(dev_node); return caam_algapi_hash_startup(pdev); } module_init(caam_algapi_hash_init); module_exit(caam_algapi_hash_exit); MODULE_LICENSE("GPL"); MODULE_DESCRIPTION("FSL CAAM support for ahash functions of crypto API"); MODULE_AUTHOR("Freescale Semiconductor - NMG"); #endif /a> 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

from __future__ import absolute_import, division, print_function

import binascii
import itertools
import math
import os

import pytest

from cryptography.exceptions import (
    AlreadyFinalized, InvalidSignature, _Reasons
)
from cryptography.hazmat.backends.interfaces import (
    PEMSerializationBackend, RSABackend
)
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import (
    padding, rsa, utils as asym_utils
)
from cryptography.hazmat.primitives.asymmetric.rsa import (
    RSAPrivateNumbers, RSAPublicNumbers
)
from cryptography.utils import CryptographyDeprecationWarning

from .fixtures_rsa import (
    RSA_KEY_1024, RSA_KEY_1025, RSA_KEY_1026, RSA_KEY_1027, RSA_KEY_1028,
    RSA_KEY_1029, RSA_KEY_1030, RSA_KEY_1031, RSA_KEY_1536, RSA_KEY_2048,
    RSA_KEY_2048_ALT, RSA_KEY_512, RSA_KEY_512_ALT, RSA_KEY_522, RSA_KEY_599,
    RSA_KEY_745, RSA_KEY_768,
)
from .utils import (
    _check_rsa_private_numbers, generate_rsa_verification_test
)
from ...doubles import (
    DummyAsymmetricPadding, DummyHashAlgorithm, DummyKeySerializationEncryption
)
from ...utils import (
    load_nist_vectors, load_pkcs1_vectors, load_rsa_nist_vectors,
    load_vectors_from_file, raises_unsupported_algorithm
)


class DummyMGF(object):
    _salt_length = 0


def _check_rsa_private_numbers_if_serializable(key):
    if isinstance(key, rsa.RSAPrivateKeyWithSerialization):
        _check_rsa_private_numbers(key.private_numbers())


def test_check_rsa_private_numbers_if_serializable():
    _check_rsa_private_numbers_if_serializable("notserializable")


def _flatten_pkcs1_examples(vectors):
    flattened_vectors = []
    for vector in vectors:
        examples = vector[0].pop("examples")
        for example in examples:
            merged_vector = (vector[0], vector[1], example)
            flattened_vectors.append(merged_vector)

    return flattened_vectors


def _build_oaep_sha2_vectors():
    base_path = os.path.join("asymmetric", "RSA", "oaep-custom")
    vectors = []
    hashalgs = [
        hashes.SHA1(),
        hashes.SHA224(),
        hashes.SHA256(),
        hashes.SHA384(),
        hashes.SHA512(),
    ]
    for mgf1alg, oaepalg in itertools.product(hashalgs, hashalgs):
        if mgf1alg.name == "sha1" and oaepalg.name == "sha1":
            # We need to generate the cartesian product of the permutations
            # of all the SHAs above, but SHA1/SHA1 is something we already
            # tested previously and thus did not generate custom vectors for.
            continue

        examples = _flatten_pkcs1_examples(
            load_vectors_from_file(
                os.path.join(
                    base_path,
                    "oaep-{0}-{1}.txt".format(
                        mgf1alg.name, oaepalg.name
                    )
                ),
                load_pkcs1_vectors
            )
        )
        # We've loaded the files, but the loaders don't give us any information
        # about the mgf1 or oaep hash algorithms. We know this info so we'll
        # just add that to the end of the tuple
        for private, public, vector in examples:
            vectors.append((private, public, vector, mgf1alg, oaepalg))
    return vectors


def _skip_pss_hash_algorithm_unsupported(backend, hash_alg):
    if not backend.rsa_padding_supported(
        padding.PSS(
            mgf=padding.MGF1(hash_alg),
            salt_length=padding.PSS.MAX_LENGTH
        )
    ):
        pytest.skip(
            "Does not support {0} in MGF1 using PSS.".format(hash_alg.name)
        )


@pytest.mark.requires_backend_interface(interface=RSABackend)
def test_skip_pss_hash_algorithm_unsupported(backend):
    with pytest.raises(pytest.skip.Exception):
        _skip_pss_hash_algorithm_unsupported(backend, DummyHashAlgorithm())


def test_modular_inverse():
    p = int(
        "d1f9f6c09fd3d38987f7970247b85a6da84907753d42ec52bc23b745093f4fff5cff3"
        "617ce43d00121a9accc0051f519c76e08cf02fc18acfe4c9e6aea18da470a2b611d2e"
        "56a7b35caa2c0239bc041a53cc5875ca0b668ae6377d4b23e932d8c995fd1e58ecfd8"
        "c4b73259c0d8a54d691cca3f6fb85c8a5c1baf588e898d481", 16
    )
    q = int(
        "d1519255eb8f678c86cfd06802d1fbef8b664441ac46b73d33d13a8404580a33a8e74"
        "cb2ea2e2963125b3d454d7a922cef24dd13e55f989cbabf64255a736671f4629a47b5"
        "b2347cfcd669133088d1c159518531025297c2d67c9da856a12e80222cd03b4c6ec0f"
        "86c957cb7bb8de7a127b645ec9e820aa94581e4762e209f01", 16
    )
    assert rsa._modinv(q, p) == int(
        "0275e06afa722999315f8f322275483e15e2fb46d827b17800f99110b269a6732748f"
        "624a382fa2ed1ec68c99f7fc56fb60e76eea51614881f497ba7034c17dde955f92f15"
        "772f8b2b41f3e56d88b1e096cdd293eba4eae1e82db815e0fadea0c4ec971bc6fd875"
        "c20e67e48c31a611e98d32c6213ae4c4d7b53023b2f80c538", 16
    )


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSA(object):
    @pytest.mark.parametrize(
        ("public_exponent", "key_size"),
        itertools.product(
            (3, 5, 65537),
            (1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1536, 2048)
        )
    )
    def test_generate_rsa_keys(self, backend, public_exponent, key_size):
        skey = rsa.generate_private_key(public_exponent, key_size, backend)
        assert skey.key_size == key_size

        _check_rsa_private_numbers_if_serializable(skey)
        pkey = skey.public_key()
        assert isinstance(pkey.public_numbers(), rsa.RSAPublicNumbers)

    def test_generate_bad_public_exponent(self, backend):
        with pytest.raises(ValueError):
            rsa.generate_private_key(public_exponent=1,
                                     key_size=2048,
                                     backend=backend)

        with pytest.raises(ValueError):
            rsa.generate_private_key(public_exponent=4,
                                     key_size=2048,
                                     backend=backend)

    def test_cant_generate_insecure_tiny_key(self, backend):
        with pytest.raises(ValueError):
            rsa.generate_private_key(public_exponent=65537,
                                     key_size=511,
                                     backend=backend)

        with pytest.raises(ValueError):
            rsa.generate_private_key(public_exponent=65537,
                                     key_size=256,
                                     backend=backend)

    @pytest.mark.parametrize(
        "pkcs1_example",
        load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "pss-vect.txt"),
            load_pkcs1_vectors
        )
    )
    def test_load_pss_vect_example_keys(self, pkcs1_example):
        secret, public = pkcs1_example

        private_num = rsa.RSAPrivateNumbers(
            p=secret["p"],
            q=secret["q"],
            d=secret["private_exponent"],
            dmp1=secret["dmp1"],
            dmq1=secret["dmq1"],
            iqmp=secret["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=secret["public_exponent"],
                n=secret["modulus"]
            )
        )
        _check_rsa_private_numbers(private_num)

        public_num = rsa.RSAPublicNumbers(
            e=public["public_exponent"],
            n=public["modulus"]
        )
        assert public_num

        public_num2 = private_num.public_numbers
        assert public_num2

        assert public_num.n == public_num2.n
        assert public_num.e == public_num2.e

    @pytest.mark.parametrize(
        "vector",
        load_vectors_from_file(
            os.path.join("asymmetric", "RSA", "oaep-label.txt"),
            load_nist_vectors)
    )
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=b"label"
            )
        ),
        skip_message="Does not support RSA OAEP labels"
    )
    def test_oaep_label_decrypt(self, vector, backend):
        private_key = serialization.load_der_private_key(
            binascii.unhexlify(vector["key"]), None, backend
        )
        assert vector["oaepdigest"] == b"SHA512"
        decrypted = private_key.decrypt(
            binascii.unhexlify(vector["input"]),
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA512()),
                algorithm=hashes.SHA512(),
                label=binascii.unhexlify(vector["oaeplabel"])
            )
        )
        assert vector["output"][1:-1] == decrypted

    @pytest.mark.parametrize(
        ("msg", "label"),
        [
            (b"amazing encrypted msg", b"some label"),
            (b"amazing encrypted msg", b""),
        ]
    )
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=b"label"
            )
        ),
        skip_message="Does not support RSA OAEP labels"
    )
    def test_oaep_label_roundtrip(self, msg, label, backend):
        private_key = RSA_KEY_2048.private_key(backend)
        ct = private_key.public_key().encrypt(
            msg,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=label
            )
        )
        pt = private_key.decrypt(
            ct,
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=label
            )
        )
        assert pt == msg

    @pytest.mark.parametrize(
        ("enclabel", "declabel"),
        [
            (b"label1", b"label2"),
            (b"label3", b""),
            (b"", b"label4"),
        ]
    )
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=b"label"
            )
        ),
        skip_message="Does not support RSA OAEP labels"
    )
    def test_oaep_wrong_label(self, enclabel, declabel, backend):
        private_key = RSA_KEY_2048.private_key(backend)
        msg = b"test"
        ct = private_key.public_key().encrypt(
            msg, padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=enclabel
            )
        )
        with pytest.raises(ValueError):
            private_key.decrypt(
                ct, padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA256()),
                    algorithm=hashes.SHA256(),
                    label=declabel
                )
            )

    @pytest.mark.supported(
        only_if=lambda backend: not backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA256(),
                label=b"label"
            )
        ),
        skip_message="Requires backend without RSA OAEP label support"
    )
    def test_unsupported_oaep_label_decrypt(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
            private_key.decrypt(
                b"0" * 64,
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    algorithm=hashes.SHA1(),
                    label=b"label"
                )
            )


def test_rsa_generate_invalid_backend():
    pretend_backend = object()

    with raises_unsupported_algorithm(_Reasons.BACKEND_MISSING_INTERFACE):
        rsa.generate_private_key(65537, 2048, pretend_backend)


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSASignature(object):
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    @pytest.mark.parametrize(
        "pkcs1_example",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs1v15sign-vectors.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_pkcs1v15_signing(self, pkcs1_example, backend):
        private, public, example = pkcs1_example
        private_key = rsa.RSAPrivateNumbers(
            p=private["p"],
            q=private["q"],
            d=private["private_exponent"],
            dmp1=private["dmp1"],
            dmq1=private["dmq1"],
            iqmp=private["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=private["public_exponent"],
                n=private["modulus"]
            )
        ).private_key(backend)
        signature = private_key.sign(
            binascii.unhexlify(example["message"]),
            padding.PKCS1v15(),
            hashes.SHA1()
        )
        assert binascii.hexlify(signature) == example["signature"]

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    @pytest.mark.parametrize(
        "pkcs1_example",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "pss-vect.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_pss_signing(self, pkcs1_example, backend):
        private, public, example = pkcs1_example
        private_key = rsa.RSAPrivateNumbers(
            p=private["p"],
            q=private["q"],
            d=private["private_exponent"],
            dmp1=private["dmp1"],
            dmq1=private["dmq1"],
            iqmp=private["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=private["public_exponent"],
                n=private["modulus"]
            )
        ).private_key(backend)
        public_key = rsa.RSAPublicNumbers(
            e=public["public_exponent"],
            n=public["modulus"]
        ).public_key(backend)
        signature = private_key.sign(
            binascii.unhexlify(example["message"]),
            padding.PSS(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA1()
        )
        assert len(signature) == math.ceil(private_key.key_size / 8.0)
        # PSS signatures contain randomness so we can't do an exact
        # signature check. Instead we'll verify that the signature created
        # successfully verifies.
        public_key.verify(
            signature,
            binascii.unhexlify(example["message"]),
            padding.PSS(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA1(),
        )

    @pytest.mark.parametrize(
        "hash_alg",
        [hashes.SHA224(), hashes.SHA256(), hashes.SHA384(), hashes.SHA512()]
    )
    def test_pss_signing_sha2(self, hash_alg, backend):
        _skip_pss_hash_algorithm_unsupported(backend, hash_alg)
        private_key = RSA_KEY_768.private_key(backend)
        public_key = private_key.public_key()
        pss = padding.PSS(
            mgf=padding.MGF1(hash_alg),
            salt_length=padding.PSS.MAX_LENGTH
        )
        msg = b"testing signature"
        signature = private_key.sign(msg, pss, hash_alg)
        public_key.verify(signature, msg, pss, hash_alg)

    @pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA512()) and
            backend.rsa_padding_supported(
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                )
            )
        ),
        skip_message="Does not support SHA512."
    )
    def test_pss_minimum_key_size_for_digest(self, backend):
        private_key = RSA_KEY_522.private_key(backend)
        private_key.sign(
            b"no failure",
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA512()
        )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    @pytest.mark.supported(
        only_if=lambda backend: backend.hash_supported(hashes.SHA512()),
        skip_message="Does not support SHA512."
    )
    def test_pss_signing_digest_too_large_for_key_size(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(ValueError):
            private_key.sign(
                b"msg",
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA512()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    def test_pss_signing_salt_length_too_long(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(ValueError):
            private_key.sign(
                b"failure coming",
                padding.PSS(
                    mgf=padding.MGF1(hashes.SHA1()),
                    salt_length=1000000
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_use_after_finalize(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.warns(CryptographyDeprecationWarning):
            signer = private_key.signer(padding.PKCS1v15(), hashes.SHA1())
        signer.update(b"sign me")
        signer.finalize()
        with pytest.raises(AlreadyFinalized):
            signer.finalize()
        with pytest.raises(AlreadyFinalized):
            signer.update(b"more data")

    def test_unsupported_padding(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
            private_key.sign(b"msg", DummyAsymmetricPadding(), hashes.SHA1())

    def test_padding_incorrect_type(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(TypeError):
            private_key.sign(b"msg", "notpadding", hashes.SHA1())

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        ),
        skip_message="Does not support PSS."
    )
    def test_unsupported_pss_mgf(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
            private_key.sign(
                b"msg",
                padding.PSS(
                    mgf=DummyMGF(),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_pkcs1_digest_too_large_for_key_size(self, backend):
        private_key = RSA_KEY_599.private_key(backend)
        with pytest.raises(ValueError):
            private_key.sign(
                b"failure coming",
                padding.PKCS1v15(),
                hashes.SHA512()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_pkcs1_minimum_key_size(self, backend):
        private_key = RSA_KEY_745.private_key(backend)
        private_key.sign(
            b"no failure",
            padding.PKCS1v15(),
            hashes.SHA512()
        )

    def test_sign(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        message = b"one little message"
        pkcs = padding.PKCS1v15()
        algorithm = hashes.SHA1()
        signature = private_key.sign(message, pkcs, algorithm)
        public_key = private_key.public_key()
        public_key.verify(signature, message, pkcs, algorithm)

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        ),
        skip_message="Does not support PSS."
    )
    def test_prehashed_sign(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        message = b"one little message"
        h = hashes.Hash(hashes.SHA1(), backend)
        h.update(message)
        digest = h.finalize()
        pss = padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        prehashed_alg = asym_utils.Prehashed(hashes.SHA1())
        signature = private_key.sign(digest, pss, prehashed_alg)
        public_key = private_key.public_key()
        public_key.verify(signature, message, pss, hashes.SHA1())

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        ),
        skip_message="Does not support PSS."
    )
    def test_prehashed_digest_mismatch(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        message = b"one little message"
        h = hashes.Hash(hashes.SHA512(), backend)
        h.update(message)
        digest = h.finalize()
        pss = padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        prehashed_alg = asym_utils.Prehashed(hashes.SHA1())
        with pytest.raises(ValueError):
            private_key.sign(digest, pss, prehashed_alg)

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_prehashed_unsupported_in_signer_ctx(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(TypeError), \
                pytest.warns(CryptographyDeprecationWarning):
            private_key.signer(
                padding.PKCS1v15(),
                asym_utils.Prehashed(hashes.SHA1())
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_prehashed_unsupported_in_verifier_ctx(self, backend):
        public_key = RSA_KEY_512.private_key(backend).public_key()
        with pytest.raises(TypeError), \
                pytest.warns(CryptographyDeprecationWarning):
            public_key.verifier(
                b"0" * 64,
                padding.PKCS1v15(),
                asym_utils.Prehashed(hashes.SHA1())
            )


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSAVerification(object):
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    @pytest.mark.parametrize(
        "pkcs1_example",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs1v15sign-vectors.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_pkcs1v15_verification(self, pkcs1_example, backend):
        private, public, example = pkcs1_example
        public_key = rsa.RSAPublicNumbers(
            e=public["public_exponent"],
            n=public["modulus"]
        ).public_key(backend)
        public_key.verify(
            binascii.unhexlify(example["signature"]),
            binascii.unhexlify(example["message"]),
            padding.PKCS1v15(),
            hashes.SHA1()
        )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_invalid_pkcs1v15_signature_wrong_data(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()
        signature = private_key.sign(
            b"sign me", padding.PKCS1v15(), hashes.SHA1()
        )
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature,
                b"incorrect data",
                padding.PKCS1v15(),
                hashes.SHA1()
            )

    def test_invalid_signature_sequence_removed(self, backend):
        """
        This test comes from wycheproof
        """
        key_der = binascii.unhexlify(
            b"30820122300d06092a864886f70d01010105000382010f003082010a02820101"
            b"00a2b451a07d0aa5f96e455671513550514a8a5b462ebef717094fa1fee82224"
            b"e637f9746d3f7cafd31878d80325b6ef5a1700f65903b469429e89d6eac88450"
            b"97b5ab393189db92512ed8a7711a1253facd20f79c15e8247f3d3e42e46e48c9"
            b"8e254a2fe9765313a03eff8f17e1a029397a1fa26a8dce26f490ed81299615d9"
            b"814c22da610428e09c7d9658594266f5c021d0fceca08d945a12be82de4d1ece"
            b"6b4c03145b5d3495d4ed5411eb878daf05fd7afc3e09ada0f1126422f590975a"
            b"1969816f48698bcbba1b4d9cae79d460d8f9f85e7975005d9bc22c4e5ac0f7c1"
            b"a45d12569a62807d3b9a02e5a530e773066f453d1f5b4c2e9cf7820283f742b9"
            b"d50203010001"
        )
        sig = binascii.unhexlify(
            b"498209f59a0679a1f926eccf3056da2cba553d7ab3064e7c41ad1d739f038249"
            b"f02f5ad12ee246073d101bc3cdb563e8b6be61562056422b7e6c16ad53deb12a"
            b"f5de744197753a35859833f41bb59c6597f3980132b7478fd0b95fd27dfad64a"
            b"20fd5c25312bbd41a85286cd2a83c8df5efa0779158d01b0747ff165b055eb28"
            b"80ea27095700a295593196d8c5922cf6aa9d7e29b5056db5ded5eb20aeb31b89"
            b"42e26b15a5188a4934cd7e39cfe379a197f49a204343a493452deebca436ee61"
            b"4f4daf989e355544489f7e69ffa8ccc6a1e81cf0ab33c3e6d7591091485a6a31"
            b"bda3b33946490057b9a3003d3fd9daf7c4778b43fd46144d945d815f12628ff4"
        )
        public_key = serialization.load_der_public_key(key_der, backend)
        with pytest.raises(InvalidSignature):
            public_key.verify(
                sig,
                binascii.unhexlify(b"313233343030"),
                padding.PKCS1v15(),
                hashes.SHA256()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_invalid_pkcs1v15_signature_wrong_key(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        private_key2 = RSA_KEY_512_ALT.private_key(backend)
        public_key = private_key2.public_key()
        msg = b"sign me"
        signature = private_key.sign(msg, padding.PKCS1v15(), hashes.SHA1())
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature, msg, padding.PKCS1v15(), hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=20
            )
        ),
        skip_message="Does not support PSS."
    )
    @pytest.mark.parametrize(
        "pkcs1_example",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "pss-vect.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_pss_verification(self, pkcs1_example, backend):
        private, public, example = pkcs1_example
        public_key = rsa.RSAPublicNumbers(
            e=public["public_exponent"],
            n=public["modulus"]
        ).public_key(backend)
        public_key.verify(
            binascii.unhexlify(example["signature"]),
            binascii.unhexlify(example["message"]),
            padding.PSS(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                salt_length=20
            ),
            hashes.SHA1()
        )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    def test_invalid_pss_signature_wrong_data(self, backend):
        public_key = rsa.RSAPublicNumbers(
            n=int(
                b"dffc2137d5e810cde9e4b4612f5796447218bab913b3fa98bdf7982e4fa6"
                b"ec4d6653ef2b29fb1642b095befcbea6decc178fb4bed243d3c3592c6854"
                b"6af2d3f3", 16
            ),
            e=65537
        ).public_key(backend)
        signature = binascii.unhexlify(
            b"0e68c3649df91c5bc3665f96e157efa75b71934aaa514d91e94ca8418d100f45"
            b"6f05288e58525f99666bab052adcffdf7186eb40f583bd38d98c97d3d524808b"
        )
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature,
                b"incorrect data",
                padding.PSS(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    def test_invalid_pss_signature_wrong_key(self, backend):
        signature = binascii.unhexlify(
            b"3a1880165014ba6eb53cc1449d13e5132ebcc0cfd9ade6d7a2494a0503bd0826"
            b"f8a46c431e0d7be0ca3e453f8b2b009e2733764da7927cc6dbe7a021437a242e"
        )
        public_key = rsa.RSAPublicNumbers(
            n=int(
                b"381201f4905d67dfeb3dec131a0fbea773489227ec7a1448c3109189ac68"
                b"5a95441be90866a14c4d2e139cd16db540ec6c7abab13ffff91443fd46a8"
                b"960cbb7658ded26a5c95c86f6e40384e1c1239c63e541ba221191c4dd303"
                b"231b42e33c6dbddf5ec9a746f09bf0c25d0f8d27f93ee0ae5c0d723348f4"
                b"030d3581e13522e1", 16
            ),
            e=65537
        ).public_key(backend)
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature,
                b"sign me",
                padding.PSS(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    def test_invalid_pss_signature_data_too_large_for_modulus(self, backend):
        signature = binascii.unhexlify(
            b"cb43bde4f7ab89eb4a79c6e8dd67e0d1af60715da64429d90c716a490b799c29"
            b"194cf8046509c6ed851052367a74e2e92d9b38947ed74332acb115a03fcc0222"
        )
        public_key = rsa.RSAPublicNumbers(
            n=int(
                b"381201f4905d67dfeb3dec131a0fbea773489227ec7a1448c3109189ac68"
                b"5a95441be90866a14c4d2e139cd16db540ec6c7abab13ffff91443fd46a8"
                b"960cbb7658ded26a5c95c86f6e40384e1c1239c63e541ba221191c4dd303"
                b"231b42e33c6dbddf5ec9a746f09bf0c25d0f8d27f93ee0ae5c0d723348f4"
                b"030d3581e13522", 16
            ),
            e=65537
        ).public_key(backend)
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature,
                b"sign me",
                padding.PSS(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_use_after_finalize(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()
        signature = private_key.sign(
            b"sign me", padding.PKCS1v15(), hashes.SHA1()
        )

        with pytest.warns(CryptographyDeprecationWarning):
            verifier = public_key.verifier(
                signature,
                padding.PKCS1v15(),
                hashes.SHA1()
            )
        verifier.update(b"sign me")
        verifier.verify()
        with pytest.raises(AlreadyFinalized):
            verifier.verify()
        with pytest.raises(AlreadyFinalized):
            verifier.update(b"more data")

    def test_unsupported_padding(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
            public_key.verify(
                b"sig", b"msg", DummyAsymmetricPadding(), hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_signature_not_bytes(self, backend):
        public_key = RSA_KEY_512.public_numbers.public_key(backend)
        signature = 1234

        with pytest.raises(TypeError), \
                pytest.warns(CryptographyDeprecationWarning):
            public_key.verifier(
                signature,
                padding.PKCS1v15(),
                hashes.SHA1()
            )

    def test_padding_incorrect_type(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()
        with pytest.raises(TypeError):
            public_key.verify(b"sig", b"msg", "notpadding", hashes.SHA1())

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0)
        ),
        skip_message="Does not support PSS."
    )
    def test_unsupported_pss_mgf(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
            public_key.verify(
                b"sig",
                b"msg",
                padding.PSS(
                    mgf=DummyMGF(),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA1()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    @pytest.mark.supported(
        only_if=lambda backend: backend.hash_supported(hashes.SHA512()),
        skip_message="Does not support SHA512."
    )
    def test_pss_verify_digest_too_large_for_key_size(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        signature = binascii.unhexlify(
            b"8b9a3ae9fb3b64158f3476dd8d8a1f1425444e98940e0926378baa9944d219d8"
            b"534c050ef6b19b1bdc6eb4da422e89161106a6f5b5cc16135b11eb6439b646bd"
        )
        public_key = private_key.public_key()
        with pytest.raises(ValueError):
            public_key.verify(
                signature,
                b"msg doesn't matter",
                padding.PSS(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    salt_length=padding.PSS.MAX_LENGTH
                ),
                hashes.SHA512()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS."
    )
    def test_pss_verify_salt_length_too_long(self, backend):
        signature = binascii.unhexlify(
            b"8b9a3ae9fb3b64158f3476dd8d8a1f1425444e98940e0926378baa9944d219d8"
            b"534c050ef6b19b1bdc6eb4da422e89161106a6f5b5cc16135b11eb6439b646bd"
        )
        public_key = rsa.RSAPublicNumbers(
            n=int(
                b"d309e4612809437548b747d7f9eb9cd3340f54fe42bb3f84a36933b0839c"
                b"11b0c8b7f67e11f7252370161e31159c49c784d4bc41c42a78ce0f0b40a3"
                b"ca8ffb91", 16
            ),
            e=65537
        ).public_key(backend)
        with pytest.raises(InvalidSignature):
            public_key.verify(
                signature,
                b"sign me",
                padding.PSS(
                    mgf=padding.MGF1(
                        algorithm=hashes.SHA1(),
                    ),
                    salt_length=1000000
                ),
                hashes.SHA1()
            )

    def test_verify(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        message = b"one little message"
        pkcs = padding.PKCS1v15()
        algorithm = hashes.SHA1()
        signature = private_key.sign(message, pkcs, algorithm)
        public_key = private_key.public_key()
        public_key.verify(signature, message, pkcs, algorithm)

    def test_prehashed_verify(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        message = b"one little message"
        h = hashes.Hash(hashes.SHA1(), backend)
        h.update(message)
        digest = h.finalize()
        prehashed_alg = asym_utils.Prehashed(hashes.SHA1())
        pkcs = padding.PKCS1v15()
        signature = private_key.sign(message, pkcs, hashes.SHA1())
        public_key = private_key.public_key()
        public_key.verify(signature, digest, pkcs, prehashed_alg)

    def test_prehashed_digest_mismatch(self, backend):
        public_key = RSA_KEY_512.private_key(backend).public_key()
        message = b"one little message"
        h = hashes.Hash(hashes.SHA1(), backend)
        h.update(message)
        data = h.finalize()
        prehashed_alg = asym_utils.Prehashed(hashes.SHA512())
        pkcs = padding.PKCS1v15()
        with pytest.raises(ValueError):
            public_key.verify(b"\x00" * 64, data, pkcs, prehashed_alg)


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSAPSSMGF1Verification(object):
    test_rsa_pss_mgf1_sha1 = pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA1()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS using MGF1 with SHA1."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGenPSS_186-2.rsp",
            "SigGenPSS_186-3.rsp",
            "SigVerPSS_186-3.rsp",
        ],
        hashes.SHA1(),
        lambda params, hash_alg: padding.PSS(
            mgf=padding.MGF1(
                algorithm=hash_alg,
            ),
            salt_length=params["salt_length"]
        )
    ))

    test_rsa_pss_mgf1_sha224 = pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA224()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS using MGF1 with SHA224."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGenPSS_186-2.rsp",
            "SigGenPSS_186-3.rsp",
            "SigVerPSS_186-3.rsp",
        ],
        hashes.SHA224(),
        lambda params, hash_alg: padding.PSS(
            mgf=padding.MGF1(
                algorithm=hash_alg,
            ),
            salt_length=params["salt_length"]
        )
    ))

    test_rsa_pss_mgf1_sha256 = pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS using MGF1 with SHA256."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGenPSS_186-2.rsp",
            "SigGenPSS_186-3.rsp",
            "SigVerPSS_186-3.rsp",
        ],
        hashes.SHA256(),
        lambda params, hash_alg: padding.PSS(
            mgf=padding.MGF1(
                algorithm=hash_alg,
            ),
            salt_length=params["salt_length"]
        )
    ))

    test_rsa_pss_mgf1_sha384 = pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA384()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS using MGF1 with SHA384."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGenPSS_186-2.rsp",
            "SigGenPSS_186-3.rsp",
            "SigVerPSS_186-3.rsp",
        ],
        hashes.SHA384(),
        lambda params, hash_alg: padding.PSS(
            mgf=padding.MGF1(
                algorithm=hash_alg,
            ),
            salt_length=params["salt_length"]
        )
    ))

    test_rsa_pss_mgf1_sha512 = pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA512()),
                salt_length=padding.PSS.MAX_LENGTH
            )
        ),
        skip_message="Does not support PSS using MGF1 with SHA512."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGenPSS_186-2.rsp",
            "SigGenPSS_186-3.rsp",
            "SigVerPSS_186-3.rsp",
        ],
        hashes.SHA512(),
        lambda params, hash_alg: padding.PSS(
            mgf=padding.MGF1(
                algorithm=hash_alg,
            ),
            salt_length=params["salt_length"]
        )
    ))


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSAPKCS1Verification(object):
    test_rsa_pkcs1v15_verify_sha1 = pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA1()) and
            backend.rsa_padding_supported(padding.PKCS1v15())
        ),
        skip_message="Does not support SHA1 and PKCS1v1.5."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGen15_186-2.rsp",
            "SigGen15_186-3.rsp",
            "SigVer15_186-3.rsp",
        ],
        hashes.SHA1(),
        lambda params, hash_alg: padding.PKCS1v15()
    ))

    test_rsa_pkcs1v15_verify_sha224 = pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA224()) and
            backend.rsa_padding_supported(padding.PKCS1v15())
        ),
        skip_message="Does not support SHA224 and PKCS1v1.5."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGen15_186-2.rsp",
            "SigGen15_186-3.rsp",
            "SigVer15_186-3.rsp",
        ],
        hashes.SHA224(),
        lambda params, hash_alg: padding.PKCS1v15()
    ))

    test_rsa_pkcs1v15_verify_sha256 = pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA256()) and
            backend.rsa_padding_supported(padding.PKCS1v15())
        ),
        skip_message="Does not support SHA256 and PKCS1v1.5."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGen15_186-2.rsp",
            "SigGen15_186-3.rsp",
            "SigVer15_186-3.rsp",
        ],
        hashes.SHA256(),
        lambda params, hash_alg: padding.PKCS1v15()
    ))

    test_rsa_pkcs1v15_verify_sha384 = pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA384()) and
            backend.rsa_padding_supported(padding.PKCS1v15())
        ),
        skip_message="Does not support SHA384 and PKCS1v1.5."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGen15_186-2.rsp",
            "SigGen15_186-3.rsp",
            "SigVer15_186-3.rsp",
        ],
        hashes.SHA384(),
        lambda params, hash_alg: padding.PKCS1v15()
    ))

    test_rsa_pkcs1v15_verify_sha512 = pytest.mark.supported(
        only_if=lambda backend: (
            backend.hash_supported(hashes.SHA512()) and
            backend.rsa_padding_supported(padding.PKCS1v15())
        ),
        skip_message="Does not support SHA512 and PKCS1v1.5."
    )(generate_rsa_verification_test(
        load_rsa_nist_vectors,
        os.path.join("asymmetric", "RSA", "FIPS_186-2"),
        [
            "SigGen15_186-2.rsp",
            "SigGen15_186-3.rsp",
            "SigVer15_186-3.rsp",
        ],
        hashes.SHA512(),
        lambda params, hash_alg: padding.PKCS1v15()
    ))


class TestPSS(object):
    def test_calculate_max_pss_salt_length(self):
        with pytest.raises(TypeError):
            padding.calculate_max_pss_salt_length(object(), hashes.SHA256())

    def test_invalid_salt_length_not_integer(self):
        with pytest.raises(TypeError):
            padding.PSS(
                mgf=padding.MGF1(
                    hashes.SHA1()
                ),
                salt_length=b"not_a_length"
            )

    def test_invalid_salt_length_negative_integer(self):
        with pytest.raises(ValueError):
            padding.PSS(
                mgf=padding.MGF1(
                    hashes.SHA1()
                ),
                salt_length=-1
            )

    def test_valid_pss_parameters(self):
        algorithm = hashes.SHA1()
        salt_length = algorithm.digest_size
        mgf = padding.MGF1(algorithm)
        pss = padding.PSS(mgf=mgf, salt_length=salt_length)
        assert pss._mgf == mgf
        assert pss._salt_length == salt_length

    def test_valid_pss_parameters_maximum(self):
        algorithm = hashes.SHA1()
        mgf = padding.MGF1(algorithm)
        pss = padding.PSS(mgf=mgf, salt_length=padding.PSS.MAX_LENGTH)
        assert pss._mgf == mgf
        assert pss._salt_length == padding.PSS.MAX_LENGTH


class TestMGF1(object):
    def test_invalid_hash_algorithm(self):
        with pytest.raises(TypeError):
            padding.MGF1(b"not_a_hash")

    def test_valid_mgf1_parameters(self):
        algorithm = hashes.SHA1()
        mgf = padding.MGF1(algorithm)
        assert mgf._algorithm == algorithm


class TestOAEP(object):
    def test_invalid_algorithm(self):
        mgf = padding.MGF1(hashes.SHA1())
        with pytest.raises(TypeError):
            padding.OAEP(
                mgf=mgf,
                algorithm=b"",
                label=None
            )


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSADecryption(object):
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    @pytest.mark.parametrize(
        "vector",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs1v15crypt-vectors.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_decrypt_pkcs1v15_vectors(self, vector, backend):
        private, public, example = vector
        skey = rsa.RSAPrivateNumbers(
            p=private["p"],
            q=private["q"],
            d=private["private_exponent"],
            dmp1=private["dmp1"],
            dmq1=private["dmq1"],
            iqmp=private["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=private["public_exponent"],
                n=private["modulus"]
            )
        ).private_key(backend)
        ciphertext = binascii.unhexlify(example["encryption"])
        assert len(ciphertext) == math.ceil(skey.key_size / 8.0)
        message = skey.decrypt(ciphertext, padding.PKCS1v15())
        assert message == binascii.unhexlify(example["message"])

    def test_unsupported_padding(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
            private_key.decrypt(b"0" * 64, DummyAsymmetricPadding())

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_decrypt_invalid_decrypt(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(ValueError):
            private_key.decrypt(
                b"\x00" * 64,
                padding.PKCS1v15()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_decrypt_ciphertext_too_large(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with pytest.raises(ValueError):
            private_key.decrypt(
                b"\x00" * 65,
                padding.PKCS1v15()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    def test_decrypt_ciphertext_too_small(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        ct = binascii.unhexlify(
            b"50b4c14136bd198c2f3c3ed243fce036e168d56517984a263cd66492b80804f1"
            b"69d210f2b9bdfb48b12f9ea05009c77da257cc600ccefe3a6283789d8ea0"
        )
        with pytest.raises(ValueError):
            private_key.decrypt(
                ct,
                padding.PKCS1v15()
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        ),
        skip_message="Does not support OAEP."
    )
    @pytest.mark.parametrize(
        "vector",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs-1v2-1d2-vec", "oaep-vect.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_decrypt_oaep_vectors(self, vector, backend):
        private, public, example = vector
        skey = rsa.RSAPrivateNumbers(
            p=private["p"],
            q=private["q"],
            d=private["private_exponent"],
            dmp1=private["dmp1"],
            dmq1=private["dmq1"],
            iqmp=private["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=private["public_exponent"],
                n=private["modulus"]
            )
        ).private_key(backend)
        message = skey.decrypt(
            binascii.unhexlify(example["encryption"]),
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        )
        assert message == binascii.unhexlify(example["message"])

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA224()),
                algorithm=hashes.SHA224(),
                label=None
            )
        ),
        skip_message="Does not support OAEP using SHA224 MGF1 and SHA224 hash."
    )
    @pytest.mark.parametrize(
        "vector",
        _build_oaep_sha2_vectors()
    )
    def test_decrypt_oaep_sha2_vectors(self, vector, backend):
        private, public, example, mgf1_alg, hash_alg = vector
        skey = rsa.RSAPrivateNumbers(
            p=private["p"],
            q=private["q"],
            d=private["private_exponent"],
            dmp1=private["dmp1"],
            dmq1=private["dmq1"],
            iqmp=private["iqmp"],
            public_numbers=rsa.RSAPublicNumbers(
                e=private["public_exponent"],
                n=private["modulus"]
            )
        ).private_key(backend)
        message = skey.decrypt(
            binascii.unhexlify(example["encryption"]),
            padding.OAEP(
                mgf=padding.MGF1(algorithm=mgf1_alg),
                algorithm=hash_alg,
                label=None
            )
        )
        assert message == binascii.unhexlify(example["message"])

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        ),
        skip_message="Does not support OAEP."
    )
    def test_invalid_oaep_decryption(self, backend):
        # More recent versions of OpenSSL may raise RSA_R_OAEP_DECODING_ERROR
        # This test triggers it and confirms that we properly handle it. Other
        # backends should also return the proper ValueError.
        private_key = RSA_KEY_512.private_key(backend)

        ciphertext = private_key.public_key().encrypt(
            b'secure data',
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        )

        private_key_alt = RSA_KEY_512_ALT.private_key(backend)

        with pytest.raises(ValueError):
            private_key_alt.decrypt(
                ciphertext,
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    algorithm=hashes.SHA1(),
                    label=None
                )
            )

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        ),
        skip_message="Does not support OAEP."
    )
    def test_invalid_oaep_decryption_data_to_large_for_modulus(self, backend):
        key = RSA_KEY_2048_ALT.private_key(backend)

        ciphertext = (
            b'\xb1ph\xc0\x0b\x1a|\xe6\xda\xea\xb5\xd7%\x94\x07\xf96\xfb\x96'
            b'\x11\x9b\xdc4\xea.-\x91\x80\x13S\x94\x04m\xe9\xc5/F\x1b\x9b:\\'
            b'\x1d\x04\x16ML\xae\xb32J\x01yuA\xbb\x83\x1c\x8f\xf6\xa5\xdbp\xcd'
            b'\nx\xc7\xf6\x15\xb2/\xdcH\xae\xe7\x13\x13by\r4t\x99\x0fc\x1f\xc1'
            b'\x1c\xb1\xdd\xc5\x08\xd1\xee\xa1XQ\xb8H@L5v\xc3\xaf\xf2\r\x97'
            b'\xed\xaa\xe7\xf1\xd4xai\xd3\x83\xd9\xaa9\xbfx\xe1\x87F \x01\xff'
            b'L\xccv}ae\xb3\xfa\xf2B\xb8\xf9\x04H\x94\x85\xcb\x86\xbb\\ghx!W31'
            b'\xc7;t\na_E\xc2\x16\xb0;\xa1\x18\t\x1b\xe1\xdb\x80>)\x15\xc6\x12'
            b'\xcb\xeeg`\x8b\x9b\x1b\x05y4\xb0\x84M6\xcd\xa1\x827o\xfd\x96\xba'
            b'Z#\x8d\xae\x01\xc9\xf2\xb6\xde\x89{8&eQ\x1e8\x03\x01#?\xb66\\'
            b'\xad.\xe9\xfa!\x95 c{\xcaz\xe0*\tP\r\x91\x9a)B\xb5\xadN\xf4$\x83'
            b'\t\xb5u\xab\x19\x99'
        )

        with pytest.raises(ValueError):
            key.decrypt(
                ciphertext,
                padding.OAEP(
                    algorithm=hashes.SHA1(),
                    mgf=padding.MGF1(hashes.SHA1()),
                    label=None
                )
            )

    def test_unsupported_oaep_mgf(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
            private_key.decrypt(
                b"0" * 64,
                padding.OAEP(
                    mgf=DummyMGF(),
                    algorithm=hashes.SHA1(),
                    label=None
                )
            )


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSAEncryption(object):
    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA1()),
                algorithm=hashes.SHA1(),
                label=None
            )
        ),
        skip_message="Does not support OAEP."
    )
    @pytest.mark.parametrize(
        ("key_data", "pad"),
        itertools.product(
            (RSA_KEY_1024, RSA_KEY_1025, RSA_KEY_1026, RSA_KEY_1027,
             RSA_KEY_1028, RSA_KEY_1029, RSA_KEY_1030, RSA_KEY_1031,
             RSA_KEY_1536, RSA_KEY_2048),
            [
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    algorithm=hashes.SHA1(),
                    label=None
                )
            ]
        )
    )
    def test_rsa_encrypt_oaep(self, key_data, pad, backend):
        private_key = key_data.private_key(backend)
        pt = b"encrypt me!"
        public_key = private_key.public_key()
        ct = public_key.encrypt(pt, pad)
        assert ct != pt
        assert len(ct) == math.ceil(public_key.key_size / 8.0)
        recovered_pt = private_key.decrypt(ct, pad)
        assert recovered_pt == pt

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.OAEP(
                mgf=padding.MGF1(algorithm=hashes.SHA256()),
                algorithm=hashes.SHA512(),
                label=None
            )
        ),
        skip_message="Does not support OAEP using SHA256 MGF1 and SHA512 hash."
    )
    @pytest.mark.parametrize(
        ("mgf1hash", "oaephash"),
        itertools.product([
            hashes.SHA1(),
            hashes.SHA224(),
            hashes.SHA256(),
            hashes.SHA384(),
            hashes.SHA512(),
        ], [
            hashes.SHA1(),
            hashes.SHA224(),
            hashes.SHA256(),
            hashes.SHA384(),
            hashes.SHA512(),
        ])
    )
    def test_rsa_encrypt_oaep_sha2(self, mgf1hash, oaephash, backend):
        pad = padding.OAEP(
            mgf=padding.MGF1(algorithm=mgf1hash),
            algorithm=oaephash,
            label=None
        )
        private_key = RSA_KEY_2048.private_key(backend)
        pt = b"encrypt me using sha2 hashes!"
        public_key = private_key.public_key()
        ct = public_key.encrypt(pt, pad)
        assert ct != pt
        assert len(ct) == math.ceil(public_key.key_size / 8.0)
        recovered_pt = private_key.decrypt(ct, pad)
        assert recovered_pt == pt

    @pytest.mark.supported(
        only_if=lambda backend: backend.rsa_padding_supported(
            padding.PKCS1v15()
        ),
        skip_message="Does not support PKCS1v1.5."
    )
    @pytest.mark.parametrize(
        ("key_data", "pad"),
        itertools.product(
            (RSA_KEY_1024, RSA_KEY_1025, RSA_KEY_1026, RSA_KEY_1027,
             RSA_KEY_1028, RSA_KEY_1029, RSA_KEY_1030, RSA_KEY_1031,
             RSA_KEY_1536, RSA_KEY_2048),
            [padding.PKCS1v15()]
        )
    )
    def test_rsa_encrypt_pkcs1v15(self, key_data, pad, backend):
        private_key = key_data.private_key(backend)
        pt = b"encrypt me!"
        public_key = private_key.public_key()
        ct = public_key.encrypt(pt, pad)
        assert ct != pt
        assert len(ct) == math.ceil(public_key.key_size / 8.0)
        recovered_pt = private_key.decrypt(ct, pad)
        assert recovered_pt == pt

    @pytest.mark.parametrize(
        ("key_data", "pad"),
        itertools.product(
            (RSA_KEY_1024, RSA_KEY_1025, RSA_KEY_1026, RSA_KEY_1027,
             RSA_KEY_1028, RSA_KEY_1029, RSA_KEY_1030, RSA_KEY_1031,
             RSA_KEY_1536, RSA_KEY_2048),
            (
                padding.OAEP(
                    mgf=padding.MGF1(algorithm=hashes.SHA1()),
                    algorithm=hashes.SHA1(),
                    label=None
                ),
                padding.PKCS1v15()
            )
        )
    )
    def test_rsa_encrypt_key_too_small(self, key_data, pad, backend):
        private_key = key_data.private_key(backend)
        public_key = private_key.public_key()
        # Slightly smaller than the key size but not enough for padding.
        with pytest.raises(ValueError):
            public_key.encrypt(
                b"\x00" * (private_key.key_size // 8 - 1),
                pad
            )

        # Larger than the key size.
        with pytest.raises(ValueError):
            public_key.encrypt(
                b"\x00" * (private_key.key_size // 8 + 5),
                pad
            )

    def test_unsupported_padding(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()

        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_PADDING):
            public_key.encrypt(b"somedata", DummyAsymmetricPadding())
        with pytest.raises(TypeError):
            public_key.encrypt(b"somedata", padding=object())

    def test_unsupported_oaep_mgf(self, backend):
        private_key = RSA_KEY_512.private_key(backend)
        public_key = private_key.public_key()

        with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_MGF):
            public_key.encrypt(
                b"ciphertext",
                padding.OAEP(
                    mgf=DummyMGF(),
                    algorithm=hashes.SHA1(),
                    label=None
                )
            )


@pytest.mark.requires_backend_interface(interface=RSABackend)
class TestRSANumbers(object):
    def test_rsa_public_numbers(self):
        public_numbers = rsa.RSAPublicNumbers(e=1, n=15)
        assert public_numbers.e == 1
        assert public_numbers.n == 15

    def test_rsa_private_numbers(self):
        public_numbers = rsa.RSAPublicNumbers(e=1, n=15)
        private_numbers = rsa.RSAPrivateNumbers(
            p=3,
            q=5,
            d=1,
            dmp1=1,
            dmq1=1,
            iqmp=2,
            public_numbers=public_numbers
        )

        assert private_numbers.p == 3
        assert private_numbers.q == 5
        assert private_numbers.d == 1
        assert private_numbers.dmp1 == 1
        assert private_numbers.dmq1 == 1
        assert private_numbers.iqmp == 2
        assert private_numbers.public_numbers == public_numbers

    def test_rsa_private_numbers_create_key(self, backend):
        private_key = RSA_KEY_1024.private_key(backend)
        assert private_key

    def test_rsa_public_numbers_create_key(self, backend):
        public_key = RSA_KEY_1024.public_numbers.public_key(backend)
        assert public_key

    def test_public_numbers_invalid_types(self):
        with pytest.raises(TypeError):
            rsa.RSAPublicNumbers(e=None, n=15)

        with pytest.raises(TypeError):
            rsa.RSAPublicNumbers(e=1, n=None)

    def test_private_numbers_invalid_types(self):
        public_numbers = rsa.RSAPublicNumbers(e=1, n=15)

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=None,
                q=5,
                d=1,
                dmp1=1,
                dmq1=1,
                iqmp=2,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=None,
                d=1,
                dmp1=1,
                dmq1=1,
                iqmp=2,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=5,
                d=None,
                dmp1=1,
                dmq1=1,
                iqmp=2,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=5,
                d=1,
                dmp1=None,
                dmq1=1,
                iqmp=2,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=5,
                d=1,
                dmp1=1,
                dmq1=None,
                iqmp=2,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=5,
                d=1,
                dmp1=1,
                dmq1=1,
                iqmp=None,
                public_numbers=public_numbers
            )

        with pytest.raises(TypeError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=5,
                d=1,
                dmp1=1,
                dmq1=1,
                iqmp=2,
                public_numbers=None
            )

    def test_invalid_public_numbers_argument_values(self, backend):
        # Start with public_exponent=7, modulus=15. Then change one value at a
        # time to test the bounds.

        # Test a modulus < 3.

        with pytest.raises(ValueError):
            rsa.RSAPublicNumbers(e=7, n=2).public_key(backend)

        # Test a public_exponent < 3
        with pytest.raises(ValueError):
            rsa.RSAPublicNumbers(e=1, n=15).public_key(backend)

        # Test a public_exponent > modulus
        with pytest.raises(ValueError):
            rsa.RSAPublicNumbers(e=17, n=15).public_key(backend)

        # Test a public_exponent that is not odd.
        with pytest.raises(ValueError):
            rsa.RSAPublicNumbers(e=14, n=15).public_key(backend)

    def test_invalid_private_numbers_argument_values(self, backend):
        # Start with p=3, q=11, private_exponent=3, public_exponent=7,
        # modulus=33, dmp1=1, dmq1=3, iqmp=2. Then change one value at
        # a time to test the bounds.

        # Test a modulus < 3.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=2
                )
            ).private_key(backend)

        # Test a modulus != p * q.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=35
                )
            ).private_key(backend)

        # Test a p > modulus.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=37,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a q > modulus.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=37,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a dmp1 > modulus.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=35,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a dmq1 > modulus.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=35,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test an iqmp > modulus.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=35,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a private_exponent > modulus
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=37,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a public_exponent < 3
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=1,
                    n=33
                )
            ).private_key(backend)

        # Test a public_exponent > modulus
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=35,
                public_numbers=rsa.RSAPublicNumbers(
                    e=65537,
                    n=33
                )
            ).private_key(backend)

        # Test a public_exponent that is not odd.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=6,
                    n=33
                )
            ).private_key(backend)

        # Test a dmp1 that is not odd.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=2,
                dmq1=3,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

        # Test a dmq1 that is not odd.
        with pytest.raises(ValueError):
            rsa.RSAPrivateNumbers(
                p=3,
                q=11,
                d=3,
                dmp1=1,
                dmq1=4,
                iqmp=2,
                public_numbers=rsa.RSAPublicNumbers(
                    e=7,
                    n=33
                )
            ).private_key(backend)

    def test_public_number_repr(self):
        num = RSAPublicNumbers(1, 1)
        assert repr(num) == "<RSAPublicNumbers(e=1, n=1)>"


class TestRSANumbersEquality(object):
    def test_public_numbers_eq(self):
        num = RSAPublicNumbers(1, 2)
        num2 = RSAPublicNumbers(1, 2)
        assert num == num2

    def test_public_numbers_ne(self):
        num = RSAPublicNumbers(1, 2)
        assert num != RSAPublicNumbers(2, 2)
        assert num != RSAPublicNumbers(1, 3)
        assert num != object()

    def test_private_numbers_eq(self):
        pub = RSAPublicNumbers(1, 2)
        num = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub)
        pub2 = RSAPublicNumbers(1, 2)
        num2 = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub2)
        assert num == num2

    def test_private_numbers_ne(self):
        pub = RSAPublicNumbers(1, 2)
        num = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub)
        assert num != RSAPrivateNumbers(
            1, 2, 3, 4, 5, 7, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 2, 3, 4, 4, 6, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 2, 3, 5, 5, 6, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 2, 4, 4, 5, 6, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 3, 3, 4, 5, 6, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            2, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 2, 3, 4, 5, 6, RSAPublicNumbers(2, 2)
        )
        assert num != RSAPrivateNumbers(
            1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 3)
        )
        assert num != object()

    def test_public_numbers_hash(self):
        pub1 = RSAPublicNumbers(3, 17)
        pub2 = RSAPublicNumbers(3, 17)
        pub3 = RSAPublicNumbers(7, 21)

        assert hash(pub1) == hash(pub2)
        assert hash(pub1) != hash(pub3)

    def test_private_numbers_hash(self):
        priv1 = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 2))
        priv2 = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 2))
        priv3 = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 3))

        assert hash(priv1) == hash(priv2)
        assert hash(priv1) != hash(priv3)


class TestRSAPrimeFactorRecovery(object):
    @pytest.mark.parametrize(
        "vector",
        _flatten_pkcs1_examples(load_vectors_from_file(
            os.path.join(
                "asymmetric", "RSA", "pkcs1v15crypt-vectors.txt"),
            load_pkcs1_vectors
        ))
    )
    def test_recover_prime_factors(self, vector):
        private, public, example = vector
        p, q = rsa.rsa_recover_prime_factors(
            private["modulus"],
            private["public_exponent"],
            private["private_exponent"]
        )
        # Unfortunately there is no convention on which prime should be p
        # and which one q. The function we use always makes p > q, but the
        # NIST vectors are not so consistent. Accordingly, we verify we've
        # recovered the proper (p, q) by sorting them and asserting on that.
        assert sorted([p, q]) == sorted([private["p"], private["q"]])
        assert p > q

    def test_invalid_recover_prime_factors(self):
        with pytest.raises(ValueError):
            rsa.rsa_recover_prime_factors(34, 3, 7)


@pytest.mark.requires_backend_interface(interface=RSABackend)
@pytest.mark.requires_backend_interface(interface=PEMSerializationBackend)
class TestRSAPrivateKeySerialization(object):
    @pytest.mark.parametrize(
        ("fmt", "password"),
        itertools.product(
            [
                serialization.PrivateFormat.TraditionalOpenSSL,
                serialization.PrivateFormat.PKCS8
            ],
            [
                b"s",
                b"longerpassword",
                b"!*$&(@#$*&($T@%_somesymbols",
                b"\x01" * 1000,
            ]
        )
    )
    def test_private_bytes_encrypted_pem(self, backend, fmt, password):
        key = RSA_KEY_2048.private_key(backend)
        serialized = key.private_bytes(
            serialization.Encoding.PEM,
            fmt,
            serialization.BestAvailableEncryption(password)
        )
        loaded_key = serialization.load_pem_private_key(
            serialized, password, backend
        )
        loaded_priv_num = loaded_key.private_numbers()
        priv_num = key.private_numbers()
        assert loaded_priv_num == priv_num

    @pytest.mark.parametrize(
        ("fmt", "password"),
        [
            [serialization.PrivateFormat.PKCS8, b"s"],
            [serialization.PrivateFormat.PKCS8, b"longerpassword"],
            [serialization.PrivateFormat.PKCS8, b"!*$&(@#$*&($T@%_somesymbol"],
            [serialization.PrivateFormat.PKCS8, b"\x01" * 1000]
        ]
    )
    def test_private_bytes_encrypted_der(self, backend, fmt, password):
        key = RSA_KEY_2048.private_key(backend)
        serialized = key.private_bytes(
            serialization.Encoding.DER,
            fmt,
            serialization.BestAvailableEncryption(password)
        )
        loaded_key = serialization.load_der_private_key(
            serialized, password, backend
        )
        loaded_priv_num = loaded_key.private_numbers()
        priv_num = key.private_numbers()
        assert loaded_priv_num == priv_num

    @pytest.mark.parametrize(
        ("encoding", "fmt", "loader_func"),
        [
            [
                serialization.Encoding.PEM,
                serialization.PrivateFormat.TraditionalOpenSSL,
                serialization.load_pem_private_key
            ],
            [
                serialization.Encoding.DER,
                serialization.PrivateFormat.TraditionalOpenSSL,
                serialization.load_der_private_key
            ],
            [
                serialization.Encoding.PEM,
                serialization.PrivateFormat.PKCS8,
                serialization.load_pem_private_key
            ],
            [
                serialization.Encoding.DER,
                serialization.PrivateFormat.PKCS8,
                serialization.load_der_private_key
            ],
        ]
    )
    def test_private_bytes_unencrypted(self, backend, encoding, fmt,
                                       loader_func):
        key = RSA_KEY_2048.private_key(backend)
        serialized = key.private_bytes(
            encoding, fmt, serialization.NoEncryption()
        )
        loaded_key = loader_func(serialized, None, backend)
        loaded_priv_num = loaded_key.private_numbers()
        priv_num = key.private_numbers()
        assert loaded_priv_num == priv_num

    @pytest.mark.parametrize(
        ("key_path", "encoding", "loader_func"),
        [
            [
                os.path.join(
                    "asymmetric",
                    "Traditional_OpenSSL_Serialization",
                    "testrsa.pem"
                ),
                serialization.Encoding.PEM,
                serialization.load_pem_private_key
            ],
            [
                os.path.join("asymmetric", "DER_Serialization", "testrsa.der"),
                serialization.Encoding.DER,
                serialization.load_der_private_key
            ],
        ]
    )
    def test_private_bytes_traditional_openssl_unencrypted(
        self, backend, key_path, encoding, loader_func
    ):
        key_bytes = load_vectors_from_file(
            key_path, lambda pemfile: pemfile.read(), mode="rb"
        )
        key = loader_func(key_bytes, None, backend)
        serialized = key.private_bytes(
            encoding,
            serialization.PrivateFormat.TraditionalOpenSSL,
            serialization.NoEncryption()
        )
        assert serialized == key_bytes

    def test_private_bytes_traditional_der_encrypted_invalid(self, backend):
        key = RSA_KEY_2048.private_key(backend)
        with pytest.raises(ValueError):
            key.private_bytes(
                serialization.Encoding.DER,
                serialization.PrivateFormat.TraditionalOpenSSL,
                serialization.BestAvailableEncryption(b"password")
            )

    def test_private_bytes_invalid_encoding(self, backend):
        key = RSA_KEY_2048.private_key(backend)
        with pytest.raises(TypeError):
            key.private_bytes(
                "notencoding",
                serialization.PrivateFormat.PKCS8,
                serialization.NoEncryption()
            )

    def test_private_bytes_invalid_format(self, backend):
        key = RSA_KEY_2048.private_key(backend)
        with pytest.raises(TypeError):
            key.private_bytes(
                serialization.Encoding.PEM,
                "invalidformat",
                serialization.NoEncryption()
            )

    def test_private_bytes_invalid_encryption_algorithm(self, backend):
        key = RSA_KEY_2048.private_key(backend)
        with pytest.raises(TypeError):
            key.private_bytes(
                serialization.Encoding.PEM,
                serialization.PrivateFormat.TraditionalOpenSSL,
                "notanencalg"
            )

    def test_private_bytes_unsupported_encryption_type(self, backend):
        key = RSA_KEY_2048.private_key(backend)
        with pytest.raises(ValueError):
            key.private_bytes(
                serialization.Encoding.PEM,
                serialization.PrivateFormat.TraditionalOpenSSL,
                DummyKeySerializationEncryption()
            )


@pytest.mark.requires_backend_interface(interface=RSABackend)
@pytest.mark.requires_backend_interface(interface=PEMSerializationBackend)
class TestRSAPEMPublicKeySerialization(object):
    @pytest.mark.parametrize(
        ("key_path", "loader_func", "encoding", "format"),
        [
            (
                os.path.join("asymmetric", "public", "PKCS1", "rsa.pub.pem"),
                serialization.load_pem_public_key,
                serialization.Encoding.PEM,
                serialization.PublicFormat.PKCS1,
            ), (
                os.path.join("asymmetric", "public", "PKCS1", "rsa.pub.der"),
                serialization.load_der_public_key,
                serialization.Encoding.DER,
                serialization.PublicFormat.PKCS1,
            ), (
                os.path.join("asymmetric", "PKCS8", "unenc-rsa-pkcs8.pub.pem"),
                serialization.load_pem_public_key,
                serialization.Encoding.PEM,
                serialization.PublicFormat.SubjectPublicKeyInfo,
            ), (
                os.path.join(
                    "asymmetric",
                    "DER_Serialization",
                    "unenc-rsa-pkcs8.pub.der"
                ),
                serialization.load_der_public_key,
                serialization.Encoding.DER,
                serialization.PublicFormat.SubjectPublicKeyInfo,
            )
        ]
    )
    def test_public_bytes_match(self, key_path, loader_func, encoding, format,
                                backend):
        key_bytes = load_vectors_from_file(
            key_path, lambda pemfile: pemfile.read(), mode="rb"
        )
        key = loader_func(key_bytes, backend)
        serialized = key.public_bytes(encoding, format)
        assert serialized == key_bytes

    def test_public_bytes_openssh(self, backend):
        key_bytes = load_vectors_from_file(
            os.path.join("asymmetric", "public", "PKCS1", "rsa.pub.pem"),
            lambda pemfile: pemfile.read(), mode="rb"
        )
        key = serialization.load_pem_public_key(key_bytes, backend)

        ssh_bytes = key.public_bytes(
            serialization.Encoding.OpenSSH, serialization.PublicFormat.OpenSSH
        )
        assert ssh_bytes == (
            b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQC7JHoJfg6yNzLMOWet8Z49a4KD"
            b"0dCspMAYvo2YAMB7/wdEycocujbhJ2n/seONi+5XqTqqFkM5VBl8rmkkFPZk/7x0"
            b"xmdsTPECSWnHK+HhoaNDFPR3j8jQhVo1laxiqcEhAHegi5cwtFosuJAvSKAFKEvy"
            b"D43si00DQnXWrYHAEQ=="
        )

        with pytest.raises(ValueError):
            key.public_bytes(
                serialization.Encoding.PEM, serialization.PublicFormat.OpenSSH
            )
        with pytest.raises(ValueError):
            key.public_bytes(
                serialization.Encoding.DER, serialization.PublicFormat.OpenSSH
            )
        with pytest.raises(ValueError):
            key.public_bytes(
                serialization.Encoding.OpenSSH,
                serialization.PublicFormat.PKCS1,
            )
        with pytest.raises(ValueError):
            key.public_bytes(
                serialization.Encoding.OpenSSH,
                serialization.PublicFormat.SubjectPublicKeyInfo,
            )

    def test_public_bytes_invalid_encoding(self, backend):
        key = RSA_KEY_2048.private_key(backend).public_key()
        with pytest.raises(TypeError):
            key.public_bytes("notencoding", serialization.PublicFormat.PKCS1)

    def test_public_bytes_invalid_format(self, backend):
        key = RSA_KEY_2048.private_key(backend).public_key()
        with pytest.raises(TypeError):
            key.public_bytes(serialization.Encoding.PEM, "invalidformat")