1  // SPDX-License-Identifier: GPL-2.0-only
2  /*
3   * vsock sock_diag(7) module
4   *
5   * Copyright (C) 2017 Red Hat, Inc.
6   * Author: Stefan Hajnoczi <stefanha@redhat.com>
7   */
8  
9  #include <linux/module.h>
10  #include <linux/sock_diag.h>
11  #include <linux/vm_sockets_diag.h>
12  #include <net/af_vsock.h>
13  
14  static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
15  			u32 portid, u32 seq, u32 flags)
16  {
17  	struct vsock_sock *vsk = vsock_sk(sk);
18  	struct vsock_diag_msg *rep;
19  	struct nlmsghdr *nlh;
20  
21  	nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep),
22  			flags);
23  	if (!nlh)
24  		return -EMSGSIZE;
25  
26  	rep = nlmsg_data(nlh);
27  	rep->vdiag_family = AF_VSOCK;
28  
29  	/* Lock order dictates that sk_lock is acquired before
30  	 * vsock_table_lock, so we cannot lock here.  Simply don't take
31  	 * sk_lock; sk is guaranteed to stay alive since vsock_table_lock is
32  	 * held.
33  	 */
34  	rep->vdiag_type = sk->sk_type;
35  	rep->vdiag_state = sk->sk_state;
36  	rep->vdiag_shutdown = sk->sk_shutdown;
37  	rep->vdiag_src_cid = vsk->local_addr.svm_cid;
38  	rep->vdiag_src_port = vsk->local_addr.svm_port;
39  	rep->vdiag_dst_cid = vsk->remote_addr.svm_cid;
40  	rep->vdiag_dst_port = vsk->remote_addr.svm_port;
41  	rep->vdiag_ino = sock_i_ino(sk);
42  
43  	sock_diag_save_cookie(sk, rep->vdiag_cookie);
44  
45  	return 0;
46  }
47  
48  static int vsock_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
49  {
50  	struct vsock_diag_req *req;
51  	struct vsock_sock *vsk;
52  	unsigned int bucket;
53  	unsigned int last_i;
54  	unsigned int table;
55  	struct net *net;
56  	unsigned int i;
57  
58  	req = nlmsg_data(cb->nlh);
59  	net = sock_net(skb->sk);
60  
61  	/* State saved between calls: */
62  	table = cb->args[0];
63  	bucket = cb->args[1];
64  	i = last_i = cb->args[2];
65  
66  	/* TODO VMCI pending sockets? */
67  
68  	spin_lock_bh(&vsock_table_lock);
69  
70  	/* Bind table (locally created sockets) */
71  	if (table == 0) {
72  		while (bucket < ARRAY_SIZE(vsock_bind_table)) {
73  			struct list_head *head = &vsock_bind_table[bucket];
74  
75  			i = 0;
76  			list_for_each_entry(vsk, head, bound_table) {
77  				struct sock *sk = sk_vsock(vsk);
78  
79  				if (!net_eq(sock_net(sk), net))
80  					continue;
81  				if (i < last_i)
82  					goto next_bind;
83  				if (!(req->vdiag_states & (1 << sk->sk_state)))
84  					goto next_bind;
85  				if (sk_diag_fill(sk, skb,
86  						 NETLINK_CB(cb->skb).portid,
87  						 cb->nlh->nlmsg_seq,
88  						 NLM_F_MULTI) < 0)
89  					goto done;
90  next_bind:
91  				i++;
92  			}
93  			last_i = 0;
94  			bucket++;
95  		}
96  
97  		table++;
98  		bucket = 0;
99  	}
100  
101  	/* Connected table (accepted connections) */
102  	while (bucket < ARRAY_SIZE(vsock_connected_table)) {
103  		struct list_head *head = &vsock_connected_table[bucket];
104  
105  		i = 0;
106  		list_for_each_entry(vsk, head, connected_table) {
107  			struct sock *sk = sk_vsock(vsk);
108  
109  			/* Skip sockets we've already seen above */
110  			if (__vsock_in_bound_table(vsk))
111  				continue;
112  
113  			if (!net_eq(sock_net(sk), net))
114  				continue;
115  			if (i < last_i)
116  				goto next_connected;
117  			if (!(req->vdiag_states & (1 << sk->sk_state)))
118  				goto next_connected;
119  			if (sk_diag_fill(sk, skb,
120  					 NETLINK_CB(cb->skb).portid,
121  					 cb->nlh->nlmsg_seq,
122  					 NLM_F_MULTI) < 0)
123  				goto done;
124  next_connected:
125  			i++;
126  		}
127  		last_i = 0;
128  		bucket++;
129  	}
130  
131  done:
132  	spin_unlock_bh(&vsock_table_lock);
133  
134  	cb->args[0] = table;
135  	cb->args[1] = bucket;
136  	cb->args[2] = i;
137  
138  	return skb->len;
139  }
140  
141  static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
142  {
143  	int hdrlen = sizeof(struct vsock_diag_req);
144  	struct net *net = sock_net(skb->sk);
145  
146  	if (nlmsg_len(h) < hdrlen)
147  		return -EINVAL;
148  
149  	if (h->nlmsg_flags & NLM_F_DUMP) {
150  		struct netlink_dump_control c = {
151  			.dump = vsock_diag_dump,
152  		};
153  		return netlink_dump_start(net->diag_nlsk, skb, h, &c);
154  	}
155  
156  	return -EOPNOTSUPP;
157  }
158  
159  static const struct sock_diag_handler vsock_diag_handler = {
160  	.owner = THIS_MODULE,
161  	.family = AF_VSOCK,
162  	.dump = vsock_diag_handler_dump,
163  };
164  
165  static int __init vsock_diag_init(void)
166  {
167  	return sock_diag_register(&vsock_diag_handler);
168  }
169  
170  static void __exit vsock_diag_exit(void)
171  {
172  	sock_diag_unregister(&vsock_diag_handler);
173  }
174  
175  module_init(vsock_diag_init);
176  module_exit(vsock_diag_exit);
177  MODULE_LICENSE("GPL");
178  MODULE_DESCRIPTION("VMware Virtual Sockets monitoring via SOCK_DIAG");
179  MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG,
180  			       40 /* AF_VSOCK */);
181