aboutsummaryrefslogtreecommitdiffstats
path: root/target/linux/generic/backport-5.4/080-wireguard-0108-wireguard-device-avoid-circular-netns-references.patch
blob: 8021b9bf23a9639a408213ad806213b1d6b503ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
From 40d881393cfc6953778691444ab27a29d51d24aa Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Tue, 23 Jun 2020 03:59:45 -0600
Subject: [PATCH 108/124] wireguard: device: avoid circular netns references

commit 900575aa33a3eaaef802b31de187a85c4a4b4bd0 upstream.

Before, we took a reference to the creating netns if the new netns was
different. This caused issues with circular references, with two
wireguard interfaces swapping namespaces. The solution is to rather not
take any extra references at all, but instead simply invalidate the
creating netns pointer when that netns is deleted.

In order to prevent this from happening again, this commit improves the
rough object leak tracking by allowing it to account for created and
destroyed interfaces, aside from just peers and keys. That then makes it
possible to check for the object leak when having two interfaces take a
reference to each others' namespaces.

Fixes: e7096c131e51 ("net: WireGuard secure network tunnel")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
---
 drivers/net/wireguard/device.c             | 58 ++++++++++------------
 drivers/net/wireguard/device.h             |  3 +-
 drivers/net/wireguard/netlink.c            | 14 ++++--
 drivers/net/wireguard/socket.c             | 25 +++++++---
 tools/testing/selftests/wireguard/netns.sh | 13 ++++-
 5 files changed, 67 insertions(+), 46 deletions(-)

--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -45,17 +45,18 @@ static int wg_open(struct net_device *de
 	if (dev_v6)
 		dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
 
+	mutex_lock(&wg->device_update_lock);
 	ret = wg_socket_init(wg, wg->incoming_port);
 	if (ret < 0)
-		return ret;
-	mutex_lock(&wg->device_update_lock);
+		goto out;
 	list_for_each_entry(peer, &wg->peer_list, peer_list) {
 		wg_packet_send_staged_packets(peer);
 		if (peer->persistent_keepalive_interval)
 			wg_packet_send_keepalive(peer);
 	}
+out:
 	mutex_unlock(&wg->device_update_lock);
-	return 0;
+	return ret;
 }
 
 #ifdef CONFIG_PM_SLEEP
@@ -225,6 +226,7 @@ static void wg_destruct(struct net_devic
 	list_del(&wg->device_list);
 	rtnl_unlock();
 	mutex_lock(&wg->device_update_lock);
+	rcu_assign_pointer(wg->creating_net, NULL);
 	wg->incoming_port = 0;
 	wg_socket_reinit(wg, NULL, NULL);
 	/* The final references are cleared in the below calls to destroy_workqueue. */
@@ -240,13 +242,11 @@ static void wg_destruct(struct net_devic
 	skb_queue_purge(&wg->incoming_handshakes);
 	free_percpu(dev->tstats);
 	free_percpu(wg->incoming_handshakes_worker);
-	if (wg->have_creating_net_ref)
-		put_net(wg->creating_net);
 	kvfree(wg->index_hashtable);
 	kvfree(wg->peer_hashtable);
 	mutex_unlock(&wg->device_update_lock);
 
-	pr_debug("%s: Interface deleted\n", dev->name);
+	pr_debug("%s: Interface destroyed\n", dev->name);
 	free_netdev(dev);
 }
 
@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_ne
 	struct wg_device *wg = netdev_priv(dev);
 	int ret = -ENOMEM;
 
-	wg->creating_net = src_net;
+	rcu_assign_pointer(wg->creating_net, src_net);
 	init_rwsem(&wg->static_identity.lock);
 	mutex_init(&wg->socket_update_lock);
 	mutex_init(&wg->device_update_lock);
@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __r
 	.newlink		= wg_newlink,
 };
 
-static int wg_netdevice_notification(struct notifier_block *nb,
-				     unsigned long action, void *data)
+static void wg_netns_pre_exit(struct net *net)
 {
-	struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
-	struct wg_device *wg = netdev_priv(dev);
-
-	ASSERT_RTNL();
-
-	if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
-		return 0;
+	struct wg_device *wg;
 
-	if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
-		put_net(wg->creating_net);
-		wg->have_creating_net_ref = false;
-	} else if (dev_net(dev) != wg->creating_net &&
-		   !wg->have_creating_net_ref) {
-		wg->have_creating_net_ref = true;
-		get_net(wg->creating_net);
+	rtnl_lock();
+	list_for_each_entry(wg, &device_list, device_list) {
+		if (rcu_access_pointer(wg->creating_net) == net) {
+			pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
+			netif_carrier_off(wg->dev);
+			mutex_lock(&wg->device_update_lock);
+			rcu_assign_pointer(wg->creating_net, NULL);
+			wg_socket_reinit(wg, NULL, NULL);
+			mutex_unlock(&wg->device_update_lock);
+		}
 	}
-	return 0;
+	rtnl_unlock();
 }
 
