aboutsummaryrefslogtreecommitdiffstats
path: root/target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch
diff options
context:
space:
mode:
Diffstat (limited to 'target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch')
-rw-r--r--target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch330
1 files changed, 330 insertions, 0 deletions
diff --git a/target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch b/target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch
new file mode 100644
index 0000000000..87d38d36fe
--- /dev/null
+++ b/target/linux/generic/backport-5.4/080-wireguard-0107-wireguard-noise-separate-receive-counter-from-send-c.patch
@@ -0,0 +1,330 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: "Jason A. Donenfeld" <Jason@zx2c4.com>
+Date: Tue, 19 May 2020 22:49:30 -0600
+Subject: [PATCH] wireguard: noise: separate receive counter from send counter
+
+commit a9e90d9931f3a474f04bab782ccd9d77904941e9 upstream.
+
+In "wireguard: queueing: preserve flow hash across packet scrubbing", we
+were required to slightly increase the size of the receive replay
+counter to something still fairly small, but an increase nonetheless.
+It turns out that we can recoup some of the additional memory overhead
+by splitting up the prior union type into two distinct types. Before, we
+used the same "noise_counter" union for both sending and receiving, with
+sending just using a simple atomic64_t, while receiving used the full
+replay counter checker. This meant that most of the memory being
+allocated for the sending counter was being wasted. Since the old
+"noise_counter" type increased in size in the prior commit, now is a
+good time to split up that union type into a distinct "noise_replay_
+counter" for receiving and a boring atomic64_t for sending, each using
+neither more nor less memory than required.
+
+Also, since sometimes the replay counter is accessed without
+necessitating additional accesses to the bitmap, we can reduce cache
+misses by hoisting the always-necessary lock above the bitmap in the
+struct layout. We also change a "noise_replay_counter" stack allocation
+to kmalloc in a -DDEBUG selftest so that KASAN doesn't trigger a stack
+frame warning.
+
+All and all, removing a bit of abstraction in this commit makes the code
+simpler and smaller, in addition to the motivating memory usage
+recuperation. For example, passing around raw "noise_symmetric_key"
+structs is something that really only makes sense within noise.c, in the
+one place where the sending and receiving keys can safely be thought of
+as the same type of object; subsequent to that, it's important that we
+uniformly access these through keypair->{sending,receiving}, where their
+distinct roles are always made explicit. So this patch allows us to draw
+that distinction clearly as well.
+
+Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
+Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
+Signed-off-by: David S. Miller <davem@davemloft.net>
+Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
+---
+ drivers/net/wireguard/noise.c | 16 +++------
+ drivers/net/wireguard/noise.h | 14 ++++----
+ drivers/net/wireguard/receive.c | 42 ++++++++++++------------
+ drivers/net/wireguard/selftest/counter.c | 17 +++++++---
+ drivers/net/wireguard/send.c | 12 +++----
+ 5 files changed, 48 insertions(+), 53 deletions(-)
+
+--- a/drivers/net/wireguard/noise.c
++++ b/drivers/net/wireguard/noise.c
+@@ -104,6 +104,7 @@ static struct noise_keypair *keypair_cre
+
+ if (unlikely(!keypair))
+ return NULL;
++ spin_lock_init(&keypair->receiving_counter.lock);
+ keypair->internal_id = atomic64_inc_return(&keypair_counter);
+ keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
+ keypair->entry.peer = peer;
+@@ -358,25 +359,16 @@ out:
+ memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
+ }
+
+-static void symmetric_key_init(struct noise_symmetric_key *key)
+-{
+- spin_lock_init(&key->counter.receive.lock);
+- atomic64_set(&key->counter.counter, 0);
+- memset(key->counter.receive.backtrack, 0,
+- sizeof(key->counter.receive.backtrack));
+- key->birthdate = ktime_get_coarse_boottime_ns();
+- key->is_valid = true;
+-}
+-
+ static void derive_keys(struct noise_symmetric_key *first_dst,
+ struct noise_symmetric_key *second_dst,
+ const u8 chaining_key[NOISE_HASH_LEN])
+ {
++ u64 birthdate = ktime_get_coarse_boottime_ns();
+ kdf(first_dst->key, second_dst->key, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ chaining_key);
+- symmetric_key_init(first_dst);
+- symmetric_key_init(second_dst);
++ first_dst->birthdate = second_dst->birthdate = birthdate;
++ first_dst->is_valid = second_dst->is_valid = true;
+ }
+
+ static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
+--- a/drivers/net/wireguard/noise.h
++++ b/drivers/net/wireguard/noise.h
+@@ -15,18 +15,14 @@
+ #include <linux/mutex.h>
+ #include <linux/kref.h>
+
+-union noise_counter {
+- struct {
+- u64 counter;
+- unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
+- spinlock_t lock;
+- } receive;
+- atomic64_t counter;
++struct noise_replay_counter {
++ u64 counter;
++ spinlock_t lock;
++ unsigned long backtrack[COUNTER_BITS_TOTAL / BITS_PER_LONG];
+ };
+
+ struct noise_symmetric_key {
+ u8 key[NOISE_SYMMETRIC_KEY_LEN];
+- union noise_counter counter;
+ u64 birthdate;
+ bool is_valid;
+ };
+@@ -34,7 +30,9 @@ struct noise_symmetric_key {
+ struct noise_keypair {
+ struct index_hashtable_entry entry;
+ struct noise_symmetric_key sending;
++ atomic64_t sending_counter;
+ struct noise_symmetric_key receiving;
++ struct noise_replay_counter receiving_counter;
+ __le32 remote_index;
+ bool i_am_the_initiator;
+ struct kref refcount;
+--- a/drivers/net/wireguard/receive.c
++++ b/drivers/net/wireguard/receive.c
+@@ -245,20 +245,20 @@ static void keep_key_fresh(struct wg_pee
+ }
+ }
+
+-static bool decrypt_packet(struct sk_buff *skb, struct noise_symmetric_key *key)
++static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
+ {
+ struct scatterlist sg[MAX_SKB_FRAGS + 8];
+ struct sk_buff *trailer;
+ unsigned int offset;
+ int num_frags;
+
+- if (unlikely(!key))
++ if (unlikely(!keypair))
+ return false;
+
+- if (unlikely(!READ_ONCE(key->is_valid) ||
+- wg_birthdate_has_expired(key->birthdate, REJECT_AFTER_TIME) ||
+- key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
+- WRITE_ONCE(key->is_valid, false);
++ if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
++ wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
++ keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
++ WRITE_ONCE(keypair->receiving.is_valid, false);
+ return false;
+ }
+
+@@ -283,7 +283,7 @@ static bool decrypt_packet(struct sk_buf
+
+ if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
+ PACKET_CB(skb)->nonce,
+- key->key))
++ keypair->receiving.key))
+ return false;
+
+ /* Another ugly situation of pushing and pulling the header so as to
+@@ -298,41 +298,41 @@ static bool decrypt_packet(struct sk_buf
+ }
+
+ /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
+-static bool counter_validate(union noise_counter *counter, u64 their_counter)
++static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
+ {
+ unsigned long index, index_current, top, i;
+ bool ret = false;
+
+- spin_lock_bh(&counter->receive.lock);
++ spin_lock_bh(&counter->lock);
+
+- if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
++ if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
+ their_counter >= REJECT_AFTER_MESSAGES))
+ goto out;
+
+ ++their_counter;
+
+ if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
+- counter->receive.counter))
++ counter->counter))
+ goto out;
+
+ index = their_counter >> ilog2(BITS_PER_LONG);
+
+- if (likely(their_counter > counter->receive.counter)) {
+- index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
++ if (likely(their_counter > counter->counter)) {
++ index_current = counter->counter >> ilog2(BITS_PER_LONG);
+ top = min_t(unsigned long, index - index_current,
+ COUNTER_BITS_TOTAL / BITS_PER_LONG);
+ for (i = 1; i <= top; ++i)
+- counter->receive.backtrack[(i + index_current) &
++ counter->backtrack[(i + index_current) &
+ ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
+- counter->receive.counter = their_counter;
++ counter->counter = their_counter;
+ }
+
+ index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
+ ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
+- &counter->receive.backtrack[index]);
++ &counter->backtrack[index]);
+
+ out:
+- spin_unlock_bh(&counter->receive.lock);
++ spin_unlock_bh(&counter->lock);
+ return ret;
+ }
+
+@@ -472,12 +472,12 @@ int wg_packet_rx_poll(struct napi_struct
+ if (unlikely(state != PACKET_STATE_CRYPTED))
+ goto next;
+
+- if (unlikely(!counter_validate(&keypair->receiving.counter,
++ if (unlikely(!counter_validate(&keypair->receiving_counter,
+ PACKET_CB(skb)->nonce))) {
+ net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
+ peer->device->dev->name,
+ PACKET_CB(skb)->nonce,
+- keypair->receiving.counter.receive.counter);
++ keypair->receiving_counter.counter);
+ goto next;
+ }
+
+@@ -511,8 +511,8 @@ void wg_packet_decrypt_worker(struct wor
+ struct sk_buff *skb;
+
+ while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
+- enum packet_state state = likely(decrypt_packet(skb,
+- &PACKET_CB(skb)->keypair->receiving)) ?
++ enum packet_state state =
++ likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
+ PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
+ wg_queue_enqueue_per_peer_napi(skb, state);
+ if (need_resched())
+--- a/drivers/net/wireguard/selftest/counter.c
++++ b/drivers/net/wireguard/selftest/counter.c
+@@ -6,18 +6,24 @@
+ #ifdef DEBUG
+ bool __init wg_packet_counter_selftest(void)
+ {
++ struct noise_replay_counter *counter;
+ unsigned int test_num = 0, i;
+- union noise_counter counter;
+ bool success = true;
+
+-#define T_INIT do { \
+- memset(&counter, 0, sizeof(union noise_counter)); \
+- spin_lock_init(&counter.receive.lock); \
++ counter = kmalloc(sizeof(*counter), GFP_KERNEL);
++ if (unlikely(!counter)) {
++ pr_err("nonce counter self-test malloc: FAIL\n");
++ return false;
++ }
++
++#define T_INIT do { \
++ memset(counter, 0, sizeof(*counter)); \
++ spin_lock_init(&counter->lock); \
+ } while (0)
+ #define T_LIM (COUNTER_WINDOW_SIZE + 1)
+ #define T(n, v) do { \
+ ++test_num; \
+- if (counter_validate(&counter, n) != (v)) { \
++ if (counter_validate(counter, n) != (v)) { \
+ pr_err("nonce counter self-test %u: FAIL\n", \
+ test_num); \
+ success = false; \
+@@ -99,6 +105,7 @@ bool __init wg_packet_counter_selftest(v
+
+ if (success)
+ pr_info("nonce counter self-tests: pass\n");
++ kfree(counter);
+ return success;
+ }
+ #endif
+--- a/drivers/net/wireguard/send.c
++++ b/drivers/net/wireguard/send.c
+@@ -129,7 +129,7 @@ static void keep_key_fresh(struct wg_pee
+ rcu_read_lock_bh();
+ keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
+ send = keypair && READ_ONCE(keypair->sending.is_valid) &&
+- (atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES ||
++ (atomic64_read(&keypair->sending_counter) > REKEY_AFTER_MESSAGES ||
+ (keypair->i_am_the_initiator &&
+ wg_birthdate_has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)));
+ rcu_read_unlock_bh();
+@@ -349,7 +349,6 @@ void wg_packet_purge_staged_packets(stru
+
+ void wg_packet_send_staged_packets(struct wg_peer *peer)
+ {
+- struct noise_symmetric_key *key;
+ struct noise_keypair *keypair;
+ struct sk_buff_head packets;
+ struct sk_buff *skb;
+@@ -369,10 +368,9 @@ void wg_packet_send_staged_packets(struc
+ rcu_read_unlock_bh();
+ if (unlikely(!keypair))
+ goto out_nokey;
+- key = &keypair->sending;
+- if (unlikely(!READ_ONCE(key->is_valid)))
++ if (unlikely(!READ_ONCE(keypair->sending.is_valid)))
+ goto out_nokey;
+- if (unlikely(wg_birthdate_has_expired(key->birthdate,
++ if (unlikely(wg_birthdate_has_expired(keypair->sending.birthdate,
+ REJECT_AFTER_TIME)))
+ goto out_invalid;
+
+@@ -387,7 +385,7 @@ void wg_packet_send_staged_packets(struc
+ */
+ PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
+ PACKET_CB(skb)->nonce =
+- atomic64_inc_return(&key->counter.counter) - 1;
++ atomic64_inc_return(&keypair->sending_counter) - 1;
+ if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
+ goto out_invalid;
+ }
+@@ -399,7 +397,7 @@ void wg_packet_send_staged_packets(struc
+ return;
+
+ out_invalid:
+- WRITE_ONCE(key->is_valid, false);
++ WRITE_ONCE(keypair->sending.is_valid, false);
+ out_nokey:
+ wg_noise_keypair_put(keypair, false);
+