1 // SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2 #include <errno.h>
3 #include <poll.h>
4 #include <string.h>
5 #include <stdlib.h>
6 #include <linux/types.h>
7 
8 #include <libmnl/libmnl.h>
9 #include <linux/genetlink.h>
10 
11 #include "ynl.h"
12 
13 #define ARRAY_SIZE(arr)		(sizeof(arr) / sizeof(*arr))
14 
15 #define __yerr_msg(yse, _msg...)					\
16 	({								\
17 		struct ynl_error *_yse = (yse);				\
18 									\
19 		if (_yse) {						\
20 			snprintf(_yse->msg, sizeof(_yse->msg) - 1,  _msg); \
21 			_yse->msg[sizeof(_yse->msg) - 1] = 0;		\
22 		}							\
23 	})
24 
25 #define __yerr_code(yse, _code...)		\
26 	({					\
27 		struct ynl_error *_yse = (yse);	\
28 						\
29 		if (_yse) {			\
30 			_yse->code = _code;	\
31 		}				\
32 	})
33 
34 #define __yerr(yse, _code, _msg...)		\
35 	({					\
36 		__yerr_msg(yse, _msg);		\
37 		__yerr_code(yse, _code);	\
38 	})
39 
40 #define __perr(yse, _msg)		__yerr(yse, errno, _msg)
41 
42 #define yerr_msg(_ys, _msg...)		__yerr_msg(&(_ys)->err, _msg)
43 #define yerr(_ys, _code, _msg...)	__yerr(&(_ys)->err, _code, _msg)
44 #define perr(_ys, _msg)			__yerr(&(_ys)->err, errno, _msg)
45 
46 /* -- Netlink boiler plate */
47 static int
ynl_err_walk_report_one(struct ynl_policy_nest * policy,unsigned int type,char * str,int str_sz,int * n)48 ynl_err_walk_report_one(struct ynl_policy_nest *policy, unsigned int type,
49 			char *str, int str_sz, int *n)
50 {
51 	if (!policy) {
52 		if (*n < str_sz)
53 			*n += snprintf(str, str_sz, "!policy");
54 		return 1;
55 	}
56 
57 	if (type > policy->max_attr) {
58 		if (*n < str_sz)
59 			*n += snprintf(str, str_sz, "!oob");
60 		return 1;
61 	}
62 
63 	if (!policy->table[type].name) {
64 		if (*n < str_sz)
65 			*n += snprintf(str, str_sz, "!name");
66 		return 1;
67 	}
68 
69 	if (*n < str_sz)
70 		*n += snprintf(str, str_sz - *n,
71 			       ".%s", policy->table[type].name);
72 	return 0;
73 }
74 
75 static int
ynl_err_walk(struct ynl_sock * ys,void * start,void * end,unsigned int off,struct ynl_policy_nest * policy,char * str,int str_sz,struct ynl_policy_nest ** nest_pol)76 ynl_err_walk(struct ynl_sock *ys, void *start, void *end, unsigned int off,
77 	     struct ynl_policy_nest *policy, char *str, int str_sz,
78 	     struct ynl_policy_nest **nest_pol)
79 {
80 	unsigned int astart_off, aend_off;
81 	const struct nlattr *attr;
82 	unsigned int data_len;
83 	unsigned int type;
84 	bool found = false;
85 	int n = 0;
86 
87 	if (!policy) {
88 		if (n < str_sz)
89 			n += snprintf(str, str_sz, "!policy");
90 		return n;
91 	}
92 
93 	data_len = end - start;
94 
95 	mnl_attr_for_each_payload(start, data_len) {
96 		astart_off = (char *)attr - (char *)start;
97 		aend_off = astart_off + mnl_attr_get_payload_len(attr);
98 		if (aend_off <= off)
99 			continue;
100 
101 		found = true;
102 		break;
103 	}
104 	if (!found)
105 		return 0;
106 
107 	off -= astart_off;
108 
109 	type = mnl_attr_get_type(attr);
110 
111 	if (ynl_err_walk_report_one(policy, type, str, str_sz, &n))
112 		return n;
113 
114 	if (!off) {
115 		if (nest_pol)
116 			*nest_pol = policy->table[type].nest;
117 		return n;
118 	}
119 
120 	if (!policy->table[type].nest) {
121 		if (n < str_sz)
122 			n += snprintf(str, str_sz, "!nest");
123 		return n;
124 	}
125 
126 	off -= sizeof(struct nlattr);
127 	start =  mnl_attr_get_payload(attr);
128 	end = start + mnl_attr_get_payload_len(attr);
129 
130 	return n + ynl_err_walk(ys, start, end, off, policy->table[type].nest,
131 				&str[n], str_sz - n, nest_pol);
132 }
133 
134 #define NLMSGERR_ATTR_MISS_TYPE (NLMSGERR_ATTR_POLICY + 1)
135 #define NLMSGERR_ATTR_MISS_NEST (NLMSGERR_ATTR_POLICY + 2)
136 #define NLMSGERR_ATTR_MAX (NLMSGERR_ATTR_MAX + 2)
137 
138 static int
ynl_ext_ack_check(struct ynl_sock * ys,const struct nlmsghdr * nlh,unsigned int hlen)139 ynl_ext_ack_check(struct ynl_sock *ys, const struct nlmsghdr *nlh,
140 		  unsigned int hlen)
141 {
142 	const struct nlattr *tb[NLMSGERR_ATTR_MAX + 1] = {};
143 	char miss_attr[sizeof(ys->err.msg)];
144 	char bad_attr[sizeof(ys->err.msg)];
145 	const struct nlattr *attr;
146 	const char *str = NULL;
147 
148 	if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) {
149 		yerr_msg(ys, "%s", strerror(ys->err.code));
150 		return MNL_CB_OK;
151 	}
152 
153 	mnl_attr_for_each(attr, nlh, hlen) {
154 		unsigned int len, type;
155 
156 		len = mnl_attr_get_payload_len(attr);
157 		type = mnl_attr_get_type(attr);
158 
159 		if (type > NLMSGERR_ATTR_MAX)
160 			continue;
161 
162 		tb[type] = attr;
163 
164 		switch (type) {
165 		case NLMSGERR_ATTR_OFFS:
166 		case NLMSGERR_ATTR_MISS_TYPE:
167 		case NLMSGERR_ATTR_MISS_NEST:
168 			if (len != sizeof(__u32))
169 				return MNL_CB_ERROR;
170 			break;
171 		case NLMSGERR_ATTR_MSG:
172 			str = mnl_attr_get_payload(attr);
173 			if (str[len - 1])
174 				return MNL_CB_ERROR;
175 			break;
176 		default:
177 			break;
178 		}
179 	}
180 
181 	bad_attr[0] = '\0';
182 	miss_attr[0] = '\0';
183 
184 	if (tb[NLMSGERR_ATTR_OFFS]) {
185 		unsigned int n, off;
186 		void *start, *end;
187 
188 		ys->err.attr_offs = mnl_attr_get_u32(tb[NLMSGERR_ATTR_OFFS]);
189 
190 		n = snprintf(bad_attr, sizeof(bad_attr), "%sbad attribute: ",
191 			     str ? " (" : "");
192 
193 		start = mnl_nlmsg_get_payload_offset(ys->nlh,
194 						     ys->family->hdr_len);
195 		end = mnl_nlmsg_get_payload_tail(ys->nlh);
196 
197 		off = ys->err.attr_offs;
198 		off -= sizeof(struct nlmsghdr);
199 		off -= ys->family->hdr_len;
200 
201 		n += ynl_err_walk(ys, start, end, off, ys->req_policy,
202 				  &bad_attr[n], sizeof(bad_attr) - n, NULL);
203 
204 		if (n >= sizeof(bad_attr))
205 			n = sizeof(bad_attr) - 1;
206 		bad_attr[n] = '\0';
207 	}
208 	if (tb[NLMSGERR_ATTR_MISS_TYPE]) {
209 		struct ynl_policy_nest *nest_pol = NULL;
210 		unsigned int n, off, type;
211 		void *start, *end;
212 		int n2;
213 
214 		type = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_TYPE]);
215 
216 		n = snprintf(miss_attr, sizeof(miss_attr), "%smissing attribute: ",
217 			     bad_attr[0] ? ", " : (str ? " (" : ""));
218 
219 		start = mnl_nlmsg_get_payload_offset(ys->nlh,
220 						     ys->family->hdr_len);
221 		end = mnl_nlmsg_get_payload_tail(ys->nlh);
222 
223 		nest_pol = ys->req_policy;
224 		if (tb[NLMSGERR_ATTR_MISS_NEST]) {
225 			off = mnl_attr_get_u32(tb[NLMSGERR_ATTR_MISS_NEST]);
226 			off -= sizeof(struct nlmsghdr);
227 			off -= ys->family->hdr_len;
228 
229 			n += ynl_err_walk(ys, start, end, off, ys->req_policy,
230 					  &miss_attr[n], sizeof(miss_attr) - n,
231 					  &nest_pol);
232 		}
233 
234 		n2 = 0;
235 		ynl_err_walk_report_one(nest_pol, type, &miss_attr[n],
236 					sizeof(miss_attr) - n, &n2);
237 		n += n2;
238 
239 		if (n >= sizeof(miss_attr))
240 			n = sizeof(miss_attr) - 1;
241 		miss_attr[n] = '\0';
242 	}
243 
244 	/* Implicitly depend on ys->err.code already set */
245 	if (str)
246 		yerr_msg(ys, "Kernel %s: '%s'%s%s%s",
247 			 ys->err.code ? "error" : "warning",
248 			 str, bad_attr, miss_attr,
249 			 bad_attr[0] || miss_attr[0] ? ")" : "");
250 	else if (bad_attr[0] || miss_attr[0])
251 		yerr_msg(ys, "Kernel %s: %s%s",
252 			 ys->err.code ? "error" : "warning",
253 			 bad_attr, miss_attr);
254 	else
255 		yerr_msg(ys, "%s", strerror(ys->err.code));
256 
257 	return MNL_CB_OK;
258 }
259 
ynl_cb_error(const struct nlmsghdr * nlh,void * data)260 static int ynl_cb_error(const struct nlmsghdr *nlh, void *data)
261 {
262 	const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh);
263 	struct ynl_parse_arg *yarg = data;
264 	unsigned int hlen;
265 	int code;
266 
267 	code = err->error >= 0 ? err->error : -err->error;
268 	yarg->ys->err.code = code;
269 	errno = code;
270 
271 	hlen = sizeof(*err);
272 	if (!(nlh->nlmsg_flags & NLM_F_CAPPED))
273 		hlen += mnl_nlmsg_get_payload_len(&err->msg);
274 
275 	ynl_ext_ack_check(yarg->ys, nlh, hlen);
276 
277 	return code ? MNL_CB_ERROR : MNL_CB_STOP;
278 }
279 
ynl_cb_done(const struct nlmsghdr * nlh,void * data)280 static int ynl_cb_done(const struct nlmsghdr *nlh, void *data)
281 {
282 	struct ynl_parse_arg *yarg = data;
283 	int err;
284 
285 	err = *(int *)NLMSG_DATA(nlh);
286 	if (err < 0) {
287 		yarg->ys->err.code = -err;
288 		errno = -err;
289 
290 		ynl_ext_ack_check(yarg->ys, nlh, sizeof(int));
291 
292 		return MNL_CB_ERROR;
293 	}
294 	return MNL_CB_STOP;
295 }
296 
ynl_cb_noop(const struct nlmsghdr * nlh,void * data)297 static int ynl_cb_noop(const struct nlmsghdr *nlh, void *data)
298 {
299 	return MNL_CB_OK;
300 }
301 
302 mnl_cb_t ynl_cb_array[NLMSG_MIN_TYPE] = {
303 	[NLMSG_NOOP]	= ynl_cb_noop,
304 	[NLMSG_ERROR]	= ynl_cb_error,
305 	[NLMSG_DONE]	= ynl_cb_done,
306 	[NLMSG_OVERRUN]	= ynl_cb_noop,
307 };
308 
309 /* Attribute validation */
310 
ynl_attr_validate(struct ynl_parse_arg * yarg,const struct nlattr * attr)311 int ynl_attr_validate(struct ynl_parse_arg *yarg, const struct nlattr *attr)
312 {
313 	struct ynl_policy_attr *policy;
314 	unsigned int type, len;
315 	unsigned char *data;
316 
317 	data = mnl_attr_get_payload(attr);
318 	len = mnl_attr_get_payload_len(attr);
319 	type = mnl_attr_get_type(attr);
320 	if (type > yarg->rsp_policy->max_attr) {
321 		yerr(yarg->ys, YNL_ERROR_INTERNAL,
322 		     "Internal error, validating unknown attribute");
323 		return -1;
324 	}
325 
326 	policy = &yarg->rsp_policy->table[type];
327 
328 	switch (policy->type) {
329 	case YNL_PT_REJECT:
330 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
331 		     "Rejected attribute (%s)", policy->name);
332 		return -1;
333 	case YNL_PT_IGNORE:
334 		break;
335 	case YNL_PT_U8:
336 		if (len == sizeof(__u8))
337 			break;
338 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
339 		     "Invalid attribute (u8 %s)", policy->name);
340 		return -1;
341 	case YNL_PT_U16:
342 		if (len == sizeof(__u16))
343 			break;
344 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
345 		     "Invalid attribute (u16 %s)", policy->name);
346 		return -1;
347 	case YNL_PT_U32:
348 		if (len == sizeof(__u32))
349 			break;
350 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
351 		     "Invalid attribute (u32 %s)", policy->name);
352 		return -1;
353 	case YNL_PT_U64:
354 		if (len == sizeof(__u64))
355 			break;
356 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
357 		     "Invalid attribute (u64 %s)", policy->name);
358 		return -1;
359 	case YNL_PT_UINT:
360 		if (len == sizeof(__u32) || len == sizeof(__u64))
361 			break;
362 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
363 		     "Invalid attribute (uint %s)", policy->name);
364 		return -1;
365 	case YNL_PT_FLAG:
366 		/* Let flags grow into real attrs, why not.. */
367 		break;
368 	case YNL_PT_NEST:
369 		if (!len || len >= sizeof(*attr))
370 			break;
371 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
372 		     "Invalid attribute (nest %s)", policy->name);
373 		return -1;
374 	case YNL_PT_BINARY:
375 		if (!policy->len || len == policy->len)
376 			break;
377 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
378 		     "Invalid attribute (binary %s)", policy->name);
379 		return -1;
380 	case YNL_PT_NUL_STR:
381 		if ((!policy->len || len <= policy->len) && !data[len - 1])
382 			break;
383 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
384 		     "Invalid attribute (string %s)", policy->name);
385 		return -1;
386 	case YNL_PT_BITFIELD32:
387 		if (len == sizeof(struct nla_bitfield32))
388 			break;
389 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
390 		     "Invalid attribute (bitfield32 %s)", policy->name);
391 		return -1;
392 	default:
393 		yerr(yarg->ys, YNL_ERROR_ATTR_INVALID,
394 		     "Invalid attribute (unknown %s)", policy->name);
395 		return -1;
396 	}
397 
398 	return 0;
399 }
400 
401 /* Generic code */
402 
ynl_err_reset(struct ynl_sock * ys)403 static void ynl_err_reset(struct ynl_sock *ys)
404 {
405 	ys->err.code = 0;
406 	ys->err.attr_offs = 0;
407 	ys->err.msg[0] = 0;
408 }
409 
ynl_msg_start(struct ynl_sock * ys,__u32 id,__u16 flags)410 struct nlmsghdr *ynl_msg_start(struct ynl_sock *ys, __u32 id, __u16 flags)
411 {
412 	struct nlmsghdr *nlh;
413 
414 	ynl_err_reset(ys);
415 
416 	nlh = ys->nlh = mnl_nlmsg_put_header(ys->tx_buf);
417 	nlh->nlmsg_type	= id;
418 	nlh->nlmsg_flags = flags;
419 	nlh->nlmsg_seq = ++ys->seq;
420 
421 	return nlh;
422 }
423 
424 struct nlmsghdr *
ynl_gemsg_start(struct ynl_sock * ys,__u32 id,__u16 flags,__u8 cmd,__u8 version)425 ynl_gemsg_start(struct ynl_sock *ys, __u32 id, __u16 flags,
426 		__u8 cmd, __u8 version)
427 {
428 	struct genlmsghdr gehdr;
429 	struct nlmsghdr *nlh;
430 	void *data;
431 
432 	nlh = ynl_msg_start(ys, id, flags);
433 
434 	memset(&gehdr, 0, sizeof(gehdr));
435 	gehdr.cmd = cmd;
436 	gehdr.version = version;
437 
438 	data = mnl_nlmsg_put_extra_header(nlh, sizeof(gehdr));
439 	memcpy(data, &gehdr, sizeof(gehdr));
440 
441 	return nlh;
442 }
443 
ynl_msg_start_req(struct ynl_sock * ys,__u32 id)444 void ynl_msg_start_req(struct ynl_sock *ys, __u32 id)
445 {
446 	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK);
447 }
448 
ynl_msg_start_dump(struct ynl_sock * ys,__u32 id)449 void ynl_msg_start_dump(struct ynl_sock *ys, __u32 id)
450 {
451 	ynl_msg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
452 }
453 
454 struct nlmsghdr *
ynl_gemsg_start_req(struct ynl_sock * ys,__u32 id,__u8 cmd,__u8 version)455 ynl_gemsg_start_req(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
456 {
457 	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK, cmd, version);
458 }
459 
460 struct nlmsghdr *
ynl_gemsg_start_dump(struct ynl_sock * ys,__u32 id,__u8 cmd,__u8 version)461 ynl_gemsg_start_dump(struct ynl_sock *ys, __u32 id, __u8 cmd, __u8 version)
462 {
463 	return ynl_gemsg_start(ys, id, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP,
464 			       cmd, version);
465 }
466 
ynl_recv_ack(struct ynl_sock * ys,int ret)467 int ynl_recv_ack(struct ynl_sock *ys, int ret)
468 {
469 	struct ynl_parse_arg yarg = { .ys = ys, };
470 
471 	if (!ret) {
472 		yerr(ys, YNL_ERROR_EXPECT_ACK,
473 		     "Expecting an ACK but nothing received");
474 		return -1;
475 	}
476 
477 	ret = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
478 	if (ret < 0) {
479 		perr(ys, "Socket receive failed");
480 		return ret;
481 	}
482 	return mnl_cb_run(ys->rx_buf, ret, ys->seq, ys->portid,
483 			  ynl_cb_null, &yarg);
484 }
485 
ynl_cb_null(const struct nlmsghdr * nlh,void * data)486 int ynl_cb_null(const struct nlmsghdr *nlh, void *data)
487 {
488 	struct ynl_parse_arg *yarg = data;
489 
490 	yerr(yarg->ys, YNL_ERROR_UNEXPECT_MSG,
491 	     "Received a message when none were expected");
492 
493 	return MNL_CB_ERROR;
494 }
495 
496 /* Init/fini and genetlink boiler plate */
497 static int
ynl_get_family_info_mcast(struct ynl_sock * ys,const struct nlattr * mcasts)498 ynl_get_family_info_mcast(struct ynl_sock *ys, const struct nlattr *mcasts)
499 {
500 	const struct nlattr *entry, *attr;
501 	unsigned int i;
502 
503 	mnl_attr_for_each_nested(attr, mcasts)
504 		ys->n_mcast_groups++;
505 
506 	if (!ys->n_mcast_groups)
507 		return 0;
508 
509 	ys->mcast_groups = calloc(ys->n_mcast_groups,
510 				  sizeof(*ys->mcast_groups));
511 	if (!ys->mcast_groups)
512 		return MNL_CB_ERROR;
513 
514 	i = 0;
515 	mnl_attr_for_each_nested(entry, mcasts) {
516 		mnl_attr_for_each_nested(attr, entry) {
517 			if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_ID)
518 				ys->mcast_groups[i].id = mnl_attr_get_u32(attr);
519 			if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GRP_NAME) {
520 				strncpy(ys->mcast_groups[i].name,
521 					mnl_attr_get_str(attr),
522 					GENL_NAMSIZ - 1);
523 				ys->mcast_groups[i].name[GENL_NAMSIZ - 1] = 0;
524 			}
525 		}
526 		i++;
527 	}
528 
529 	return 0;
530 }
531 
ynl_get_family_info_cb(const struct nlmsghdr * nlh,void * data)532 static int ynl_get_family_info_cb(const struct nlmsghdr *nlh, void *data)
533 {
534 	struct ynl_parse_arg *yarg = data;
535 	struct ynl_sock *ys = yarg->ys;
536 	const struct nlattr *attr;
537 	bool found_id = true;
538 
539 	mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr)) {
540 		if (mnl_attr_get_type(attr) == CTRL_ATTR_MCAST_GROUPS)
541 			if (ynl_get_family_info_mcast(ys, attr))
542 				return MNL_CB_ERROR;
543 
544 		if (mnl_attr_get_type(attr) != CTRL_ATTR_FAMILY_ID)
545 			continue;
546 
547 		if (mnl_attr_get_payload_len(attr) != sizeof(__u16)) {
548 			yerr(ys, YNL_ERROR_ATTR_INVALID, "Invalid family ID");
549 			return MNL_CB_ERROR;
550 		}
551 
552 		ys->family_id = mnl_attr_get_u16(attr);
553 		found_id = true;
554 	}
555 
556 	if (!found_id) {
557 		yerr(ys, YNL_ERROR_ATTR_MISSING, "Family ID missing");
558 		return MNL_CB_ERROR;
559 	}
560 	return MNL_CB_OK;
561 }
562 
ynl_sock_read_family(struct ynl_sock * ys,const char * family_name)563 static int ynl_sock_read_family(struct ynl_sock *ys, const char *family_name)
564 {
565 	struct ynl_parse_arg yarg = { .ys = ys, };
566 	struct nlmsghdr *nlh;
567 	int err;
568 
569 	nlh = ynl_gemsg_start_req(ys, GENL_ID_CTRL, CTRL_CMD_GETFAMILY, 1);
570 	mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name);
571 
572 	err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);
573 	if (err < 0) {
574 		perr(ys, "failed to request socket family info");
575 		return err;
576 	}
577 
578 	err = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);
579 	if (err <= 0) {
580 		perr(ys, "failed to receive the socket family info");
581 		return err;
582 	}
583 	err = mnl_cb_run2(ys->rx_buf, err, ys->seq, ys->portid,
584 			  ynl_get_family_info_cb, &yarg,
585 			  ynl_cb_array, ARRAY_SIZE(ynl_cb_array));
586 	if (err < 0) {
587 		free(ys->mcast_groups);
588 		perr(ys, "failed to receive the socket family info - no such family?");
589 		return err;
590 	}
591 
592 	err = ynl_recv_ack(ys, err);
593 	if (err < 0) {
594 		free(ys->mcast_groups);
595 		return err;
596 	}
597 
598 	return 0;
599 }
600 
601 struct ynl_sock *
ynl_sock_create(const struct ynl_family * yf,struct ynl_error * yse)602 ynl_sock_create(const struct ynl_family *yf, struct ynl_error *yse)
603 {
604 	struct ynl_sock *ys;
605 	int one = 1;
606 
607 	ys = malloc(sizeof(*ys) + 2 * MNL_SOCKET_BUFFER_SIZE);
608 	if (!ys)
609 		return NULL;
610 	memset(ys, 0, sizeof(*ys));
611 
612 	ys->family = yf;
613 	ys->tx_buf = &ys->raw_buf[0];
614 	ys->rx_buf = &ys->raw_buf[MNL_SOCKET_BUFFER_SIZE];
615 	ys->ntf_last_next = &ys->ntf_first;
616 
617 	ys->sock = mnl_socket_open(NETLINK_GENERIC);
618 	if (!ys->sock) {
619 		__perr(yse, "failed to create a netlink socket");
620 		goto err_free_sock;
621 	}
622 
623 	if (mnl_socket_setsockopt(ys->sock, NETLINK_CAP_ACK,
624 				  &one, sizeof(one))) {
625 		__perr(yse, "failed to enable netlink ACK");
626 		goto err_close_sock;
627 	}
628 	if (mnl_socket_setsockopt(ys->sock, NETLINK_EXT_ACK,
629 				  &one, sizeof(one))) {
630 		__perr(yse, "failed to enable netlink ext ACK");
631 		goto err_close_sock;
632 	}
633 
634 	ys->seq = random();
635 	ys->portid = mnl_socket_get_portid(ys->sock);
636 
637 	if (ynl_sock_read_family(ys, yf->name)) {
638 		if (yse)
639 			memcpy(yse, &ys->err, sizeof(*yse));
640 		goto err_close_sock;
641 	}
642 
643 	return ys;
644 
645 err_close_sock:
646 	mnl_socket_close(ys->sock);
647 err_free_sock:
648 	free(ys);
649 	return NULL;
650 }
651 
ynl_sock_destroy(struct ynl_sock * ys)652 void ynl_sock_destroy(struct ynl_sock *ys)
653 {
654 	struct ynl_ntf_base_type *ntf;
655 
656 	mnl_socket_close(ys->sock);
657 	while ((ntf = ynl_ntf_dequeue(ys)))
658 		ynl_ntf_free(ntf);
659 	free(ys->mcast_groups);
660 	free(ys);
661 }
662 
663 /* YNL multicast handling */
664 
ynl_ntf_free(struct ynl_ntf_base_type * ntf)665 void ynl_ntf_free(struct ynl_ntf_base_type *ntf)
666 {
667 	ntf->free(ntf);
668 }
669 
ynl_subscribe(struct ynl_sock * ys,const char * grp_name)670 int ynl_subscribe(struct ynl_sock *ys, const char *grp_name)
671 {
672 	unsigned int i;
673 	int err;
674 
675 	for (i = 0; i < ys->n_mcast_groups; i++)
676 		if (!strcmp(ys->mcast_groups[i].name, grp_name))
677 			break;
678 	if (i == ys->n_mcast_groups) {
679 		yerr(ys, ENOENT, "Multicast group '%s' not found", grp_name);
680 		return -1;
681 	}
682 
683 	err = mnl_socket_setsockopt(ys->sock, NETLINK_ADD_MEMBERSHIP,
684 				    &ys->mcast_groups[i].id,
685 				    sizeof(ys->mcast_groups[i].id));
686 	if (err < 0) {
687 		perr(ys, "Subscribing to multicast group failed");
688 		return -1;
689 	}
690 
691 	return 0;
692 }
693 
ynl_socket_get_fd(struct ynl_sock * ys)694 int ynl_socket_get_fd(struct ynl_sock *ys)
695 {
696 	return mnl_socket_get_fd(ys->sock);
697 }
698 
ynl_ntf_dequeue(struct ynl_sock * ys)699 struct ynl_ntf_base_type *ynl_ntf_dequeue(struct ynl_sock *ys)
700 {
701 	struct ynl_ntf_base_type *ntf;
702 
703 	if (!ynl_has_ntf(ys))
704 		return NULL;
705 
706 	ntf = ys->ntf_first;
707 	ys->ntf_first = ntf->next;
708 	if (ys->ntf_last_next == &ntf->next)
709 		ys->ntf_last_next = &ys->ntf_first;
710 
711 	return ntf;
712 }
713 
ynl_ntf_parse(struct ynl_sock * ys,const struct nlmsghdr * nlh)714 static int ynl_ntf_parse(struct ynl_sock *ys, const struct nlmsghdr *nlh)
715 {
716 	struct ynl_parse_arg yarg = { .ys = ys, };
717 	const struct ynl_ntf_info *info;
718 	struct ynl_ntf_base_type *rsp;
719 	struct genlmsghdr *gehdr;
720 	int ret;
721 
722 	gehdr = mnl_nlmsg_get_payload(nlh);
723 	if (gehdr->cmd >= ys->family->ntf_info_size)
724 		return MNL_CB_ERROR;
725 	info = &ys->family->ntf_info[gehdr->cmd];
726 	if (!info->cb)
727 		return MNL_CB_ERROR;
728 
729 	rsp = calloc(1, info->alloc_sz);
730 	rsp->free = info->free;
731 	yarg.data = rsp->data;
732 	yarg.rsp_policy = info->policy;
733 
734 	ret = info->cb(nlh, &yarg);
735 	if (ret <= MNL_CB_STOP)
736 		goto err_free;
737 
738 	rsp->family = nlh->nlmsg_type;
739 	rsp->cmd = gehdr->cmd;
740 
741 	*ys->ntf_last_next = rsp;
742 	ys->ntf_last_next = &rsp->next;
743 
744 	return MNL_CB_OK;
745 
746 err_free:
747 	info->free(rsp);
748 	return MNL_CB_ERROR;
749 }
750 
ynl_ntf_trampoline(const struct nlmsghdr * nlh,void * data)751 static int ynl_ntf_trampoline(const struct nlmsghdr *nlh, void *data)
752 {
753 	struct ynl_parse_arg *yarg = data;
754 
755 	return ynl_ntf_parse(yarg->ys, nlh);
756 }
757 
ynl_ntf_check(struct ynl_sock * ys)758 int ynl_ntf_check(struct ynl_sock *ys)
759 {
760 	struct ynl_parse_arg yarg = { .ys = ys, };
761 	ssize_t len;
762 	int err;
763 
764 	do {
765 		/* libmnl doesn't let us pass flags to the recv to make
766 		 * it non-blocking so we need to poll() or peek() :|
767 		 */
768 		struct pollfd pfd = { };
769 
770 		pfd.fd = mnl_socket_get_fd(ys->sock);
771 		pfd.events = POLLIN;
772 		err = poll(&pfd, 1, 1);
773 		if (err < 1)
774 			return err;
775 
776 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
777 					  MNL_SOCKET_BUFFER_SIZE);
778 		if (len < 0)
779 			return len;
780 
781 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
782 				  ynl_ntf_trampoline, &yarg,
783 				  ynl_cb_array, NLMSG_MIN_TYPE);
784 		if (err < 0)
785 			return err;
786 	} while (err > 0);
787 
788 	return 0;
789 }
790 
791 /* YNL specific helpers used by the auto-generated code */
792 
793 struct ynl_dump_list_type *YNL_LIST_END = (void *)(0xb4d123);
794 
ynl_error_unknown_notification(struct ynl_sock * ys,__u8 cmd)795 void ynl_error_unknown_notification(struct ynl_sock *ys, __u8 cmd)
796 {
797 	yerr(ys, YNL_ERROR_UNKNOWN_NTF,
798 	     "Unknown notification message type '%d'", cmd);
799 }
800 
ynl_error_parse(struct ynl_parse_arg * yarg,const char * msg)801 int ynl_error_parse(struct ynl_parse_arg *yarg, const char *msg)
802 {
803 	yerr(yarg->ys, YNL_ERROR_INV_RESP, "Error parsing response: %s", msg);
804 	return MNL_CB_ERROR;
805 }
806 
807 static int
ynl_check_alien(struct ynl_sock * ys,const struct nlmsghdr * nlh,__u32 rsp_cmd)808 ynl_check_alien(struct ynl_sock *ys, const struct nlmsghdr *nlh, __u32 rsp_cmd)
809 {
810 	struct genlmsghdr *gehdr;
811 
812 	if (mnl_nlmsg_get_payload_len(nlh) < sizeof(*gehdr)) {
813 		yerr(ys, YNL_ERROR_INV_RESP,
814 		     "Kernel responded with truncated message");
815 		return -1;
816 	}
817 
818 	gehdr = mnl_nlmsg_get_payload(nlh);
819 	if (gehdr->cmd != rsp_cmd)
820 		return ynl_ntf_parse(ys, nlh);
821 
822 	return 0;
823 }
824 
ynl_req_trampoline(const struct nlmsghdr * nlh,void * data)825 static int ynl_req_trampoline(const struct nlmsghdr *nlh, void *data)
826 {
827 	struct ynl_req_state *yrs = data;
828 	int ret;
829 
830 	ret = ynl_check_alien(yrs->yarg.ys, nlh, yrs->rsp_cmd);
831 	if (ret)
832 		return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
833 
834 	return yrs->cb(nlh, &yrs->yarg);
835 }
836 
ynl_exec(struct ynl_sock * ys,struct nlmsghdr * req_nlh,struct ynl_req_state * yrs)837 int ynl_exec(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
838 	     struct ynl_req_state *yrs)
839 {
840 	ssize_t len;
841 	int err;
842 
843 	err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
844 	if (err < 0)
845 		return err;
846 
847 	do {
848 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
849 					  MNL_SOCKET_BUFFER_SIZE);
850 		if (len < 0)
851 			return len;
852 
853 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
854 				  ynl_req_trampoline, yrs,
855 				  ynl_cb_array, NLMSG_MIN_TYPE);
856 		if (err < 0)
857 			return err;
858 	} while (err > 0);
859 
860 	return 0;
861 }
862 
ynl_dump_trampoline(const struct nlmsghdr * nlh,void * data)863 static int ynl_dump_trampoline(const struct nlmsghdr *nlh, void *data)
864 {
865 	struct ynl_dump_state *ds = data;
866 	struct ynl_dump_list_type *obj;
867 	struct ynl_parse_arg yarg = {};
868 	int ret;
869 
870 	ret = ynl_check_alien(ds->ys, nlh, ds->rsp_cmd);
871 	if (ret)
872 		return ret < 0 ? MNL_CB_ERROR : MNL_CB_OK;
873 
874 	obj = calloc(1, ds->alloc_sz);
875 	if (!obj)
876 		return MNL_CB_ERROR;
877 
878 	if (!ds->first)
879 		ds->first = obj;
880 	if (ds->last)
881 		ds->last->next = obj;
882 	ds->last = obj;
883 
884 	yarg.ys = ds->ys;
885 	yarg.rsp_policy = ds->rsp_policy;
886 	yarg.data = &obj->data;
887 
888 	return ds->cb(nlh, &yarg);
889 }
890 
ynl_dump_end(struct ynl_dump_state * ds)891 static void *ynl_dump_end(struct ynl_dump_state *ds)
892 {
893 	if (!ds->first)
894 		return YNL_LIST_END;
895 
896 	ds->last->next = YNL_LIST_END;
897 	return ds->first;
898 }
899 
ynl_exec_dump(struct ynl_sock * ys,struct nlmsghdr * req_nlh,struct ynl_dump_state * yds)900 int ynl_exec_dump(struct ynl_sock *ys, struct nlmsghdr *req_nlh,
901 		  struct ynl_dump_state *yds)
902 {
903 	ssize_t len;
904 	int err;
905 
906 	err = mnl_socket_sendto(ys->sock, req_nlh, req_nlh->nlmsg_len);
907 	if (err < 0)
908 		return err;
909 
910 	do {
911 		len = mnl_socket_recvfrom(ys->sock, ys->rx_buf,
912 					  MNL_SOCKET_BUFFER_SIZE);
913 		if (len < 0)
914 			goto err_close_list;
915 
916 		err = mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,
917 				  ynl_dump_trampoline, yds,
918 				  ynl_cb_array, NLMSG_MIN_TYPE);
919 		if (err < 0)
920 			goto err_close_list;
921 	} while (err > 0);
922 
923 	yds->first = ynl_dump_end(yds);
924 	return 0;
925 
926 err_close_list:
927 	yds->first = ynl_dump_end(yds);
928 	return -1;
929 }
930