diff options
Diffstat (limited to 'target/linux/generic/backport-5.4/080-wireguard-0135-wireguard-allowedips-free-empty-intermediate-nodes-w.patch')
-rw-r--r-- | target/linux/generic/backport-5.4/080-wireguard-0135-wireguard-allowedips-free-empty-intermediate-nodes-w.patch | 521 |
1 files changed, 521 insertions, 0 deletions
diff --git a/target/linux/generic/backport-5.4/080-wireguard-0135-wireguard-allowedips-free-empty-intermediate-nodes-w.patch b/target/linux/generic/backport-5.4/080-wireguard-0135-wireguard-allowedips-free-empty-intermediate-nodes-w.patch new file mode 100644 index 0000000000..c044ad25af --- /dev/null +++ b/target/linux/generic/backport-5.4/080-wireguard-0135-wireguard-allowedips-free-empty-intermediate-nodes-w.patch @@ -0,0 +1,521 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: "Jason A. Donenfeld" <Jason@zx2c4.com> +Date: Fri, 4 Jun 2021 17:17:38 +0200 +Subject: [PATCH] wireguard: allowedips: free empty intermediate nodes when + removing single node + +commit bf7b042dc62a31f66d3a41dd4dfc7806f267b307 upstream. + +When removing single nodes, it's possible that that node's parent is an +empty intermediate node, in which case, it too should be removed. +Otherwise the trie fills up and never is fully emptied, leading to +gradual memory leaks over time for tries that are modified often. There +was originally code to do this, but was removed during refactoring in +2016 and never reworked. Now that we have proper parent pointers from +the previous commits, we can implement this properly. + +In order to reduce branching and expensive comparisons, we want to keep +the double pointer for parent assignment (which lets us easily chain up +to the root), but we still need to actually get the parent's base +address. So encode the bit number into the last two bits of the pointer, +and pack and unpack it as needed. This is a little bit clumsy but is the +fastest and less memory wasteful of the compromises. Note that we align +the root struct here to a minimum of 4, because it's embedded into a +larger struct, and we're relying on having the bottom two bits for our +flag, which would only be 16-bit aligned on m68k. + +The existing macro-based helpers were a bit unwieldy for adding the bit +packing to, so this commit replaces them with safer and clearer ordinary +functions. + +We add a test to the randomized/fuzzer part of the selftests, to free +the randomized tries by-peer, refuzz it, and repeat, until it's supposed +to be empty, and then then see if that actually resulted in the whole +thing being emptied. That combined with kmemcheck should hopefully make +sure this commit is doing what it should. Along the way this resulted in +various other cleanups of the tests and fixes for recent graphviz. + +Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") +Cc: stable@vger.kernel.org +Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> +Signed-off-by: David S. Miller <davem@davemloft.net> +Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com> +--- + drivers/net/wireguard/allowedips.c | 102 ++++++------ + drivers/net/wireguard/allowedips.h | 4 +- + drivers/net/wireguard/selftest/allowedips.c | 162 ++++++++++---------- + 3 files changed, 137 insertions(+), 131 deletions(-) + +--- a/drivers/net/wireguard/allowedips.c ++++ b/drivers/net/wireguard/allowedips.c +@@ -30,8 +30,11 @@ static void copy_and_assign_cidr(struct + node->bitlen = bits; + memcpy(node->bits, src, bits / 8U); + } +-#define CHOOSE_NODE(parent, key) \ +- parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] ++ ++static inline u8 choose(struct allowedips_node *node, const u8 *key) ++{ ++ return (key[node->bit_at_a] >> node->bit_at_b) & 1; ++} + + static void push_rcu(struct allowedips_node **stack, + struct allowedips_node __rcu *p, unsigned int *len) +@@ -112,7 +115,7 @@ static struct allowedips_node *find_node + found = node; + if (node->cidr == bits) + break; +- node = rcu_dereference_bh(CHOOSE_NODE(node, key)); ++ node = rcu_dereference_bh(node->bit[choose(node, key)]); + } + return found; + } +@@ -144,8 +147,7 @@ static bool node_placement(struct allowe + u8 cidr, u8 bits, struct allowedips_node **rnode, + struct mutex *lock) + { +- struct allowedips_node *node = rcu_dereference_protected(trie, +- lockdep_is_held(lock)); ++ struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); + struct allowedips_node *parent = NULL; + bool exact = false; + +@@ -155,13 +157,24 @@ static bool node_placement(struct allowe + exact = true; + break; + } +- node = rcu_dereference_protected(CHOOSE_NODE(parent, key), +- lockdep_is_held(lock)); ++ node = rcu_dereference_protected(parent->bit[choose(parent, key)], lockdep_is_held(lock)); + } + *rnode = parent; + return exact; + } + ++static inline void connect_node(struct allowedips_node **parent, u8 bit, struct allowedips_node *node) ++{ ++ node->parent_bit_packed = (unsigned long)parent | bit; ++ rcu_assign_pointer(*parent, node); ++} ++ ++static inline void choose_and_connect_node(struct allowedips_node *parent, struct allowedips_node *node) ++{ ++ u8 bit = choose(parent, node->bits); ++ connect_node(&parent->bit[bit], bit, node); ++} ++ + static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, + u8 cidr, struct wg_peer *peer, struct mutex *lock) + { +@@ -177,8 +190,7 @@ static int add(struct allowedips_node __ + RCU_INIT_POINTER(node->peer, peer); + list_add_tail(&node->peer_list, &peer->allowedips_list); + copy_and_assign_cidr(node, key, cidr, bits); +- rcu_assign_pointer(node->parent_bit, trie); +- rcu_assign_pointer(*trie, node); ++ connect_node(trie, 2, node); + return 0; + } + if (node_placement(*trie, key, cidr, bits, &node, lock)) { +@@ -197,10 +209,10 @@ static int add(struct allowedips_node __ + if (!node) { + down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); + } else { +- down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); ++ const u8 bit = choose(node, key); ++ down = rcu_dereference_protected(node->bit[bit], lockdep_is_held(lock)); + if (!down) { +- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key)); +- rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); ++ connect_node(&node->bit[bit], bit, newnode); + return 0; + } + } +@@ -208,15 +220,11 @@ static int add(struct allowedips_node __ + parent = node; + + if (newnode->cidr == cidr) { +- rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits)); +- rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); +- if (!parent) { +- rcu_assign_pointer(newnode->parent_bit, trie); +- rcu_assign_pointer(*trie, newnode); +- } else { +- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits)); +- rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); +- } ++ choose_and_connect_node(newnode, down); ++ if (!parent) ++ connect_node(trie, 2, newnode); ++ else ++ choose_and_connect_node(parent, newnode); + return 0; + } + +@@ -229,17 +237,12 @@ static int add(struct allowedips_node __ + INIT_LIST_HEAD(&node->peer_list); + copy_and_assign_cidr(node, newnode->bits, cidr, bits); + +- rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits)); +- rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); +- rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits)); +- rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); +- if (!parent) { +- rcu_assign_pointer(node->parent_bit, trie); +- rcu_assign_pointer(*trie, node); +- } else { +- rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits)); +- rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); +- } ++ choose_and_connect_node(node, down); ++ choose_and_connect_node(node, newnode); ++ if (!parent) ++ connect_node(trie, 2, node); ++ else ++ choose_and_connect_node(parent, node); + return 0; + } + +@@ -297,7 +300,8 @@ int wg_allowedips_insert_v6(struct allow + void wg_allowedips_remove_by_peer(struct allowedips *table, + struct wg_peer *peer, struct mutex *lock) + { +- struct allowedips_node *node, *child, *tmp; ++ struct allowedips_node *node, *child, **parent_bit, *parent, *tmp; ++ bool free_parent; + + if (list_empty(&peer->allowedips_list)) + return; +@@ -307,19 +311,29 @@ void wg_allowedips_remove_by_peer(struct + RCU_INIT_POINTER(node->peer, NULL); + if (node->bit[0] && node->bit[1]) + continue; +- child = rcu_dereference_protected( +- node->bit[!rcu_access_pointer(node->bit[0])], +- lockdep_is_held(lock)); ++ child = rcu_dereference_protected(node->bit[!rcu_access_pointer(node->bit[0])], ++ lockdep_is_held(lock)); + if (child) +- child->parent_bit = node->parent_bit; +- *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child; ++ child->parent_bit_packed = node->parent_bit_packed; ++ parent_bit = (struct allowedips_node **)(node->parent_bit_packed & ~3UL); ++ *parent_bit = child; ++ parent = (void *)parent_bit - ++ offsetof(struct allowedips_node, bit[node->parent_bit_packed & 1]); ++ free_parent = !rcu_access_pointer(node->bit[0]) && ++ !rcu_access_pointer(node->bit[1]) && ++ (node->parent_bit_packed & 3) <= 1 && ++ !rcu_access_pointer(parent->peer); ++ if (free_parent) ++ child = rcu_dereference_protected( ++ parent->bit[!(node->parent_bit_packed & 1)], ++ lockdep_is_held(lock)); + call_rcu(&node->rcu, node_free_rcu); +- +- /* TODO: Note that we currently don't walk up and down in order to +- * free any potential filler nodes. This means that this function +- * doesn't free up as much as it could, which could be revisited +- * at some point. +- */ ++ if (!free_parent) ++ continue; ++ if (child) ++ child->parent_bit_packed = parent->parent_bit_packed; ++ *(struct allowedips_node **)(parent->parent_bit_packed & ~3UL) = child; ++ call_rcu(&parent->rcu, node_free_rcu); + } + } + +--- a/drivers/net/wireguard/allowedips.h ++++ b/drivers/net/wireguard/allowedips.h +@@ -19,7 +19,7 @@ struct allowedips_node { + u8 bits[16] __aligned(__alignof(u64)); + + /* Keep rarely used members at bottom to be beyond cache line. */ +- struct allowedips_node *__rcu *parent_bit; ++ unsigned long parent_bit_packed; + union { + struct list_head peer_list; + struct rcu_head rcu; +@@ -30,7 +30,7 @@ struct allowedips { + struct allowedips_node __rcu *root4; + struct allowedips_node __rcu *root6; + u64 seq; +-}; ++} __aligned(4); /* We pack the lower 2 bits of &root, but m68k only gives 16-bit alignment. */ + + void wg_allowedips_init(struct allowedips *table); + void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); +--- a/drivers/net/wireguard/selftest/allowedips.c ++++ b/drivers/net/wireguard/selftest/allowedips.c +@@ -19,32 +19,22 @@ + + #include <linux/siphash.h> + +-static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, +- u8 cidr) +-{ +- swap_endian(dst, src, bits); +- memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); +- if (cidr) +- dst[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); +-} +- + static __init void print_node(struct allowedips_node *node, u8 bits) + { + char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; +- char *fmt_declaration = KERN_DEBUG +- "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; ++ char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; ++ u8 ip1[16], ip2[16], cidr1, cidr2; + char *style = "dotted"; +- u8 ip1[16], ip2[16]; + u32 color = 0; + ++ if (node == NULL) ++ return; + if (bits == 32) { + fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; +- fmt_declaration = KERN_DEBUG +- "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; ++ fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; + } else if (bits == 128) { + fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; +- fmt_declaration = KERN_DEBUG +- "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; ++ fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; + } + if (node->peer) { + hsiphash_key_t key = { { 0 } }; +@@ -55,24 +45,20 @@ static __init void print_node(struct all + hsiphash_1u32(0xabad1dea, &key) % 200; + style = "bold"; + } +- swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); +- printk(fmt_declaration, ip1, node->cidr, style, color); ++ wg_allowedips_read_node(node, ip1, &cidr1); ++ printk(fmt_declaration, ip1, cidr1, style, color); + if (node->bit[0]) { +- swap_endian_and_apply_cidr(ip2, +- rcu_dereference_raw(node->bit[0])->bits, bits, +- node->cidr); +- printk(fmt_connection, ip1, node->cidr, ip2, +- rcu_dereference_raw(node->bit[0])->cidr); +- print_node(rcu_dereference_raw(node->bit[0]), bits); ++ wg_allowedips_read_node(rcu_dereference_raw(node->bit[0]), ip2, &cidr2); ++ printk(fmt_connection, ip1, cidr1, ip2, cidr2); + } + if (node->bit[1]) { +- swap_endian_and_apply_cidr(ip2, +- rcu_dereference_raw(node->bit[1])->bits, +- bits, node->cidr); +- printk(fmt_connection, ip1, node->cidr, ip2, +- rcu_dereference_raw(node->bit[1])->cidr); +- print_node(rcu_dereference_raw(node->bit[1]), bits); ++ wg_allowedips_read_node(rcu_dereference_raw(node->bit[1]), ip2, &cidr2); ++ printk(fmt_connection, ip1, cidr1, ip2, cidr2); + } ++ if (node->bit[0]) ++ print_node(rcu_dereference_raw(node->bit[0]), bits); ++ if (node->bit[1]) ++ print_node(rcu_dereference_raw(node->bit[1]), bits); + } + + static __init void print_tree(struct allowedips_node __rcu *top, u8 bits) +@@ -121,8 +107,8 @@ static __init inline union nf_inet_addr + { + union nf_inet_addr mask; + +- memset(&mask, 0x00, 128 / 8); +- memset(&mask, 0xff, cidr / 8); ++ memset(&mask, 0, sizeof(mask)); ++ memset(&mask.all, 0xff, cidr / 8); + if (cidr % 32) + mask.all[cidr / 32] = (__force u32)htonl( + (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); +@@ -149,42 +135,36 @@ horrible_mask_self(struct horrible_allow + } + + static __init inline bool +-horrible_match_v4(const struct horrible_allowedips_node *node, +- struct in_addr *ip) ++horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) + { + return (ip->s_addr & node->mask.ip) == node->ip.ip; + } + + static __init inline bool +-horrible_match_v6(const struct horrible_allowedips_node *node, +- struct in6_addr *ip) ++horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) + { +- return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == +- node->ip.ip6[0] && +- (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == +- node->ip.ip6[1] && +- (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == +- node->ip.ip6[2] && ++ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && ++ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && ++ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && + (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; + } + + static __init void +-horrible_insert_ordered(struct horrible_allowedips *table, +- struct horrible_allowedips_node *node) ++horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) + { + struct horrible_allowedips_node *other = NULL, *where = NULL; + u8 my_cidr = horrible_mask_to_cidr(node->mask); + + hlist_for_each_entry(other, &table->head, table) { +- if (!memcmp(&other->mask, &node->mask, +- sizeof(union nf_inet_addr)) && +- !memcmp(&other->ip, &node->ip, +- sizeof(union nf_inet_addr)) && +- other->ip_version == node->ip_version) { ++ if (other->ip_version == node->ip_version && ++ !memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && ++ !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr))) { + other->value = node->value; + kfree(node); + return; + } ++ } ++ hlist_for_each_entry(other, &table->head, table) { + where = other; + if (horrible_mask_to_cidr(other->mask) <= my_cidr) + break; +@@ -201,8 +181,7 @@ static __init int + horrible_allowedips_insert_v4(struct horrible_allowedips *table, + struct in_addr *ip, u8 cidr, void *value) + { +- struct horrible_allowedips_node *node = kzalloc(sizeof(*node), +- GFP_KERNEL); ++ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); + + if (unlikely(!node)) + return -ENOMEM; +@@ -219,8 +198,7 @@ static __init int + horrible_allowedips_insert_v6(struct horrible_allowedips *table, + struct in6_addr *ip, u8 cidr, void *value) + { +- struct horrible_allowedips_node *node = kzalloc(sizeof(*node), +- GFP_KERNEL); ++ struct horrible_allowedips_node *node = kzalloc(sizeof(*node), GFP_KERNEL); + + if (unlikely(!node)) + return -ENOMEM; +@@ -234,39 +212,43 @@ horrible_allowedips_insert_v6(struct hor + } + + static __init void * +-horrible_allowedips_lookup_v4(struct horrible_allowedips *table, +- struct in_addr *ip) ++horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) + { + struct horrible_allowedips_node *node; +- void *ret = NULL; + + hlist_for_each_entry(node, &table->head, table) { +- if (node->ip_version != 4) +- continue; +- if (horrible_match_v4(node, ip)) { +- ret = node->value; +- break; +- } ++ if (node->ip_version == 4 && horrible_match_v4(node, ip)) ++ return node->value; + } +- return ret; ++ return NULL; + } + + static __init void * +-horrible_allowedips_lookup_v6(struct horrible_allowedips *table, +- struct in6_addr *ip) ++horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) + { + struct horrible_allowedips_node *node; +- void *ret = NULL; + + hlist_for_each_entry(node, &table->head, table) { +- if (node->ip_version != 6) ++ if (node->ip_version == 6 && horrible_match_v6(node, ip)) ++ return node->value; ++ } ++ return NULL; ++} ++ ++ ++static __init void ++horrible_allowedips_remove_by_value(struct horrible_allowedips *table, void *value) ++{ ++ struct horrible_allowedips_node *node; ++ struct hlist_node *h; ++ ++ hlist_for_each_entry_safe(node, h, &table->head, table) { ++ if (node->value != value) + continue; +- if (horrible_match_v6(node, ip)) { +- ret = node->value; +- break; +- } ++ hlist_del(&node->table); ++ kfree(node); + } +- return ret; ++ + } + + static __init bool randomized_test(void) +@@ -397,23 +379,33 @@ static __init bool randomized_test(void) + print_tree(t.root6, 128); + } + +- for (i = 0; i < NUM_QUERIES; ++i) { +- prandom_bytes(ip, 4); +- if (lookup(t.root4, 32, ip) != +- horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { +- pr_err("allowedips random self-test: FAIL\n"); +- goto free; ++ for (j = 0;; ++j) { ++ for (i = 0; i < NUM_QUERIES; ++i) { ++ prandom_bytes(ip, 4); ++ if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { ++ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip); ++ pr_err("allowedips random v4 self-test: FAIL\n"); ++ goto free; ++ } ++ prandom_bytes(ip, 16); ++ if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { ++ pr_err("allowedips random v6 self-test: FAIL\n"); ++ goto free; ++ } + } ++ if (j >= NUM_PEERS) ++ break; ++ mutex_lock(&mutex); ++ wg_allowedips_remove_by_peer(&t, peers[j], &mutex); ++ mutex_unlock(&mutex); ++ horrible_allowedips_remove_by_value(&h, peers[j]); + } + +- for (i = 0; i < NUM_QUERIES; ++i) { +- prandom_bytes(ip, 16); +- if (lookup(t.root6, 128, ip) != +- horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { +- pr_err("allowedips random self-test: FAIL\n"); +- goto free; +- } ++ if (t.root4 || t.root6) { ++ pr_err("allowedips random self-test removal: FAIL\n"); ++ goto free; + } ++ + ret = true; + + free: |