-static struct notifier_block netdevice_notifier = {
-	.notifier_call = wg_netdevice_notification
+static struct pernet_operations pernet_ops = {
+	.pre_exit = wg_netns_pre_exit
 };
 
 int __init wg_device_init(void)
@@ -429,18 +425,18 @@ int __init wg_device_init(void)
 		return ret;
 #endif
 
-	ret = register_netdevice_notifier(&netdevice_notifier);
+	ret = register_pernet_device(&pernet_ops);
 	if (ret)
 		goto error_pm;
 
 	ret = rtnl_link_register(&link_ops);
 	if (ret)
-		goto error_netdevice;
+		goto error_pernet;
 
 	return 0;
 
-error_netdevice:
-	unregister_netdevice_notifier(&netdevice_notifier);
+error_pernet:
+	unregister_pernet_device(&pernet_ops);
 error_pm:
 #ifdef CONFIG_PM_SLEEP
 	unregister_pm_notifier(&pm_notifier);
@@ -451,7 +447,7 @@ error_pm:
 void wg_device_uninit(void)
 {
 	rtnl_link_unregister(&link_ops);
-	unregister_netdevice_notifier(&netdevice_notifier);
+	unregister_pernet_device(&pernet_ops);
 #ifdef CONFIG_PM_SLEEP
 	unregister_pm_notifier(&pm_notifier);
 #endif
--- a/drivers/net/wireguard/device.h
+++ b/drivers/net/wireguard/device.h
@@ -40,7 +40,7 @@ struct wg_device {
 	struct net_device *dev;
 	struct crypt_queue encrypt_queue, decrypt_queue;
 	struct sock __rcu *sock4, *sock6;
-	struct net *creating_net;
+	struct net __rcu *creating_net;
 	struct noise_static_identity static_identity;
 	struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
 	struct workqueue_struct *packet_crypt_wq;
@@ -56,7 +56,6 @@ struct wg_device {
 	unsigned int num_peers, device_update_gen;
 	u32 fwmark;
 	u16 incoming_port;
-	bool have_creating_net_ref;
 };
 
 int wg_device_init(void);
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -517,11 +517,15 @@ static int wg_set_device(struct sk_buff
 	if (flags & ~__WGDEVICE_F_ALL)
 		goto out;
 
-	ret = -EPERM;
-	if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
-	     info->attrs[WGDEVICE_A_FWMARK]) &&
-	    !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
-		goto out;
+	if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
+		struct net *net;
+		rcu_read_lock();
+		net = rcu_dereference(wg->creating_net);
+		ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
+		rcu_read_unlock();
+		if (ret)
+			goto out;
+	}
 
 	++wg->device_update_gen;
 
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket
 
 int wg_socket_init(struct wg_device *wg, u16 port)
 {
+	struct net *net;
 	int ret;
 	struct udp_tunnel_sock_cfg cfg = {
 		.sk_user_data = wg,
@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg,
 	};
 #endif
 
+	rcu_read_lock();
+	net = rcu_dereference(wg->creating_net);
+	net = net ? maybe_get_net(net) : NULL;
+	rcu_read_unlock();
+	if (unlikely(!net))
+		return -ENONET;
+
 #if IS_ENABLED(CONFIG_IPV6)
 retry:
 #endif
 
-	ret = udp_sock_create(wg->creating_net, &port4, &new4);
+	ret = udp_sock_create(net, &port4, &new4);
 	if (ret < 0) {
 		pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
-		return ret;
+		goto out;
 	}
 	set_sock_opts(new4);
-	setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
+	setup_udp_tunnel_sock(net, new4, &cfg);
 
 #if IS_ENABLED(CONFIG_IPV6)
 	if (ipv6_mod_enabled()) {
 		port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
-		ret = udp_sock_create(wg->creating_net, &port6, &new6);
+		ret = udp_sock_create(net, &port6, &new6);
 		if (ret < 0) {
 			udp_tunnel_sock_release(new4);
 			if (ret == -EADDRINUSE && !port && retries++ < 100)
 				goto retry;
 			pr_err("%s: Could not create IPv6 socket\n",
 			       wg->dev->name);
-			return ret;
+			goto out;
 		}
 		set_sock_opts(new6);
-		setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
+		setup_udp_tunnel_sock(net, new6, &cfg);
 	}
 #endif
 
 	wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
-	return 0;
+	ret = 0;
+out:
+	put_net(net);
+	return ret;
 }
 
 void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
--- a/tools/testing/selftests/wireguard/netns.sh
+++ b/tools/testing/selftests/wireguard/netns.sh
@@ -587,9 +587,20 @@ ip0 link set wg0 up
 kill $ncat_pid
 ip0 link del wg0
 
+# Ensure there aren't circular reference loops
+ip1 link add wg1 type wireguard
+ip2 link add wg2 type wireguard
+ip1 link set wg1 netns $netns2
+ip2 link set wg2 netns $netns1
+pp ip netns delete $netns1
+pp ip netns delete $netns2
+pp ip netns add $netns1
+pp ip netns add $netns2
+
+sleep 2 # Wait for cleanup and grace periods
 declare -A objects
 while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
-	[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
+	[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
 	objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
 done < /dev/kmsg
 alldeleted=1