1#!/usr/bin/env python3
2# SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4import argparse
5import collections
6import filecmp
7import pathlib
8import os
9import re
10import shutil
11import sys
12import tempfile
13import yaml
14
15sys.path.append(pathlib.Path(__file__).resolve().parent.as_posix())
16from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
17
18
19def c_upper(name):
20    return name.upper().replace('-', '_')
21
22
23def c_lower(name):
24    return name.lower().replace('-', '_')
25
26
27def limit_to_number(name):
28    """
29    Turn a string limit like u32-max or s64-min into its numerical value
30    """
31    if name[0] == 'u' and name.endswith('-min'):
32        return 0
33    width = int(name[1:-4])
34    if name[0] == 's':
35        width -= 1
36    value = (1 << width) - 1
37    if name[0] == 's' and name.endswith('-min'):
38        value = -value - 1
39    return value
40
41
42class BaseNlLib:
43    def get_family_id(self):
44        return 'ys->family_id'
45
46
47class Type(SpecAttr):
48    def __init__(self, family, attr_set, attr, value):
49        super().__init__(family, attr_set, attr, value)
50
51        self.attr = attr
52        self.attr_set = attr_set
53        self.type = attr['type']
54        self.checks = attr.get('checks', {})
55
56        self.request = False
57        self.reply = False
58
59        if 'len' in attr:
60            self.len = attr['len']
61
62        if 'nested-attributes' in attr:
63            self.nested_attrs = attr['nested-attributes']
64            if self.nested_attrs == family.name:
65                self.nested_render_name = c_lower(f"{family.ident_name}")
66            else:
67                self.nested_render_name = c_lower(f"{family.ident_name}_{self.nested_attrs}")
68
69            if self.nested_attrs in self.family.consts:
70                self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
71            else:
72                self.nested_struct_type = 'struct ' + self.nested_render_name
73
74        self.c_name = c_lower(self.name)
75        if self.c_name in _C_KW:
76            self.c_name += '_'
77        if self.c_name[0].isdigit():
78            self.c_name = '_' + self.c_name
79
80        # Added by resolve():
81        self.enum_name = None
82        delattr(self, "enum_name")
83
84    def _get_real_attr(self):
85        # if the attr is for a subset return the "real" attr (just one down, does not recurse)
86        return self.family.attr_sets[self.attr_set.subset_of][self.name]
87
88    def set_request(self):
89        self.request = True
90        if self.attr_set.subset_of:
91            self._get_real_attr().set_request()
92
93    def set_reply(self):
94        self.reply = True
95        if self.attr_set.subset_of:
96            self._get_real_attr().set_reply()
97
98    def get_limit(self, limit, default=None):
99        value = self.checks.get(limit, default)
100        if value is None:
101            return value
102        if isinstance(value, int):
103            return value
104        if value in self.family.consts:
105            return self.family.consts[value]["value"]
106        return limit_to_number(value)
107
108    def get_limit_str(self, limit, default=None, suffix=''):
109        value = self.checks.get(limit, default)
110        if value is None:
111            return ''
112        if isinstance(value, int):
113            return str(value) + suffix
114        if value in self.family.consts:
115            const = self.family.consts[value]
116            if const.get('header'):
117                return c_upper(value)
118            return c_upper(f"{self.family['name']}-{value}")
119        return c_upper(value)
120
121    def resolve(self):
122        if 'name-prefix' in self.attr:
123            enum_name = f"{self.attr['name-prefix']}{self.name}"
124        else:
125            enum_name = f"{self.attr_set.name_prefix}{self.name}"
126        self.enum_name = c_upper(enum_name)
127
128        if self.attr_set.subset_of:
129            if self.checks != self._get_real_attr().checks:
130                raise Exception("Overriding checks not supported by codegen, yet")
131
132    def is_multi_val(self):
133        return None
134
135    def is_scalar(self):
136        return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
137
138    def is_recursive(self):
139        return False
140
141    def is_recursive_for_op(self, ri):
142        return self.is_recursive() and not ri.op
143
144    def presence_type(self):
145        return 'bit'
146
147    def presence_member(self, space, type_filter):
148        if self.presence_type() != type_filter:
149            return
150
151        if self.presence_type() == 'bit':
152            pfx = '__' if space == 'user' else ''
153            return f"{pfx}u32 {self.c_name}:1;"
154
155        if self.presence_type() == 'len':
156            pfx = '__' if space == 'user' else ''
157            return f"{pfx}u32 {self.c_name}_len;"
158
159    def _complex_member_type(self, ri):
160        return None
161
162    def free_needs_iter(self):
163        return False
164
165    def _free_lines(self, ri, var, ref):
166        if self.is_multi_val() or self.presence_type() == 'len':
167            return [f'free({var}->{ref}{self.c_name});']
168        return []
169
170    def free(self, ri, var, ref):
171        lines = self._free_lines(ri, var, ref)
172        for line in lines:
173            ri.cw.p(line)
174
175    def arg_member(self, ri):
176        member = self._complex_member_type(ri)
177        if member:
178            arg = [member + ' *' + self.c_name]
179            if self.presence_type() == 'count':
180                arg += ['unsigned int n_' + self.c_name]
181            return arg
182        raise Exception(f"Struct member not implemented for class type {self.type}")
183
184    def struct_member(self, ri):
185        if self.is_multi_val():
186            ri.cw.p(f"unsigned int n_{self.c_name};")
187        member = self._complex_member_type(ri)
188        if member:
189            ptr = '*' if self.is_multi_val() else ''
190            if self.is_recursive_for_op(ri):
191                ptr = '*'
192            ri.cw.p(f"{member} {ptr}{self.c_name};")
193            return
194        members = self.arg_member(ri)
195        for one in members:
196            ri.cw.p(one + ';')
197
198    def _attr_policy(self, policy):
199        return '{ .type = ' + policy + ', }'
200
201    def attr_policy(self, cw):
202        policy = f'NLA_{c_upper(self.type)}'
203        if self.attr.get('byte-order') == 'big-endian':
204            if self.type in {'u16', 'u32'}:
205                policy = f'NLA_BE{self.type[1:]}'
206
207        spec = self._attr_policy(policy)
208        cw.p(f"\t[{self.enum_name}] = {spec},")
209
210    def _attr_typol(self):
211        raise Exception(f"Type policy not implemented for class type {self.type}")
212
213    def attr_typol(self, cw):
214        typol = self._attr_typol()
215        cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
216
217    def _attr_put_line(self, ri, var, line):
218        if self.presence_type() == 'bit':
219            ri.cw.p(f"if ({var}->_present.{self.c_name})")
220        elif self.presence_type() == 'len':
221            ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
222        ri.cw.p(f"{line};")
223
224    def _attr_put_simple(self, ri, var, put_type):
225        line = f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
226        self._attr_put_line(ri, var, line)
227
228    def attr_put(self, ri, var):
229        raise Exception(f"Put not implemented for class type {self.type}")
230
231    def _attr_get(self, ri, var):
232        raise Exception(f"Attr get not implemented for class type {self.type}")
233
234    def attr_get(self, ri, var, first):
235        lines, init_lines, local_vars = self._attr_get(ri, var)
236        if type(lines) is str:
237            lines = [lines]
238        if type(init_lines) is str:
239            init_lines = [init_lines]
240
241        kw = 'if' if first else 'else if'
242        ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
243        if local_vars:
244            for local in local_vars:
245                ri.cw.p(local)
246            ri.cw.nl()
247
248        if not self.is_multi_val():
249            ri.cw.p("if (ynl_attr_validate(yarg, attr))")
250            ri.cw.p("return YNL_PARSE_CB_ERROR;")
251            if self.presence_type() == 'bit':
252                ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
253
254        if init_lines:
255            ri.cw.nl()
256            for line in init_lines:
257                ri.cw.p(line)
258
259        for line in lines:
260            ri.cw.p(line)
261        ri.cw.block_end()
262        return True
263
264    def _setter_lines(self, ri, member, presence):
265        raise Exception(f"Setter not implemented for class type {self.type}")
266
267    def setter(self, ri, space, direction, deref=False, ref=None):
268        ref = (ref if ref else []) + [self.c_name]
269        var = "req"
270        member = f"{var}->{'.'.join(ref)}"
271
272        local_vars = []
273        if self.free_needs_iter():
274            local_vars += ['unsigned int i;']
275
276        code = []
277        presence = ''
278        for i in range(0, len(ref)):
279            presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
280            # Every layer below last is a nest, so we know it uses bit presence
281            # last layer is "self" and may be a complex type
282            if i == len(ref) - 1 and self.presence_type() != 'bit':
283                continue
284            code.append(presence + ' = 1;')
285        ref_path = '.'.join(ref[:-1])
286        if ref_path:
287            ref_path += '.'
288        code += self._free_lines(ri, var, ref_path)
289        code += self._setter_lines(ri, member, presence)
290
291        func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
292        free = bool([x for x in code if 'free(' in x])
293        alloc = bool([x for x in code if 'alloc(' in x])
294        if free and not alloc:
295            func_name = '__' + func_name
296        ri.cw.write_func('static inline void', func_name, local_vars=local_vars,
297                         body=code,
298                         args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
299
300
301class TypeUnused(Type):
302    def presence_type(self):
303        return ''
304
305    def arg_member(self, ri):
306        return []
307
308    def _attr_get(self, ri, var):
309        return ['return YNL_PARSE_CB_ERROR;'], None, None
310
311    def _attr_typol(self):
312        return '.type = YNL_PT_REJECT, '
313
314    def attr_policy(self, cw):
315        pass
316
317    def attr_put(self, ri, var):
318        pass
319
320    def attr_get(self, ri, var, first):
321        pass
322
323    def setter(self, ri, space, direction, deref=False, ref=None):
324        pass
325
326
327class TypePad(Type):
328    def presence_type(self):
329        return ''
330
331    def arg_member(self, ri):
332        return []
333
334    def _attr_typol(self):
335        return '.type = YNL_PT_IGNORE, '
336
337    def attr_put(self, ri, var):
338        pass
339
340    def attr_get(self, ri, var, first):
341        pass
342
343    def attr_policy(self, cw):
344        pass
345
346    def setter(self, ri, space, direction, deref=False, ref=None):
347        pass
348
349
350class TypeScalar(Type):
351    def __init__(self, family, attr_set, attr, value):
352        super().__init__(family, attr_set, attr, value)
353
354        self.byte_order_comment = ''
355        if 'byte-order' in attr:
356            self.byte_order_comment = f" /* {attr['byte-order']} */"
357
358        if 'enum' in self.attr:
359            enum = self.family.consts[self.attr['enum']]
360            low, high = enum.value_range()
361            if 'min' not in self.checks:
362                if low != 0 or self.type[0] == 's':
363                    self.checks['min'] = low
364            if 'max' not in self.checks:
365                self.checks['max'] = high
366
367        if 'min' in self.checks and 'max' in self.checks:
368            if self.get_limit('min') > self.get_limit('max'):
369                raise Exception(f'Invalid limit for "{self.name}" min: {self.get_limit("min")} max: {self.get_limit("max")}')
370            self.checks['range'] = True
371
372        low = min(self.get_limit('min', 0), self.get_limit('max', 0))
373        high = max(self.get_limit('min', 0), self.get_limit('max', 0))
374        if low < 0 and self.type[0] == 'u':
375            raise Exception(f'Invalid limit for "{self.name}" negative limit for unsigned type')
376        if low < -32768 or high > 32767:
377            self.checks['full-range'] = True
378
379        # Added by resolve():
380        self.is_bitfield = None
381        delattr(self, "is_bitfield")
382        self.type_name = None
383        delattr(self, "type_name")
384
385    def resolve(self):
386        self.resolve_up(super())
387
388        if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
389            self.is_bitfield = True
390        elif 'enum' in self.attr:
391            self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
392        else:
393            self.is_bitfield = False
394
395        if not self.is_bitfield and 'enum' in self.attr:
396            self.type_name = self.family.consts[self.attr['enum']].user_type
397        elif self.is_auto_scalar:
398            self.type_name = '__' + self.type[0] + '64'
399        else:
400            self.type_name = '__' + self.type
401
402    def _attr_policy(self, policy):
403        if 'flags-mask' in self.checks or self.is_bitfield:
404            if self.is_bitfield:
405                enum = self.family.consts[self.attr['enum']]
406                mask = enum.get_mask(as_flags=True)
407            else:
408                flags = self.family.consts[self.checks['flags-mask']]
409                flag_cnt = len(flags['entries'])
410                mask = (1 << flag_cnt) - 1
411            return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
412        elif 'full-range' in self.checks:
413            return f"NLA_POLICY_FULL_RANGE({policy}, &{c_lower(self.enum_name)}_range)"
414        elif 'range' in self.checks:
415            return f"NLA_POLICY_RANGE({policy}, {self.get_limit_str('min')}, {self.get_limit_str('max')})"
416        elif 'min' in self.checks:
417            return f"NLA_POLICY_MIN({policy}, {self.get_limit_str('min')})"
418        elif 'max' in self.checks:
419            return f"NLA_POLICY_MAX({policy}, {self.get_limit_str('max')})"
420        return super()._attr_policy(policy)
421
422    def _attr_typol(self):
423        return f'.type = YNL_PT_U{c_upper(self.type[1:])}, '
424
425    def arg_member(self, ri):
426        return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
427
428    def attr_put(self, ri, var):
429        self._attr_put_simple(ri, var, self.type)
430
431    def _attr_get(self, ri, var):
432        return f"{var}->{self.c_name} = ynl_attr_get_{self.type}(attr);", None, None
433
434    def _setter_lines(self, ri, member, presence):
435        return [f"{member} = {self.c_name};"]
436
437
438class TypeFlag(Type):
439    def arg_member(self, ri):
440        return []
441
442    def _attr_typol(self):
443        return '.type = YNL_PT_FLAG, '
444
445    def attr_put(self, ri, var):
446        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, NULL, 0)")
447
448    def _attr_get(self, ri, var):
449        return [], None, None
450
451    def _setter_lines(self, ri, member, presence):
452        return []
453
454
455class TypeString(Type):
456    def arg_member(self, ri):
457        return [f"const char *{self.c_name}"]
458
459    def presence_type(self):
460        return 'len'
461
462    def struct_member(self, ri):
463        ri.cw.p(f"char *{self.c_name};")
464
465    def _attr_typol(self):
466        return f'.type = YNL_PT_NUL_STR, '
467
468    def _attr_policy(self, policy):
469        if 'exact-len' in self.checks:
470            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
471        else:
472            mem = '{ .type = ' + policy
473            if 'max-len' in self.checks:
474                mem += ', .len = ' + self.get_limit_str('max-len')
475            mem += ', }'
476        return mem
477
478    def attr_policy(self, cw):
479        if self.checks.get('unterminated-ok', False):
480            policy = 'NLA_STRING'
481        else:
482            policy = 'NLA_NUL_STRING'
483
484        spec = self._attr_policy(policy)
485        cw.p(f"\t[{self.enum_name}] = {spec},")
486
487    def attr_put(self, ri, var):
488        self._attr_put_simple(ri, var, 'str')
489
490    def _attr_get(self, ri, var):
491        len_mem = var + '->_present.' + self.c_name + '_len'
492        return [f"{len_mem} = len;",
493                f"{var}->{self.c_name} = malloc(len + 1);",
494                f"memcpy({var}->{self.c_name}, ynl_attr_get_str(attr), len);",
495                f"{var}->{self.c_name}[len] = 0;"], \
496               ['len = strnlen(ynl_attr_get_str(attr), ynl_attr_data_len(attr));'], \
497               ['unsigned int len;']
498
499    def _setter_lines(self, ri, member, presence):
500        return [f"{presence}_len = strlen({self.c_name});",
501                f"{member} = malloc({presence}_len + 1);",
502                f'memcpy({member}, {self.c_name}, {presence}_len);',
503                f'{member}[{presence}_len] = 0;']
504
505
506class TypeBinary(Type):
507    def arg_member(self, ri):
508        return [f"const void *{self.c_name}", 'size_t len']
509
510    def presence_type(self):
511        return 'len'
512
513    def struct_member(self, ri):
514        ri.cw.p(f"void *{self.c_name};")
515
516    def _attr_typol(self):
517        return f'.type = YNL_PT_BINARY,'
518
519    def _attr_policy(self, policy):
520        if len(self.checks) == 0:
521            pass
522        elif len(self.checks) == 1:
523            check_name = list(self.checks)[0]
524            if check_name not in {'exact-len', 'min-len', 'max-len'}:
525                raise Exception('Unsupported check for binary type: ' + check_name)
526        else:
527            raise Exception('More than one check for binary type not implemented, yet')
528
529        if len(self.checks) == 0:
530            mem = '{ .type = NLA_BINARY, }'
531        elif 'exact-len' in self.checks:
532            mem = 'NLA_POLICY_EXACT_LEN(' + self.get_limit_str('exact-len') + ')'
533        elif 'min-len' in self.checks:
534            mem = '{ .len = ' + self.get_limit_str('min-len') + ', }'
535        elif 'max-len' in self.checks:
536            mem = 'NLA_POLICY_MAX_LEN(' + self.get_limit_str('max-len') + ')'
537
538        return mem
539
540    def attr_put(self, ri, var):
541        self._attr_put_line(ri, var, f"ynl_attr_put(nlh, {self.enum_name}, " +
542                            f"{var}->{self.c_name}, {var}->_present.{self.c_name}_len)")
543
544    def _attr_get(self, ri, var):
545        len_mem = var + '->_present.' + self.c_name + '_len'
546        return [f"{len_mem} = len;",
547                f"{var}->{self.c_name} = malloc(len);",
548                f"memcpy({var}->{self.c_name}, ynl_attr_data(attr), len);"], \
549               ['len = ynl_attr_data_len(attr);'], \
550               ['unsigned int len;']
551
552    def _setter_lines(self, ri, member, presence):
553        return [f"{presence}_len = len;",
554                f"{member} = malloc({presence}_len);",
555                f'memcpy({member}, {self.c_name}, {presence}_len);']
556
557
558class TypeBitfield32(Type):
559    def _complex_member_type(self, ri):
560        return "struct nla_bitfield32"
561
562    def _attr_typol(self):
563        return f'.type = YNL_PT_BITFIELD32, '
564
565    def _attr_policy(self, policy):
566        if not 'enum' in self.attr:
567            raise Exception('Enum required for bitfield32 attr')
568        enum = self.family.consts[self.attr['enum']]
569        mask = enum.get_mask(as_flags=True)
570        return f"NLA_POLICY_BITFIELD32({mask})"
571
572    def attr_put(self, ri, var):
573        line = f"ynl_attr_put(nlh, {self.enum_name}, &{var}->{self.c_name}, sizeof(struct nla_bitfield32))"
574        self._attr_put_line(ri, var, line)
575
576    def _attr_get(self, ri, var):
577        return f"memcpy(&{var}->{self.c_name}, ynl_attr_data(attr), sizeof(struct nla_bitfield32));", None, None
578
579    def _setter_lines(self, ri, member, presence):
580        return [f"memcpy(&{member}, {self.c_name}, sizeof(struct nla_bitfield32));"]
581
582
583class TypeNest(Type):
584    def is_recursive(self):
585        return self.family.pure_nested_structs[self.nested_attrs].recursive
586
587    def _complex_member_type(self, ri):
588        return self.nested_struct_type
589
590    def _free_lines(self, ri, var, ref):
591        lines = []
592        at = '&'
593        if self.is_recursive_for_op(ri):
594            at = ''
595            lines += [f'if ({var}->{ref}{self.c_name})']
596        lines += [f'{self.nested_render_name}_free({at}{var}->{ref}{self.c_name});']
597        return lines
598
599    def _attr_typol(self):
600        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
601
602    def _attr_policy(self, policy):
603        return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
604
605    def attr_put(self, ri, var):
606        at = '' if self.is_recursive_for_op(ri) else '&'
607        self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
608                            f"{self.enum_name}, {at}{var}->{self.c_name})")
609
610    def _attr_get(self, ri, var):
611        get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
612                     "return YNL_PARSE_CB_ERROR;"]
613        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
614                      f"parg.data = &{var}->{self.c_name};"]
615        return get_lines, init_lines, None
616
617    def setter(self, ri, space, direction, deref=False, ref=None):
618        ref = (ref if ref else []) + [self.c_name]
619
620        for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
621            if attr.is_recursive():
622                continue
623            attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
624
625
626class TypeMultiAttr(Type):
627    def __init__(self, family, attr_set, attr, value, base_type):
628        super().__init__(family, attr_set, attr, value)
629
630        self.base_type = base_type
631
632    def is_multi_val(self):
633        return True
634
635    def presence_type(self):
636        return 'count'
637
638    def _complex_member_type(self, ri):
639        if 'type' not in self.attr or self.attr['type'] == 'nest':
640            return self.nested_struct_type
641        elif self.attr['type'] in scalars:
642            scalar_pfx = '__' if ri.ku_space == 'user' else ''
643            return scalar_pfx + self.attr['type']
644        else:
645            raise Exception(f"Sub-type {self.attr['type']} not supported yet")
646
647    def free_needs_iter(self):
648        return 'type' not in self.attr or self.attr['type'] == 'nest'
649
650    def _free_lines(self, ri, var, ref):
651        lines = []
652        if self.attr['type'] in scalars:
653            lines += [f"free({var}->{ref}{self.c_name});"]
654        elif 'type' not in self.attr or self.attr['type'] == 'nest':
655            lines += [
656                f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)",
657                f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);',
658                f"free({var}->{ref}{self.c_name});",
659            ]
660        else:
661            raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
662        return lines
663
664    def _attr_policy(self, policy):
665        return self.base_type._attr_policy(policy)
666
667    def _attr_typol(self):
668        return self.base_type._attr_typol()
669
670    def _attr_get(self, ri, var):
671        return f'n_{self.c_name}++;', None, None
672
673    def attr_put(self, ri, var):
674        if self.attr['type'] in scalars:
675            put_type = self.type
676            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
677            ri.cw.p(f"ynl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
678        elif 'type' not in self.attr or self.attr['type'] == 'nest':
679            ri.cw.p(f"for (i = 0; i < {var}->n_{self.c_name}; i++)")
680            self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
681                                f"{self.enum_name}, &{var}->{self.c_name}[i])")
682        else:
683            raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
684
685    def _setter_lines(self, ri, member, presence):
686        # For multi-attr we have a count, not presence, hack up the presence
687        presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
688        return [f"{member} = {self.c_name};",
689                f"{presence} = n_{self.c_name};"]
690
691
692class TypeArrayNest(Type):
693    def is_multi_val(self):
694        return True
695
696    def presence_type(self):
697        return 'count'
698
699    def _complex_member_type(self, ri):
700        if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
701            return self.nested_struct_type
702        elif self.attr['sub-type'] in scalars:
703            scalar_pfx = '__' if ri.ku_space == 'user' else ''
704            return scalar_pfx + self.attr['sub-type']
705        else:
706            raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
707
708    def _attr_typol(self):
709        if self.attr['sub-type'] in scalars:
710            return f'.type = YNL_PT_U{c_upper(self.sub_type[1:])}, '
711        else:
712            return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
713
714    def _attr_get(self, ri, var):
715        local_vars = ['const struct nlattr *attr2;']
716        get_lines = [f'attr_{self.c_name} = attr;',
717                     'ynl_attr_for_each_nested(attr2, attr) {',
718                     '\tif (ynl_attr_validate(yarg, attr2))',
719                     '\t\treturn YNL_PARSE_CB_ERROR;',
720                     f'\t{var}->n_{self.c_name}++;',
721                     '}']
722        return get_lines, None, local_vars
723
724
725class TypeNestTypeValue(Type):
726    def _complex_member_type(self, ri):
727        return self.nested_struct_type
728
729    def _attr_typol(self):
730        return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
731
732    def _attr_get(self, ri, var):
733        prev = 'attr'
734        tv_args = ''
735        get_lines = []
736        local_vars = []
737        init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
738                      f"parg.data = &{var}->{self.c_name};"]
739        if 'type-value' in self.attr:
740            tv_names = [c_lower(x) for x in self.attr["type-value"]]
741            local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
742            local_vars += [f'__u32 {", ".join(tv_names)};']
743            for level in self.attr["type-value"]:
744                level = c_lower(level)
745                get_lines += [f'attr_{level} = ynl_attr_data({prev});']
746                get_lines += [f'{level} = ynl_attr_type(attr_{level});']
747                prev = 'attr_' + level
748
749            tv_args = f", {', '.join(tv_names)}"
750
751        get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
752        return get_lines, init_lines, local_vars
753
754
755class Struct:
756    def __init__(self, family, space_name, type_list=None, inherited=None):
757        self.family = family
758        self.space_name = space_name
759        self.attr_set = family.attr_sets[space_name]
760        # Use list to catch comparisons with empty sets
761        self._inherited = inherited if inherited is not None else []
762        self.inherited = []
763
764        self.nested = type_list is None
765        if family.name == c_lower(space_name):
766            self.render_name = c_lower(family.ident_name)
767        else:
768            self.render_name = c_lower(family.ident_name + '-' + space_name)
769        self.struct_name = 'struct ' + self.render_name
770        if self.nested and space_name in family.consts:
771            self.struct_name += '_'
772        self.ptr_name = self.struct_name + ' *'
773        # All attr sets this one contains, directly or multiple levels down
774        self.child_nests = set()
775
776        self.request = False
777        self.reply = False
778        self.recursive = False
779        self.in_multi_val = False  # used by a MultiAttr or and legacy arrays
780
781        self.attr_list = []
782        self.attrs = dict()
783        if type_list is not None:
784            for t in type_list:
785                self.attr_list.append((t, self.attr_set[t]),)
786        else:
787            for t in self.attr_set:
788                self.attr_list.append((t, self.attr_set[t]),)
789
790        max_val = 0
791        self.attr_max_val = None
792        for name, attr in self.attr_list:
793            if attr.value >= max_val:
794                max_val = attr.value
795                self.attr_max_val = attr
796            self.attrs[name] = attr
797
798    def __iter__(self):
799        yield from self.attrs
800
801    def __getitem__(self, key):
802        return self.attrs[key]
803
804    def member_list(self):
805        return self.attr_list
806
807    def set_inherited(self, new_inherited):
808        if self._inherited != new_inherited:
809            raise Exception("Inheriting different members not supported")
810        self.inherited = [c_lower(x) for x in sorted(self._inherited)]
811
812
813class EnumEntry(SpecEnumEntry):
814    def __init__(self, enum_set, yaml, prev, value_start):
815        super().__init__(enum_set, yaml, prev, value_start)
816
817        if prev:
818            self.value_change = (self.value != prev.value + 1)
819        else:
820            self.value_change = (self.value != 0)
821        self.value_change = self.value_change or self.enum_set['type'] == 'flags'
822
823        # Added by resolve:
824        self.c_name = None
825        delattr(self, "c_name")
826
827    def resolve(self):
828        self.resolve_up(super())
829
830        self.c_name = c_upper(self.enum_set.value_pfx + self.name)
831
832
833class EnumSet(SpecEnumSet):
834    def __init__(self, family, yaml):
835        self.render_name = c_lower(family.ident_name + '-' + yaml['name'])
836
837        if 'enum-name' in yaml:
838            if yaml['enum-name']:
839                self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
840                self.user_type = self.enum_name
841            else:
842                self.enum_name = None
843        else:
844            self.enum_name = 'enum ' + self.render_name
845
846        if self.enum_name:
847            self.user_type = self.enum_name
848        else:
849            self.user_type = 'int'
850
851        self.value_pfx = yaml.get('name-prefix', f"{family.ident_name}-{yaml['name']}-")
852        self.header = yaml.get('header', None)
853        self.enum_cnt_name = yaml.get('enum-cnt-name', None)
854
855        super().__init__(family, yaml)
856
857    def new_entry(self, entry, prev_entry, value_start):
858        return EnumEntry(self, entry, prev_entry, value_start)
859
860    def value_range(self):
861        low = min([x.value for x in self.entries.values()])
862        high = max([x.value for x in self.entries.values()])
863
864        if high - low + 1 != len(self.entries):
865            raise Exception("Can't get value range for a noncontiguous enum")
866
867        return low, high
868
869
870class AttrSet(SpecAttrSet):
871    def __init__(self, family, yaml):
872        super().__init__(family, yaml)
873
874        if self.subset_of is None:
875            if 'name-prefix' in yaml:
876                pfx = yaml['name-prefix']
877            elif self.name == family.name:
878                pfx = family.ident_name + '-a-'
879            else:
880                pfx = f"{family.ident_name}-a-{self.name}-"
881            self.name_prefix = c_upper(pfx)
882            self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
883            self.cnt_name = c_upper(self.yaml.get('attr-cnt-name', f"__{self.name_prefix}max"))
884        else:
885            self.name_prefix = family.attr_sets[self.subset_of].name_prefix
886            self.max_name = family.attr_sets[self.subset_of].max_name
887            self.cnt_name = family.attr_sets[self.subset_of].cnt_name
888
889        # Added by resolve:
890        self.c_name = None
891        delattr(self, "c_name")
892
893    def resolve(self):
894        self.c_name = c_lower(self.name)
895        if self.c_name in _C_KW:
896            self.c_name += '_'
897        if self.c_name == self.family.c_name:
898            self.c_name = ''
899
900    def new_attr(self, elem, value):
901        if elem['type'] in scalars:
902            t = TypeScalar(self.family, self, elem, value)
903        elif elem['type'] == 'unused':
904            t = TypeUnused(self.family, self, elem, value)
905        elif elem['type'] == 'pad':
906            t = TypePad(self.family, self, elem, value)
907        elif elem['type'] == 'flag':
908            t = TypeFlag(self.family, self, elem, value)
909        elif elem['type'] == 'string':
910            t = TypeString(self.family, self, elem, value)
911        elif elem['type'] == 'binary':
912            t = TypeBinary(self.family, self, elem, value)
913        elif elem['type'] == 'bitfield32':
914            t = TypeBitfield32(self.family, self, elem, value)
915        elif elem['type'] == 'nest':
916            t = TypeNest(self.family, self, elem, value)
917        elif elem['type'] == 'indexed-array' and 'sub-type' in elem:
918            if elem["sub-type"] in ['nest', 'u32']:
919                t = TypeArrayNest(self.family, self, elem, value)
920            else:
921                raise Exception(f'new_attr: unsupported sub-type {elem["sub-type"]}')
922        elif elem['type'] == 'nest-type-value':
923            t = TypeNestTypeValue(self.family, self, elem, value)
924        else:
925            raise Exception(f"No typed class for type {elem['type']}")
926
927        if 'multi-attr' in elem and elem['multi-attr']:
928            t = TypeMultiAttr(self.family, self, elem, value, t)
929
930        return t
931
932
933class Operation(SpecOperation):
934    def __init__(self, family, yaml, req_value, rsp_value):
935        super().__init__(family, yaml, req_value, rsp_value)
936
937        self.render_name = c_lower(family.ident_name + '_' + self.name)
938
939        self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
940                         ('dump' in yaml and 'request' in yaml['dump'])
941
942        self.has_ntf = False
943
944        # Added by resolve:
945        self.enum_name = None
946        delattr(self, "enum_name")
947
948    def resolve(self):
949        self.resolve_up(super())
950
951        if not self.is_async:
952            self.enum_name = self.family.op_prefix + c_upper(self.name)
953        else:
954            self.enum_name = self.family.async_op_prefix + c_upper(self.name)
955
956    def mark_has_ntf(self):
957        self.has_ntf = True
958
959
960class Family(SpecFamily):
961    def __init__(self, file_name, exclude_ops):
962        # Added by resolve:
963        self.c_name = None
964        delattr(self, "c_name")
965        self.op_prefix = None
966        delattr(self, "op_prefix")
967        self.async_op_prefix = None
968        delattr(self, "async_op_prefix")
969        self.mcgrps = None
970        delattr(self, "mcgrps")
971        self.consts = None
972        delattr(self, "consts")
973        self.hooks = None
974        delattr(self, "hooks")
975
976        super().__init__(file_name, exclude_ops=exclude_ops)
977
978        self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
979        self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
980
981        if 'definitions' not in self.yaml:
982            self.yaml['definitions'] = []
983
984        if 'uapi-header' in self.yaml:
985            self.uapi_header = self.yaml['uapi-header']
986        else:
987            self.uapi_header = f"linux/{self.ident_name}.h"
988        if self.uapi_header.startswith("linux/") and self.uapi_header.endswith('.h'):
989            self.uapi_header_name = self.uapi_header[6:-2]
990        else:
991            self.uapi_header_name = self.ident_name
992
993    def resolve(self):
994        self.resolve_up(super())
995
996        if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
997            raise Exception("Codegen only supported for genetlink")
998
999        self.c_name = c_lower(self.ident_name)
1000        if 'name-prefix' in self.yaml['operations']:
1001            self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
1002        else:
1003            self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
1004        if 'async-prefix' in self.yaml['operations']:
1005            self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
1006        else:
1007            self.async_op_prefix = self.op_prefix
1008
1009        self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
1010
1011        self.hooks = dict()
1012        for when in ['pre', 'post']:
1013            self.hooks[when] = dict()
1014            for op_mode in ['do', 'dump']:
1015                self.hooks[when][op_mode] = dict()
1016                self.hooks[when][op_mode]['set'] = set()
1017                self.hooks[when][op_mode]['list'] = []
1018
1019        # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
1020        self.root_sets = dict()
1021        # dict space-name -> set('request', 'reply')
1022        self.pure_nested_structs = dict()
1023
1024        self._mark_notify()
1025        self._mock_up_events()
1026
1027        self._load_root_sets()
1028        self._load_nested_sets()
1029        self._load_attr_use()
1030        self._load_hooks()
1031
1032        self.kernel_policy = self.yaml.get('kernel-policy', 'split')
1033        if self.kernel_policy == 'global':
1034            self._load_global_policy()
1035
1036    def new_enum(self, elem):
1037        return EnumSet(self, elem)
1038
1039    def new_attr_set(self, elem):
1040        return AttrSet(self, elem)
1041
1042    def new_operation(self, elem, req_value, rsp_value):
1043        return Operation(self, elem, req_value, rsp_value)
1044
1045    def _mark_notify(self):
1046        for op in self.msgs.values():
1047            if 'notify' in op:
1048                self.ops[op['notify']].mark_has_ntf()
1049
1050    # Fake a 'do' equivalent of all events, so that we can render their response parsing
1051    def _mock_up_events(self):
1052        for op in self.yaml['operations']['list']:
1053            if 'event' in op:
1054                op['do'] = {
1055                    'reply': {
1056                        'attributes': op['event']['attributes']
1057                    }
1058                }
1059
1060    def _load_root_sets(self):
1061        for op_name, op in self.msgs.items():
1062            if 'attribute-set' not in op:
1063                continue
1064
1065            req_attrs = set()
1066            rsp_attrs = set()
1067            for op_mode in ['do', 'dump']:
1068                if op_mode in op and 'request' in op[op_mode]:
1069                    req_attrs.update(set(op[op_mode]['request']['attributes']))
1070                if op_mode in op and 'reply' in op[op_mode]:
1071                    rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
1072            if 'event' in op:
1073                rsp_attrs.update(set(op['event']['attributes']))
1074
1075            if op['attribute-set'] not in self.root_sets:
1076                self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
1077            else:
1078                self.root_sets[op['attribute-set']]['request'].update(req_attrs)
1079                self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
1080
1081    def _sort_pure_types(self):
1082        # Try to reorder according to dependencies
1083        pns_key_list = list(self.pure_nested_structs.keys())
1084        pns_key_seen = set()
1085        rounds = len(pns_key_list) ** 2  # it's basically bubble sort
1086        for _ in range(rounds):
1087            if len(pns_key_list) == 0:
1088                break
1089            name = pns_key_list.pop(0)
1090            finished = True
1091            for _, spec in self.attr_sets[name].items():
1092                if 'nested-attributes' in spec:
1093                    nested = spec['nested-attributes']
1094                    # If the unknown nest we hit is recursive it's fine, it'll be a pointer
1095                    if self.pure_nested_structs[nested].recursive:
1096                        continue
1097                    if nested not in pns_key_seen:
1098                        # Dicts are sorted, this will make struct last
1099                        struct = self.pure_nested_structs.pop(name)
1100                        self.pure_nested_structs[name] = struct
1101                        finished = False
1102                        break
1103            if finished:
1104                pns_key_seen.add(name)
1105            else:
1106                pns_key_list.append(name)
1107
1108    def _load_nested_sets(self):
1109        attr_set_queue = list(self.root_sets.keys())
1110        attr_set_seen = set(self.root_sets.keys())
1111
1112        while len(attr_set_queue):
1113            a_set = attr_set_queue.pop(0)
1114            for attr, spec in self.attr_sets[a_set].items():
1115                if 'nested-attributes' not in spec:
1116                    continue
1117
1118                nested = spec['nested-attributes']
1119                if nested not in attr_set_seen:
1120                    attr_set_queue.append(nested)
1121                    attr_set_seen.add(nested)
1122
1123                inherit = set()
1124                if nested not in self.root_sets:
1125                    if nested not in self.pure_nested_structs:
1126                        self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
1127                else:
1128                    raise Exception(f'Using attr set as root and nested not supported - {nested}')
1129
1130                if 'type-value' in spec:
1131                    if nested in self.root_sets:
1132                        raise Exception("Inheriting members to a space used as root not supported")
1133                    inherit.update(set(spec['type-value']))
1134                elif spec['type'] == 'indexed-array':
1135                    inherit.add('idx')
1136                self.pure_nested_structs[nested].set_inherited(inherit)
1137
1138        for root_set, rs_members in self.root_sets.items():
1139            for attr, spec in self.attr_sets[root_set].items():
1140                if 'nested-attributes' in spec:
1141                    nested = spec['nested-attributes']
1142                    if attr in rs_members['request']:
1143                        self.pure_nested_structs[nested].request = True
1144                    if attr in rs_members['reply']:
1145                        self.pure_nested_structs[nested].reply = True
1146                    if spec.is_multi_val():
1147                        child = self.pure_nested_structs.get(nested)
1148                        child.in_multi_val = True
1149
1150        self._sort_pure_types()
1151
1152        # Propagate the request / reply / recursive
1153        for attr_set, struct in reversed(self.pure_nested_structs.items()):
1154            for _, spec in self.attr_sets[attr_set].items():
1155                if 'nested-attributes' in spec:
1156                    child_name = spec['nested-attributes']
1157                    struct.child_nests.add(child_name)
1158                    child = self.pure_nested_structs.get(child_name)
1159                    if child:
1160                        if not child.recursive:
1161                            struct.child_nests.update(child.child_nests)
1162                        child.request |= struct.request
1163                        child.reply |= struct.reply
1164                        if spec.is_multi_val():
1165                            child.in_multi_val = True
1166                if attr_set in struct.child_nests:
1167                    struct.recursive = True
1168
1169        self._sort_pure_types()
1170
1171    def _load_attr_use(self):
1172        for _, struct in self.pure_nested_structs.items():
1173            if struct.request:
1174                for _, arg in struct.member_list():
1175                    arg.set_request()
1176            if struct.reply:
1177                for _, arg in struct.member_list():
1178                    arg.set_reply()
1179
1180        for root_set, rs_members in self.root_sets.items():
1181            for attr, spec in self.attr_sets[root_set].items():
1182                if attr in rs_members['request']:
1183                    spec.set_request()
1184                if attr in rs_members['reply']:
1185                    spec.set_reply()
1186
1187    def _load_global_policy(self):
1188        global_set = set()
1189        attr_set_name = None
1190        for op_name, op in self.ops.items():
1191            if not op:
1192                continue
1193            if 'attribute-set' not in op:
1194                continue
1195
1196            if attr_set_name is None:
1197                attr_set_name = op['attribute-set']
1198            if attr_set_name != op['attribute-set']:
1199                raise Exception('For a global policy all ops must use the same set')
1200
1201            for op_mode in ['do', 'dump']:
1202                if op_mode in op:
1203                    req = op[op_mode].get('request')
1204                    if req:
1205                        global_set.update(req.get('attributes', []))
1206
1207        self.global_policy = []
1208        self.global_policy_set = attr_set_name
1209        for attr in self.attr_sets[attr_set_name]:
1210            if attr in global_set:
1211                self.global_policy.append(attr)
1212
1213    def _load_hooks(self):
1214        for op in self.ops.values():
1215            for op_mode in ['do', 'dump']:
1216                if op_mode not in op:
1217                    continue
1218                for when in ['pre', 'post']:
1219                    if when not in op[op_mode]:
1220                        continue
1221                    name = op[op_mode][when]
1222                    if name in self.hooks[when][op_mode]['set']:
1223                        continue
1224                    self.hooks[when][op_mode]['set'].add(name)
1225                    self.hooks[when][op_mode]['list'].append(name)
1226
1227
1228class RenderInfo:
1229    def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1230        self.family = family
1231        self.nl = cw.nlib
1232        self.ku_space = ku_space
1233        self.op_mode = op_mode
1234        self.op = op
1235
1236        self.fixed_hdr = None
1237        if op and op.fixed_header:
1238            self.fixed_hdr = 'struct ' + c_lower(op.fixed_header)
1239
1240        # 'do' and 'dump' response parsing is identical
1241        self.type_consistent = True
1242        if op_mode != 'do' and 'dump' in op:
1243            if 'do' in op:
1244                if ('reply' in op['do']) != ('reply' in op["dump"]):
1245                    self.type_consistent = False
1246                elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1247                    self.type_consistent = False
1248            else:
1249                self.type_consistent = False
1250
1251        self.attr_set = attr_set
1252        if not self.attr_set:
1253            self.attr_set = op['attribute-set']
1254
1255        self.type_name_conflict = False
1256        if op:
1257            self.type_name = c_lower(op.name)
1258        else:
1259            self.type_name = c_lower(attr_set)
1260            if attr_set in family.consts:
1261                self.type_name_conflict = True
1262
1263        self.cw = cw
1264
1265        self.struct = dict()
1266        if op_mode == 'notify':
1267            op_mode = 'do'
1268        for op_dir in ['request', 'reply']:
1269            if op:
1270                type_list = []
1271                if op_dir in op[op_mode]:
1272                    type_list = op[op_mode][op_dir]['attributes']
1273                self.struct[op_dir] = Struct(family, self.attr_set, type_list=type_list)
1274        if op_mode == 'event':
1275            self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1276
1277
1278class CodeWriter:
1279    def __init__(self, nlib, out_file=None, overwrite=True):
1280        self.nlib = nlib
1281        self._overwrite = overwrite
1282
1283        self._nl = False
1284        self._block_end = False
1285        self._silent_block = False
1286        self._ind = 0
1287        self._ifdef_block = None
1288        if out_file is None:
1289            self._out = os.sys.stdout
1290        else:
1291            self._out = tempfile.NamedTemporaryFile('w+')
1292            self._out_file = out_file
1293
1294    def __del__(self):
1295        self.close_out_file()
1296
1297    def close_out_file(self):
1298        if self._out == os.sys.stdout:
1299            return
1300        # Avoid modifying the file if contents didn't change
1301        self._out.flush()
1302        if not self._overwrite and os.path.isfile(self._out_file):
1303            if filecmp.cmp(self._out.name, self._out_file, shallow=False):
1304                return
1305        with open(self._out_file, 'w+') as out_file:
1306            self._out.seek(0)
1307            shutil.copyfileobj(self._out, out_file)
1308            self._out.close()
1309        self._out = os.sys.stdout
1310
1311    @classmethod
1312    def _is_cond(cls, line):
1313        return line.startswith('if') or line.startswith('while') or line.startswith('for')
1314
1315    def p(self, line, add_ind=0):
1316        if self._block_end:
1317            self._block_end = False
1318            if line.startswith('else'):
1319                line = '} ' + line
1320            else:
1321                self._out.write('\t' * self._ind + '}\n')
1322
1323        if self._nl:
1324            self._out.write('\n')
1325            self._nl = False
1326
1327        ind = self._ind
1328        if line[-1] == ':':
1329            ind -= 1
1330        if self._silent_block:
1331            ind += 1
1332        self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1333        if line[0] == '#':
1334            ind = 0
1335        if add_ind:
1336            ind += add_ind
1337        self._out.write('\t' * ind + line + '\n')
1338
1339    def nl(self):
1340        self._nl = True
1341
1342    def block_start(self, line=''):
1343        if line:
1344            line = line + ' '
1345        self.p(line + '{')
1346        self._ind += 1
1347
1348    def block_end(self, line=''):
1349        if line and line[0] not in {';', ','}:
1350            line = ' ' + line
1351        self._ind -= 1
1352        self._nl = False
1353        if not line:
1354            # Delay printing closing bracket in case "else" comes next
1355            if self._block_end:
1356                self._out.write('\t' * (self._ind + 1) + '}\n')
1357            self._block_end = True
1358        else:
1359            self.p('}' + line)
1360
1361    def write_doc_line(self, doc, indent=True):
1362        words = doc.split()
1363        line = ' *'
1364        for word in words:
1365            if len(line) + len(word) >= 79:
1366                self.p(line)
1367                line = ' *'
1368                if indent:
1369                    line += '  '
1370            line += ' ' + word
1371        self.p(line)
1372
1373    def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1374        if not args:
1375            args = ['void']
1376
1377        if doc:
1378            self.p('/*')
1379            self.p(' * ' + doc)
1380            self.p(' */')
1381
1382        oneline = qual_ret
1383        if qual_ret[-1] != '*':
1384            oneline += ' '
1385        oneline += f"{name}({', '.join(args)}){suffix}"
1386
1387        if len(oneline) < 80:
1388            self.p(oneline)
1389            return
1390
1391        v = qual_ret
1392        if len(v) > 3:
1393            self.p(v)
1394            v = ''
1395        elif qual_ret[-1] != '*':
1396            v += ' '
1397        v += name + '('
1398        ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1399        delta_ind = len(v) - len(ind)
1400        v += args[0]
1401        i = 1
1402        while i < len(args):
1403            next_len = len(v) + len(args[i])
1404            if v[0] == '\t':
1405                next_len += delta_ind
1406            if next_len > 76:
1407                self.p(v + ',')
1408                v = ind
1409            else:
1410                v += ', '
1411            v += args[i]
1412            i += 1
1413        self.p(v + ')' + suffix)
1414
1415    def write_func_lvar(self, local_vars):
1416        if not local_vars:
1417            return
1418
1419        if type(local_vars) is str:
1420            local_vars = [local_vars]
1421
1422        local_vars.sort(key=len, reverse=True)
1423        for var in local_vars:
1424            self.p(var)
1425        self.nl()
1426
1427    def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1428        self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1429        self.block_start()
1430        self.write_func_lvar(local_vars=local_vars)
1431
1432        for line in body:
1433            self.p(line)
1434        self.block_end()
1435
1436    def writes_defines(self, defines):
1437        longest = 0
1438        for define in defines:
1439            if len(define[0]) > longest:
1440                longest = len(define[0])
1441        longest = ((longest + 8) // 8) * 8
1442        for define in defines:
1443            line = '#define ' + define[0]
1444            line += '\t' * ((longest - len(define[0]) + 7) // 8)
1445            if type(define[1]) is int:
1446                line += str(define[1])
1447            elif type(define[1]) is str:
1448                line += '"' + define[1] + '"'
1449            self.p(line)
1450
1451    def write_struct_init(self, members):
1452        longest = max([len(x[0]) for x in members])
1453        longest += 1  # because we prepend a .
1454        longest = ((longest + 8) // 8) * 8
1455        for one in members:
1456            line = '.' + one[0]
1457            line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1458            line += '= ' + str(one[1]) + ','
1459            self.p(line)
1460
1461    def ifdef_block(self, config):
1462        config_option = None
1463        if config:
1464            config_option = 'CONFIG_' + c_upper(config)
1465        if self._ifdef_block == config_option:
1466            return
1467
1468        if self._ifdef_block:
1469            self.p('#endif /* ' + self._ifdef_block + ' */')
1470        if config_option:
1471            self.p('#ifdef ' + config_option)
1472        self._ifdef_block = config_option
1473
1474
1475scalars = {'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'uint', 'sint'}
1476
1477direction_to_suffix = {
1478    'reply': '_rsp',
1479    'request': '_req',
1480    '': ''
1481}
1482
1483op_mode_to_wrapper = {
1484    'do': '',
1485    'dump': '_list',
1486    'notify': '_ntf',
1487    'event': '',
1488}
1489
1490_C_KW = {
1491    'auto',
1492    'bool',
1493    'break',
1494    'case',
1495    'char',
1496    'const',
1497    'continue',
1498    'default',
1499    'do',
1500    'double',
1501    'else',
1502    'enum',
1503    'extern',
1504    'float',
1505    'for',
1506    'goto',
1507    'if',
1508    'inline',
1509    'int',
1510    'long',
1511    'register',
1512    'return',
1513    'short',
1514    'signed',
1515    'sizeof',
1516    'static',
1517    'struct',
1518    'switch',
1519    'typedef',
1520    'union',
1521    'unsigned',
1522    'void',
1523    'volatile',
1524    'while'
1525}
1526
1527
1528def rdir(direction):
1529    if direction == 'reply':
1530        return 'request'
1531    if direction == 'request':
1532        return 'reply'
1533    return direction
1534
1535
1536def op_prefix(ri, direction, deref=False):
1537    suffix = f"_{ri.type_name}"
1538
1539    if not ri.op_mode or ri.op_mode == 'do':
1540        suffix += f"{direction_to_suffix[direction]}"
1541    else:
1542        if direction == 'request':
1543            suffix += '_req_dump'
1544        else:
1545            if ri.type_consistent:
1546                if deref:
1547                    suffix += f"{direction_to_suffix[direction]}"
1548                else:
1549                    suffix += op_mode_to_wrapper[ri.op_mode]
1550            else:
1551                suffix += '_rsp'
1552                suffix += '_dump' if deref else '_list'
1553
1554    return f"{ri.family.c_name}{suffix}"
1555
1556
1557def type_name(ri, direction, deref=False):
1558    return f"struct {op_prefix(ri, direction, deref=deref)}"
1559
1560
1561def print_prototype(ri, direction, terminate=True, doc=None):
1562    suffix = ';' if terminate else ''
1563
1564    fname = ri.op.render_name
1565    if ri.op_mode == 'dump':
1566        fname += '_dump'
1567
1568    args = ['struct ynl_sock *ys']
1569    if 'request' in ri.op[ri.op_mode]:
1570        args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1571
1572    ret = 'int'
1573    if 'reply' in ri.op[ri.op_mode]:
1574        ret = f"{type_name(ri, rdir(direction))} *"
1575
1576    ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1577
1578
1579def print_req_prototype(ri):
1580    print_prototype(ri, "request", doc=ri.op['doc'])
1581
1582
1583def print_dump_prototype(ri):
1584    print_prototype(ri, "request")
1585
1586
1587def put_typol_fwd(cw, struct):
1588    cw.p(f'extern const struct ynl_policy_nest {struct.render_name}_nest;')
1589
1590
1591def put_typol(cw, struct):
1592    type_max = struct.attr_set.max_name
1593    cw.block_start(line=f'const struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1594
1595    for _, arg in struct.member_list():
1596        arg.attr_typol(cw)
1597
1598    cw.block_end(line=';')
1599    cw.nl()
1600
1601    cw.block_start(line=f'const struct ynl_policy_nest {struct.render_name}_nest =')
1602    cw.p(f'.max_attr = {type_max},')
1603    cw.p(f'.table = {struct.render_name}_policy,')
1604    cw.block_end(line=';')
1605    cw.nl()
1606
1607
1608def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1609    args = [f'int {arg_name}']
1610    if enum:
1611        args = [enum.user_type + ' ' + arg_name]
1612    cw.write_func_prot('const char *', f'{render_name}_str', args)
1613    cw.block_start()
1614    if enum and enum.type == 'flags':
1615        cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1616    cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)YNL_ARRAY_SIZE({map_name}))')
1617    cw.p('return NULL;')
1618    cw.p(f'return {map_name}[{arg_name}];')
1619    cw.block_end()
1620    cw.nl()
1621
1622
1623def put_op_name_fwd(family, cw):
1624    cw.write_func_prot('const char *', f'{family.c_name}_op_str', ['int op'], suffix=';')
1625
1626
1627def put_op_name(family, cw):
1628    map_name = f'{family.c_name}_op_strmap'
1629    cw.block_start(line=f"static const char * const {map_name}[] =")
1630    for op_name, op in family.msgs.items():
1631        if op.rsp_value:
1632            # Make sure we don't add duplicated entries, if multiple commands
1633            # produce the same response in legacy families.
1634            if family.rsp_by_value[op.rsp_value] != op:
1635                cw.p(f'// skip "{op_name}", duplicate reply value')
1636                continue
1637
1638            if op.req_value == op.rsp_value:
1639                cw.p(f'[{op.enum_name}] = "{op_name}",')
1640            else:
1641                cw.p(f'[{op.rsp_value}] = "{op_name}",')
1642    cw.block_end(line=';')
1643    cw.nl()
1644
1645    _put_enum_to_str_helper(cw, family.c_name + '_op', map_name, 'op')
1646
1647
1648def put_enum_to_str_fwd(family, cw, enum):
1649    args = [enum.user_type + ' value']
1650    cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1651
1652
1653def put_enum_to_str(family, cw, enum):
1654    map_name = f'{enum.render_name}_strmap'
1655    cw.block_start(line=f"static const char * const {map_name}[] =")
1656    for entry in enum.entries.values():
1657        cw.p(f'[{entry.value}] = "{entry.name}",')
1658    cw.block_end(line=';')
1659    cw.nl()
1660
1661    _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1662
1663
1664def put_req_nested_prototype(ri, struct, suffix=';'):
1665    func_args = ['struct nlmsghdr *nlh',
1666                 'unsigned int attr_type',
1667                 f'{struct.ptr_name}obj']
1668
1669    ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args,
1670                          suffix=suffix)
1671
1672
1673def put_req_nested(ri, struct):
1674    local_vars = []
1675    init_lines = []
1676
1677    local_vars.append('struct nlattr *nest;')
1678    init_lines.append("nest = ynl_attr_nest_start(nlh, attr_type);")
1679
1680    for _, arg in struct.member_list():
1681        if arg.presence_type() == 'count':
1682            local_vars.append('unsigned int i;')
1683            break
1684
1685    put_req_nested_prototype(ri, struct, suffix='')
1686    ri.cw.block_start()
1687    ri.cw.write_func_lvar(local_vars)
1688
1689    for line in init_lines:
1690        ri.cw.p(line)
1691
1692    for _, arg in struct.member_list():
1693        arg.attr_put(ri, "obj")
1694
1695    ri.cw.p("ynl_attr_nest_end(nlh, nest);")
1696
1697    ri.cw.nl()
1698    ri.cw.p('return 0;')
1699    ri.cw.block_end()
1700    ri.cw.nl()
1701
1702
1703def _multi_parse(ri, struct, init_lines, local_vars):
1704    if struct.nested:
1705        iter_line = "ynl_attr_for_each_nested(attr, nested)"
1706    else:
1707        if ri.fixed_hdr:
1708            local_vars += ['void *hdr;']
1709        iter_line = "ynl_attr_for_each(attr, nlh, yarg->ys->family->hdr_len)"
1710
1711    array_nests = set()
1712    multi_attrs = set()
1713    needs_parg = False
1714    for arg, aspec in struct.member_list():
1715        if aspec['type'] == 'indexed-array' and 'sub-type' in aspec:
1716            if aspec["sub-type"] == 'nest':
1717                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1718                array_nests.add(arg)
1719            elif aspec['sub-type'] in scalars:
1720                local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1721                array_nests.add(arg)
1722            else:
1723                raise Exception(f'Not supported sub-type {aspec["sub-type"]}')
1724        if 'multi-attr' in aspec:
1725            multi_attrs.add(arg)
1726        needs_parg |= 'nested-attributes' in aspec
1727    if array_nests or multi_attrs:
1728        local_vars.append('int i;')
1729    if needs_parg:
1730        local_vars.append('struct ynl_parse_arg parg;')
1731        init_lines.append('parg.ys = yarg->ys;')
1732
1733    all_multi = array_nests | multi_attrs
1734
1735    for anest in sorted(all_multi):
1736        local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1737
1738    ri.cw.block_start()
1739    ri.cw.write_func_lvar(local_vars)
1740
1741    for line in init_lines:
1742        ri.cw.p(line)
1743    ri.cw.nl()
1744
1745    for arg in struct.inherited:
1746        ri.cw.p(f'dst->{arg} = {arg};')
1747
1748    if ri.fixed_hdr:
1749        ri.cw.p('hdr = ynl_nlmsg_data_offset(nlh, sizeof(struct genlmsghdr));')
1750        ri.cw.p(f"memcpy(&dst->_hdr, hdr, sizeof({ri.fixed_hdr}));")
1751    for anest in sorted(all_multi):
1752        aspec = struct[anest]
1753        ri.cw.p(f"if (dst->{aspec.c_name})")
1754        ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1755
1756    ri.cw.nl()
1757    ri.cw.block_start(line=iter_line)
1758    ri.cw.p('unsigned int type = ynl_attr_type(attr);')
1759    ri.cw.nl()
1760
1761    first = True
1762    for _, arg in struct.member_list():
1763        good = arg.attr_get(ri, 'dst', first=first)
1764        # First may be 'unused' or 'pad', ignore those
1765        first &= not good
1766
1767    ri.cw.block_end()
1768    ri.cw.nl()
1769
1770    for anest in sorted(array_nests):
1771        aspec = struct[anest]
1772
1773        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1774        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1775        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1776        ri.cw.p('i = 0;')
1777        if 'nested-attributes' in aspec:
1778            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1779        ri.cw.block_start(line=f"ynl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1780        if 'nested-attributes' in aspec:
1781            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1782            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, ynl_attr_type(attr)))")
1783            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1784        elif aspec.sub_type in scalars:
1785            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.sub_type}(attr);")
1786        else:
1787            raise Exception(f"Nest parsing type not supported in {aspec['name']}")
1788        ri.cw.p('i++;')
1789        ri.cw.block_end()
1790        ri.cw.block_end()
1791    ri.cw.nl()
1792
1793    for anest in sorted(multi_attrs):
1794        aspec = struct[anest]
1795        ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1796        ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1797        ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1798        ri.cw.p('i = 0;')
1799        if 'nested-attributes' in aspec:
1800            ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1801        ri.cw.block_start(line=iter_line)
1802        ri.cw.block_start(line=f"if (ynl_attr_type(attr) == {aspec.enum_name})")
1803        if 'nested-attributes' in aspec:
1804            ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1805            ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1806            ri.cw.p('return YNL_PARSE_CB_ERROR;')
1807        elif aspec.type in scalars:
1808            ri.cw.p(f"dst->{aspec.c_name}[i] = ynl_attr_get_{aspec.type}(attr);")
1809        else:
1810            raise Exception('Nest parsing type not supported yet')
1811        ri.cw.p('i++;')
1812        ri.cw.block_end()
1813        ri.cw.block_end()
1814        ri.cw.block_end()
1815    ri.cw.nl()
1816
1817    if struct.nested:
1818        ri.cw.p('return 0;')
1819    else:
1820        ri.cw.p('return YNL_PARSE_CB_OK;')
1821    ri.cw.block_end()
1822    ri.cw.nl()
1823
1824
1825def parse_rsp_nested_prototype(ri, struct, suffix=';'):
1826    func_args = ['struct ynl_parse_arg *yarg',
1827                 'const struct nlattr *nested']
1828    for arg in struct.inherited:
1829        func_args.append('__u32 ' + arg)
1830
1831    ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args,
1832                          suffix=suffix)
1833
1834
1835def parse_rsp_nested(ri, struct):
1836    parse_rsp_nested_prototype(ri, struct, suffix='')
1837
1838    local_vars = ['const struct nlattr *attr;',
1839                  f'{struct.ptr_name}dst = yarg->data;']
1840    init_lines = []
1841
1842    if struct.member_list():
1843        _multi_parse(ri, struct, init_lines, local_vars)
1844    else:
1845        # Empty nest
1846        ri.cw.block_start()
1847        ri.cw.p('return 0;')
1848        ri.cw.block_end()
1849        ri.cw.nl()
1850
1851
1852def parse_rsp_msg(ri, deref=False):
1853    if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1854        return
1855
1856    func_args = ['const struct nlmsghdr *nlh',
1857                 'struct ynl_parse_arg *yarg']
1858
1859    local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1860                  'const struct nlattr *attr;']
1861    init_lines = ['dst = yarg->data;']
1862
1863    ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1864
1865    if ri.struct["reply"].member_list():
1866        _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1867    else:
1868        # Empty reply
1869        ri.cw.block_start()
1870        ri.cw.p('return YNL_PARSE_CB_OK;')
1871        ri.cw.block_end()
1872        ri.cw.nl()
1873
1874
1875def print_req(ri):
1876    ret_ok = '0'
1877    ret_err = '-1'
1878    direction = "request"
1879    local_vars = ['struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };',
1880                  'struct nlmsghdr *nlh;',
1881                  'int err;']
1882
1883    if 'reply' in ri.op[ri.op_mode]:
1884        ret_ok = 'rsp'
1885        ret_err = 'NULL'
1886        local_vars += [f'{type_name(ri, rdir(direction))} *rsp;']
1887
1888    if ri.fixed_hdr:
1889        local_vars += ['size_t hdr_len;',
1890                       'void *hdr;']
1891
1892    for _, attr in ri.struct["request"].member_list():
1893        if attr.presence_type() == 'count':
1894            local_vars += ['unsigned int i;']
1895            break
1896
1897    print_prototype(ri, direction, terminate=False)
1898    ri.cw.block_start()
1899    ri.cw.write_func_lvar(local_vars)
1900
1901    ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1902
1903    ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1904    if 'reply' in ri.op[ri.op_mode]:
1905        ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1906    ri.cw.nl()
1907
1908    if ri.fixed_hdr:
1909        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1910        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1911        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1912        ri.cw.nl()
1913
1914    for _, attr in ri.struct["request"].member_list():
1915        attr.attr_put(ri, "req")
1916    ri.cw.nl()
1917
1918    if 'reply' in ri.op[ri.op_mode]:
1919        ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1920        ri.cw.p('yrs.yarg.data = rsp;')
1921        ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1922        if ri.op.value is not None:
1923            ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1924        else:
1925            ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1926        ri.cw.nl()
1927    ri.cw.p("err = ynl_exec(ys, nlh, &yrs);")
1928    ri.cw.p('if (err < 0)')
1929    if 'reply' in ri.op[ri.op_mode]:
1930        ri.cw.p('goto err_free;')
1931    else:
1932        ri.cw.p('return -1;')
1933    ri.cw.nl()
1934
1935    ri.cw.p(f"return {ret_ok};")
1936    ri.cw.nl()
1937
1938    if 'reply' in ri.op[ri.op_mode]:
1939        ri.cw.p('err_free:')
1940        ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1941        ri.cw.p(f"return {ret_err};")
1942
1943    ri.cw.block_end()
1944
1945
1946def print_dump(ri):
1947    direction = "request"
1948    print_prototype(ri, direction, terminate=False)
1949    ri.cw.block_start()
1950    local_vars = ['struct ynl_dump_state yds = {};',
1951                  'struct nlmsghdr *nlh;',
1952                  'int err;']
1953
1954    if ri.fixed_hdr:
1955        local_vars += ['size_t hdr_len;',
1956                       'void *hdr;']
1957
1958    ri.cw.write_func_lvar(local_vars)
1959
1960    ri.cw.p('yds.yarg.ys = ys;')
1961    ri.cw.p(f"yds.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1962    ri.cw.p("yds.yarg.data = NULL;")
1963    ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1964    ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1965    if ri.op.value is not None:
1966        ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1967    else:
1968        ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1969    ri.cw.nl()
1970    ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1971
1972    if ri.fixed_hdr:
1973        ri.cw.p("hdr_len = sizeof(req->_hdr);")
1974        ri.cw.p("hdr = ynl_nlmsg_put_extra_header(nlh, hdr_len);")
1975        ri.cw.p("memcpy(hdr, &req->_hdr, hdr_len);")
1976        ri.cw.nl()
1977
1978    if "request" in ri.op[ri.op_mode]:
1979        ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1980        ri.cw.nl()
1981        for _, attr in ri.struct["request"].member_list():
1982            attr.attr_put(ri, "req")
1983    ri.cw.nl()
1984
1985    ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1986    ri.cw.p('if (err < 0)')
1987    ri.cw.p('goto free_list;')
1988    ri.cw.nl()
1989
1990    ri.cw.p('return yds.first;')
1991    ri.cw.nl()
1992    ri.cw.p('free_list:')
1993    ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1994    ri.cw.p('return NULL;')
1995    ri.cw.block_end()
1996
1997
1998def call_free(ri, direction, var):
1999    return f"{op_prefix(ri, direction)}_free({var});"
2000
2001
2002def free_arg_name(direction):
2003    if direction:
2004        return direction_to_suffix[direction][1:]
2005    return 'obj'
2006
2007
2008def print_alloc_wrapper(ri, direction):
2009    name = op_prefix(ri, direction)
2010    ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
2011    ri.cw.block_start()
2012    ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
2013    ri.cw.block_end()
2014
2015
2016def print_free_prototype(ri, direction, suffix=';'):
2017    name = op_prefix(ri, direction)
2018    struct_name = name
2019    if ri.type_name_conflict:
2020        struct_name += '_'
2021    arg = free_arg_name(direction)
2022    ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
2023
2024
2025def _print_type(ri, direction, struct):
2026    suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
2027    if not direction and ri.type_name_conflict:
2028        suffix += '_'
2029
2030    if ri.op_mode == 'dump':
2031        suffix += '_dump'
2032
2033    ri.cw.block_start(line=f"struct {ri.family.c_name}{suffix}")
2034
2035    if ri.fixed_hdr:
2036        ri.cw.p(ri.fixed_hdr + ' _hdr;')
2037        ri.cw.nl()
2038
2039    meta_started = False
2040    for _, attr in struct.member_list():
2041        for type_filter in ['len', 'bit']:
2042            line = attr.presence_member(ri.ku_space, type_filter)
2043            if line:
2044                if not meta_started:
2045                    ri.cw.block_start(line=f"struct")
2046                    meta_started = True
2047                ri.cw.p(line)
2048    if meta_started:
2049        ri.cw.block_end(line='_present;')
2050        ri.cw.nl()
2051
2052    for arg in struct.inherited:
2053        ri.cw.p(f"__u32 {arg};")
2054
2055    for _, attr in struct.member_list():
2056        attr.struct_member(ri)
2057
2058    ri.cw.block_end(line=';')
2059    ri.cw.nl()
2060
2061
2062def print_type(ri, direction):
2063    _print_type(ri, direction, ri.struct[direction])
2064
2065
2066def print_type_full(ri, struct):
2067    _print_type(ri, "", struct)
2068
2069
2070def print_type_helpers(ri, direction, deref=False):
2071    print_free_prototype(ri, direction)
2072    ri.cw.nl()
2073
2074    if ri.ku_space == 'user' and direction == 'request':
2075        for _, attr in ri.struct[direction].member_list():
2076            attr.setter(ri, ri.attr_set, direction, deref=deref)
2077    ri.cw.nl()
2078
2079
2080def print_req_type_helpers(ri):
2081    if len(ri.struct["request"].attr_list) == 0:
2082        return
2083    print_alloc_wrapper(ri, "request")
2084    print_type_helpers(ri, "request")
2085
2086
2087def print_rsp_type_helpers(ri):
2088    if 'reply' not in ri.op[ri.op_mode]:
2089        return
2090    print_type_helpers(ri, "reply")
2091
2092
2093def print_parse_prototype(ri, direction, terminate=True):
2094    suffix = "_rsp" if direction == "reply" else "_req"
2095    term = ';' if terminate else ''
2096
2097    ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
2098                          ['const struct nlattr **tb',
2099                           f"struct {ri.op.render_name}{suffix} *req"],
2100                          suffix=term)
2101
2102
2103def print_req_type(ri):
2104    if len(ri.struct["request"].attr_list) == 0:
2105        return
2106    print_type(ri, "request")
2107
2108
2109def print_req_free(ri):
2110    if 'request' not in ri.op[ri.op_mode]:
2111        return
2112    _free_type(ri, 'request', ri.struct['request'])
2113
2114
2115def print_rsp_type(ri):
2116    if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
2117        direction = 'reply'
2118    elif ri.op_mode == 'event':
2119        direction = 'reply'
2120    else:
2121        return
2122    print_type(ri, direction)
2123
2124
2125def print_wrapped_type(ri):
2126    ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
2127    if ri.op_mode == 'dump':
2128        ri.cw.p(f"{type_name(ri, 'reply')} *next;")
2129    elif ri.op_mode == 'notify' or ri.op_mode == 'event':
2130        ri.cw.p('__u16 family;')
2131        ri.cw.p('__u8 cmd;')
2132        ri.cw.p('struct ynl_ntf_base_type *next;')
2133        ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
2134    ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__((aligned(8)));")
2135    ri.cw.block_end(line=';')
2136    ri.cw.nl()
2137    print_free_prototype(ri, 'reply')
2138    ri.cw.nl()
2139
2140
2141def _free_type_members_iter(ri, struct):
2142    for _, attr in struct.member_list():
2143        if attr.free_needs_iter():
2144            ri.cw.p('unsigned int i;')
2145            ri.cw.nl()
2146            break
2147
2148
2149def _free_type_members(ri, var, struct, ref=''):
2150    for _, attr in struct.member_list():
2151        attr.free(ri, var, ref)
2152
2153
2154def _free_type(ri, direction, struct):
2155    var = free_arg_name(direction)
2156
2157    print_free_prototype(ri, direction, suffix='')
2158    ri.cw.block_start()
2159    _free_type_members_iter(ri, struct)
2160    _free_type_members(ri, var, struct)
2161    if direction:
2162        ri.cw.p(f'free({var});')
2163    ri.cw.block_end()
2164    ri.cw.nl()
2165
2166
2167def free_rsp_nested_prototype(ri):
2168        print_free_prototype(ri, "")
2169
2170
2171def free_rsp_nested(ri, struct):
2172    _free_type(ri, "", struct)
2173
2174
2175def print_rsp_free(ri):
2176    if 'reply' not in ri.op[ri.op_mode]:
2177        return
2178    _free_type(ri, 'reply', ri.struct['reply'])
2179
2180
2181def print_dump_type_free(ri):
2182    sub_type = type_name(ri, 'reply')
2183
2184    print_free_prototype(ri, 'reply', suffix='')
2185    ri.cw.block_start()
2186    ri.cw.p(f"{sub_type} *next = rsp;")
2187    ri.cw.nl()
2188    ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
2189    _free_type_members_iter(ri, ri.struct['reply'])
2190    ri.cw.p('rsp = next;')
2191    ri.cw.p('next = rsp->next;')
2192    ri.cw.nl()
2193
2194    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2195    ri.cw.p(f'free(rsp);')
2196    ri.cw.block_end()
2197    ri.cw.block_end()
2198    ri.cw.nl()
2199
2200
2201def print_ntf_type_free(ri):
2202    print_free_prototype(ri, 'reply', suffix='')
2203    ri.cw.block_start()
2204    _free_type_members_iter(ri, ri.struct['reply'])
2205    _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
2206    ri.cw.p(f'free(rsp);')
2207    ri.cw.block_end()
2208    ri.cw.nl()
2209
2210
2211def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
2212    if terminate and ri and policy_should_be_static(struct.family):
2213        return
2214
2215    if terminate:
2216        prefix = 'extern '
2217    else:
2218        if ri and policy_should_be_static(struct.family):
2219            prefix = 'static '
2220        else:
2221            prefix = ''
2222
2223    suffix = ';' if terminate else ' = {'
2224
2225    max_attr = struct.attr_max_val
2226    if ri:
2227        name = ri.op.render_name
2228        if ri.op.dual_policy:
2229            name += '_' + ri.op_mode
2230    else:
2231        name = struct.render_name
2232    cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
2233
2234
2235def print_req_policy(cw, struct, ri=None):
2236    if ri and ri.op:
2237        cw.ifdef_block(ri.op.get('config-cond', None))
2238    print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
2239    for _, arg in struct.member_list():
2240        arg.attr_policy(cw)
2241    cw.p("};")
2242    cw.ifdef_block(None)
2243    cw.nl()
2244
2245
2246def kernel_can_gen_family_struct(family):
2247    return family.proto == 'genetlink'
2248
2249
2250def policy_should_be_static(family):
2251    return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
2252
2253
2254def print_kernel_policy_ranges(family, cw):
2255    first = True
2256    for _, attr_set in family.attr_sets.items():
2257        if attr_set.subset_of:
2258            continue
2259
2260        for _, attr in attr_set.items():
2261            if not attr.request:
2262                continue
2263            if 'full-range' not in attr.checks:
2264                continue
2265
2266            if first:
2267                cw.p('/* Integer value ranges */')
2268                first = False
2269
2270            sign = '' if attr.type[0] == 'u' else '_signed'
2271            suffix = 'ULL' if attr.type[0] == 'u' else 'LL'
2272            cw.block_start(line=f'static const struct netlink_range_validation{sign} {c_lower(attr.enum_name)}_range =')
2273            members = []
2274            if 'min' in attr.checks:
2275                members.append(('min', attr.get_limit_str('min', suffix=suffix)))
2276            if 'max' in attr.checks:
2277                members.append(('max', attr.get_limit_str('max', suffix=suffix)))
2278            cw.write_struct_init(members)
2279            cw.block_end(line=';')
2280            cw.nl()
2281
2282
2283def print_kernel_op_table_fwd(family, cw, terminate):
2284    exported = not kernel_can_gen_family_struct(family)
2285
2286    if not terminate or exported:
2287        cw.p(f"/* Ops table for {family.ident_name} */")
2288
2289        pol_to_struct = {'global': 'genl_small_ops',
2290                         'per-op': 'genl_ops',
2291                         'split': 'genl_split_ops'}
2292        struct_type = pol_to_struct[family.kernel_policy]
2293
2294        if not exported:
2295            cnt = ""
2296        elif family.kernel_policy == 'split':
2297            cnt = 0
2298            for op in family.ops.values():
2299                if 'do' in op:
2300                    cnt += 1
2301                if 'dump' in op:
2302                    cnt += 1
2303        else:
2304            cnt = len(family.ops)
2305
2306        qual = 'static const' if not exported else 'const'
2307        line = f"{qual} struct {struct_type} {family.c_name}_nl_ops[{cnt}]"
2308        if terminate:
2309            cw.p(f"extern {line};")
2310        else:
2311            cw.block_start(line=line + ' =')
2312
2313    if not terminate:
2314        return
2315
2316    cw.nl()
2317    for name in family.hooks['pre']['do']['list']:
2318        cw.write_func_prot('int', c_lower(name),
2319                           ['const struct genl_split_ops *ops',
2320                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2321    for name in family.hooks['post']['do']['list']:
2322        cw.write_func_prot('void', c_lower(name),
2323                           ['const struct genl_split_ops *ops',
2324                            'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2325    for name in family.hooks['pre']['dump']['list']:
2326        cw.write_func_prot('int', c_lower(name),
2327                           ['struct netlink_callback *cb'], suffix=';')
2328    for name in family.hooks['post']['dump']['list']:
2329        cw.write_func_prot('int', c_lower(name),
2330                           ['struct netlink_callback *cb'], suffix=';')
2331
2332    cw.nl()
2333
2334    for op_name, op in family.ops.items():
2335        if op.is_async:
2336            continue
2337
2338        if 'do' in op:
2339            name = c_lower(f"{family.ident_name}-nl-{op_name}-doit")
2340            cw.write_func_prot('int', name,
2341                               ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
2342
2343        if 'dump' in op:
2344            name = c_lower(f"{family.ident_name}-nl-{op_name}-dumpit")
2345            cw.write_func_prot('int', name,
2346                               ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
2347    cw.nl()
2348
2349
2350def print_kernel_op_table_hdr(family, cw):
2351    print_kernel_op_table_fwd(family, cw, terminate=True)
2352
2353
2354def print_kernel_op_table(family, cw):
2355    print_kernel_op_table_fwd(family, cw, terminate=False)
2356    if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
2357        for op_name, op in family.ops.items():
2358            if op.is_async:
2359                continue
2360
2361            cw.ifdef_block(op.get('config-cond', None))
2362            cw.block_start()
2363            members = [('cmd', op.enum_name)]
2364            if 'dont-validate' in op:
2365                members.append(('validate',
2366                                ' | '.join([c_upper('genl-dont-validate-' + x)
2367                                            for x in op['dont-validate']])), )
2368            for op_mode in ['do', 'dump']:
2369                if op_mode in op:
2370                    name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2371                    members.append((op_mode + 'it', name))
2372            if family.kernel_policy == 'per-op':
2373                struct = Struct(family, op['attribute-set'],
2374                                type_list=op['do']['request']['attributes'])
2375
2376                name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2377                members.append(('policy', name))
2378                members.append(('maxattr', struct.attr_max_val.enum_name))
2379            if 'flags' in op:
2380                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
2381            cw.write_struct_init(members)
2382            cw.block_end(line=',')
2383    elif family.kernel_policy == 'split':
2384        cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
2385                    'dump': {'pre': 'start', 'post': 'done'}}
2386
2387        for op_name, op in family.ops.items():
2388            for op_mode in ['do', 'dump']:
2389                if op.is_async or op_mode not in op:
2390                    continue
2391
2392                cw.ifdef_block(op.get('config-cond', None))
2393                cw.block_start()
2394                members = [('cmd', op.enum_name)]
2395                if 'dont-validate' in op:
2396                    dont_validate = []
2397                    for x in op['dont-validate']:
2398                        if op_mode == 'do' and x in ['dump', 'dump-strict']:
2399                            continue
2400                        if op_mode == "dump" and x == 'strict':
2401                            continue
2402                        dont_validate.append(x)
2403
2404                    if dont_validate:
2405                        members.append(('validate',
2406                                        ' | '.join([c_upper('genl-dont-validate-' + x)
2407                                                    for x in dont_validate])), )
2408                name = c_lower(f"{family.ident_name}-nl-{op_name}-{op_mode}it")
2409                if 'pre' in op[op_mode]:
2410                    members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2411                members.append((op_mode + 'it', name))
2412                if 'post' in op[op_mode]:
2413                    members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2414                if 'request' in op[op_mode]:
2415                    struct = Struct(family, op['attribute-set'],
2416                                    type_list=op[op_mode]['request']['attributes'])
2417
2418                    if op.dual_policy:
2419                        name = c_lower(f"{family.ident_name}-{op_name}-{op_mode}-nl-policy")
2420                    else:
2421                        name = c_lower(f"{family.ident_name}-{op_name}-nl-policy")
2422                    members.append(('policy', name))
2423                    members.append(('maxattr', struct.attr_max_val.enum_name))
2424                flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2425                members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2426                cw.write_struct_init(members)
2427                cw.block_end(line=',')
2428    cw.ifdef_block(None)
2429
2430    cw.block_end(line=';')
2431    cw.nl()
2432
2433
2434def print_kernel_mcgrp_hdr(family, cw):
2435    if not family.mcgrps['list']:
2436        return
2437
2438    cw.block_start('enum')
2439    for grp in family.mcgrps['list']:
2440        grp_id = c_upper(f"{family.ident_name}-nlgrp-{grp['name']},")
2441        cw.p(grp_id)
2442    cw.block_end(';')
2443    cw.nl()
2444
2445
2446def print_kernel_mcgrp_src(family, cw):
2447    if not family.mcgrps['list']:
2448        return
2449
2450    cw.block_start('static const struct genl_multicast_group ' + family.c_name + '_nl_mcgrps[] =')
2451    for grp in family.mcgrps['list']:
2452        name = grp['name']
2453        grp_id = c_upper(f"{family.ident_name}-nlgrp-{name}")
2454        cw.p('[' + grp_id + '] = { "' + name + '", },')
2455    cw.block_end(';')
2456    cw.nl()
2457
2458
2459def print_kernel_family_struct_hdr(family, cw):
2460    if not kernel_can_gen_family_struct(family):
2461        return
2462
2463    cw.p(f"extern struct genl_family {family.c_name}_nl_family;")
2464    cw.nl()
2465    if 'sock-priv' in family.kernel_family:
2466        cw.p(f'void {family.c_name}_nl_sock_priv_init({family.kernel_family["sock-priv"]} *priv);')
2467        cw.p(f'void {family.c_name}_nl_sock_priv_destroy({family.kernel_family["sock-priv"]} *priv);')
2468        cw.nl()
2469
2470
2471def print_kernel_family_struct_src(family, cw):
2472    if not kernel_can_gen_family_struct(family):
2473        return
2474
2475    if 'sock-priv' in family.kernel_family:
2476        # Generate "trampolines" to make CFI happy
2477        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_init",
2478                      [f"{family.c_name}_nl_sock_priv_init(priv);"],
2479                      ["void *priv"])
2480        cw.nl()
2481        cw.write_func("static void", f"__{family.c_name}_nl_sock_priv_destroy",
2482                      [f"{family.c_name}_nl_sock_priv_destroy(priv);"],
2483                      ["void *priv"])
2484        cw.nl()
2485
2486    cw.block_start(f"struct genl_family {family.ident_name}_nl_family __ro_after_init =")
2487    cw.p('.name\t\t= ' + family.fam_key + ',')
2488    cw.p('.version\t= ' + family.ver_key + ',')
2489    cw.p('.netnsok\t= true,')
2490    cw.p('.parallel_ops\t= true,')
2491    cw.p('.module\t\t= THIS_MODULE,')
2492    if family.kernel_policy == 'per-op':
2493        cw.p(f'.ops\t\t= {family.c_name}_nl_ops,')
2494        cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2495    elif family.kernel_policy == 'split':
2496        cw.p(f'.split_ops\t= {family.c_name}_nl_ops,')
2497        cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.c_name}_nl_ops),')
2498    if family.mcgrps['list']:
2499        cw.p(f'.mcgrps\t\t= {family.c_name}_nl_mcgrps,')
2500        cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.c_name}_nl_mcgrps),')
2501    if 'sock-priv' in family.kernel_family:
2502        cw.p(f'.sock_priv_size\t= sizeof({family.kernel_family["sock-priv"]}),')
2503        cw.p(f'.sock_priv_init\t= __{family.c_name}_nl_sock_priv_init,')
2504        cw.p(f'.sock_priv_destroy = __{family.c_name}_nl_sock_priv_destroy,')
2505    cw.block_end(';')
2506
2507
2508def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2509    start_line = 'enum'
2510    if enum_name in obj:
2511        if obj[enum_name]:
2512            start_line = 'enum ' + c_lower(obj[enum_name])
2513    elif ckey and ckey in obj:
2514        start_line = 'enum ' + family.c_name + '_' + c_lower(obj[ckey])
2515    cw.block_start(line=start_line)
2516
2517
2518def render_uapi_unified(family, cw, max_by_define, separate_ntf):
2519    max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2520    cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2521    max_value = f"({cnt_name} - 1)"
2522
2523    uapi_enum_start(family, cw, family['operations'], 'enum-name')
2524    val = 0
2525    for op in family.msgs.values():
2526        if separate_ntf and ('notify' in op or 'event' in op):
2527            continue
2528
2529        suffix = ','
2530        if op.value != val:
2531            suffix = f" = {op.value},"
2532            val = op.value
2533        cw.p(op.enum_name + suffix)
2534        val += 1
2535    cw.nl()
2536    cw.p(cnt_name + ('' if max_by_define else ','))
2537    if not max_by_define:
2538        cw.p(f"{max_name} = {max_value}")
2539    cw.block_end(line=';')
2540    if max_by_define:
2541        cw.p(f"#define {max_name} {max_value}")
2542    cw.nl()
2543
2544
2545def render_uapi_directional(family, cw, max_by_define):
2546    max_name = f"{family.op_prefix}USER_MAX"
2547    cnt_name = f"__{family.op_prefix}USER_CNT"
2548    max_value = f"({cnt_name} - 1)"
2549
2550    cw.block_start(line='enum')
2551    cw.p(c_upper(f'{family.name}_MSG_USER_NONE = 0,'))
2552    val = 0
2553    for op in family.msgs.values():
2554        if 'do' in op and 'event' not in op:
2555            suffix = ','
2556            if op.value and op.value != val:
2557                suffix = f" = {op.value},"
2558                val = op.value
2559            cw.p(op.enum_name + suffix)
2560            val += 1
2561    cw.nl()
2562    cw.p(cnt_name + ('' if max_by_define else ','))
2563    if not max_by_define:
2564        cw.p(f"{max_name} = {max_value}")
2565    cw.block_end(line=';')
2566    if max_by_define:
2567        cw.p(f"#define {max_name} {max_value}")
2568    cw.nl()
2569
2570    max_name = f"{family.op_prefix}KERNEL_MAX"
2571    cnt_name = f"__{family.op_prefix}KERNEL_CNT"
2572    max_value = f"({cnt_name} - 1)"
2573
2574    cw.block_start(line='enum')
2575    cw.p(c_upper(f'{family.name}_MSG_KERNEL_NONE = 0,'))
2576    val = 0
2577    for op in family.msgs.values():
2578        if ('do' in op and 'reply' in op['do']) or 'notify' in op or 'event' in op:
2579            enum_name = op.enum_name
2580            if 'event' not in op and 'notify' not in op:
2581                enum_name = f'{enum_name}_REPLY'
2582
2583            suffix = ','
2584            if op.value and op.value != val:
2585                suffix = f" = {op.value},"
2586                val = op.value
2587            cw.p(enum_name + suffix)
2588            val += 1
2589    cw.nl()
2590    cw.p(cnt_name + ('' if max_by_define else ','))
2591    if not max_by_define:
2592        cw.p(f"{max_name} = {max_value}")
2593    cw.block_end(line=';')
2594    if max_by_define:
2595        cw.p(f"#define {max_name} {max_value}")
2596    cw.nl()
2597
2598
2599def render_uapi(family, cw):
2600    hdr_prot = f"_UAPI_LINUX_{c_upper(family.uapi_header_name)}_H"
2601    hdr_prot = hdr_prot.replace('/', '_')
2602    cw.p('#ifndef ' + hdr_prot)
2603    cw.p('#define ' + hdr_prot)
2604    cw.nl()
2605
2606    defines = [(family.fam_key, family["name"]),
2607               (family.ver_key, family.get('version', 1))]
2608    cw.writes_defines(defines)
2609    cw.nl()
2610
2611    defines = []
2612    for const in family['definitions']:
2613        if const.get('header'):
2614            continue
2615
2616        if const['type'] != 'const':
2617            cw.writes_defines(defines)
2618            defines = []
2619            cw.nl()
2620
2621        # Write kdoc for enum and flags (one day maybe also structs)
2622        if const['type'] == 'enum' or const['type'] == 'flags':
2623            enum = family.consts[const['name']]
2624
2625            if enum.header:
2626                continue
2627
2628            if enum.has_doc():
2629                if enum.has_entry_doc():
2630                    cw.p('/**')
2631                    doc = ''
2632                    if 'doc' in enum:
2633                        doc = ' - ' + enum['doc']
2634                    cw.write_doc_line(enum.enum_name + doc)
2635                else:
2636                    cw.p('/*')
2637                    cw.write_doc_line(enum['doc'], indent=False)
2638                for entry in enum.entries.values():
2639                    if entry.has_doc():
2640                        doc = '@' + entry.c_name + ': ' + entry['doc']
2641                        cw.write_doc_line(doc)
2642                cw.p(' */')
2643
2644            uapi_enum_start(family, cw, const, 'name')
2645            name_pfx = const.get('name-prefix', f"{family.ident_name}-{const['name']}-")
2646            for entry in enum.entries.values():
2647                suffix = ','
2648                if entry.value_change:
2649                    suffix = f" = {entry.user_value()}" + suffix
2650                cw.p(entry.c_name + suffix)
2651
2652            if const.get('render-max', False):
2653                cw.nl()
2654                cw.p('/* private: */')
2655                if const['type'] == 'flags':
2656                    max_name = c_upper(name_pfx + 'mask')
2657                    max_val = f' = {enum.get_mask()},'
2658                    cw.p(max_name + max_val)
2659                else:
2660                    cnt_name = enum.enum_cnt_name
2661                    max_name = c_upper(name_pfx + 'max')
2662                    if not cnt_name:
2663                        cnt_name = '__' + name_pfx + 'max'
2664                    cw.p(c_upper(cnt_name) + ',')
2665                    cw.p(max_name + ' = (' + c_upper(cnt_name) + ' - 1)')
2666            cw.block_end(line=';')
2667            cw.nl()
2668        elif const['type'] == 'const':
2669            defines.append([c_upper(family.get('c-define-name',
2670                                               f"{family.ident_name}-{const['name']}")),
2671                            const['value']])
2672
2673    if defines:
2674        cw.writes_defines(defines)
2675        cw.nl()
2676
2677    max_by_define = family.get('max-by-define', False)
2678
2679    for _, attr_set in family.attr_sets.items():
2680        if attr_set.subset_of:
2681            continue
2682
2683        max_value = f"({attr_set.cnt_name} - 1)"
2684
2685        val = 0
2686        uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2687        for _, attr in attr_set.items():
2688            suffix = ','
2689            if attr.value != val:
2690                suffix = f" = {attr.value},"
2691                val = attr.value
2692            val += 1
2693            cw.p(attr.enum_name + suffix)
2694        if attr_set.items():
2695            cw.nl()
2696        cw.p(attr_set.cnt_name + ('' if max_by_define else ','))
2697        if not max_by_define:
2698            cw.p(f"{attr_set.max_name} = {max_value}")
2699        cw.block_end(line=';')
2700        if max_by_define:
2701            cw.p(f"#define {attr_set.max_name} {max_value}")
2702        cw.nl()
2703
2704    # Commands
2705    separate_ntf = 'async-prefix' in family['operations']
2706
2707    if family.msg_id_model == 'unified':
2708        render_uapi_unified(family, cw, max_by_define, separate_ntf)
2709    elif family.msg_id_model == 'directional':
2710        render_uapi_directional(family, cw, max_by_define)
2711    else:
2712        raise Exception(f'Unsupported message enum-model {family.msg_id_model}')
2713
2714    if separate_ntf:
2715        uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2716        for op in family.msgs.values():
2717            if separate_ntf and not ('notify' in op or 'event' in op):
2718                continue
2719
2720            suffix = ','
2721            if 'value' in op:
2722                suffix = f" = {op['value']},"
2723            cw.p(op.enum_name + suffix)
2724        cw.block_end(line=';')
2725        cw.nl()
2726
2727    # Multicast
2728    defines = []
2729    for grp in family.mcgrps['list']:
2730        name = grp['name']
2731        defines.append([c_upper(grp.get('c-define-name', f"{family.ident_name}-mcgrp-{name}")),
2732                        f'{name}'])
2733    cw.nl()
2734    if defines:
2735        cw.writes_defines(defines)
2736        cw.nl()
2737
2738    cw.p(f'#endif /* {hdr_prot} */')
2739
2740
2741def _render_user_ntf_entry(ri, op):
2742    ri.cw.block_start(line=f"[{op.enum_name}] = ")
2743    ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2744    ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2745    ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2746    ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2747    ri.cw.block_end(line=',')
2748
2749
2750def render_user_family(family, cw, prototype):
2751    symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2752    if prototype:
2753        cw.p(f'extern {symbol};')
2754        return
2755
2756    if family.ntfs:
2757        cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2758        for ntf_op_name, ntf_op in family.ntfs.items():
2759            if 'notify' in ntf_op:
2760                op = family.ops[ntf_op['notify']]
2761                ri = RenderInfo(cw, family, "user", op, "notify")
2762            elif 'event' in ntf_op:
2763                ri = RenderInfo(cw, family, "user", ntf_op, "event")
2764            else:
2765                raise Exception('Invalid notification ' + ntf_op_name)
2766            _render_user_ntf_entry(ri, ntf_op)
2767        for op_name, op in family.ops.items():
2768            if 'event' not in op:
2769                continue
2770            ri = RenderInfo(cw, family, "user", op, "event")
2771            _render_user_ntf_entry(ri, op)
2772        cw.block_end(line=";")
2773        cw.nl()
2774
2775    cw.block_start(f'{symbol} = ')
2776    cw.p(f'.name\t\t= "{family.c_name}",')
2777    if family.fixed_header:
2778        cw.p(f'.hdr_len\t= sizeof(struct genlmsghdr) + sizeof(struct {c_lower(family.fixed_header)}),')
2779    else:
2780        cw.p('.hdr_len\t= sizeof(struct genlmsghdr),')
2781    if family.ntfs:
2782        cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2783        cw.p(f".ntf_info_size\t= YNL_ARRAY_SIZE({family['name']}_ntf_info),")
2784    cw.block_end(line=';')
2785
2786
2787def family_contains_bitfield32(family):
2788    for _, attr_set in family.attr_sets.items():
2789        if attr_set.subset_of:
2790            continue
2791        for _, attr in attr_set.items():
2792            if attr.type == "bitfield32":
2793                return True
2794    return False
2795
2796
2797def find_kernel_root(full_path):
2798    sub_path = ''
2799    while True:
2800        sub_path = os.path.join(os.path.basename(full_path), sub_path)
2801        full_path = os.path.dirname(full_path)
2802        maintainers = os.path.join(full_path, "MAINTAINERS")
2803        if os.path.exists(maintainers):
2804            return full_path, sub_path[:-1]
2805
2806
2807def main():
2808    parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2809    parser.add_argument('--mode', dest='mode', type=str, required=True,
2810                        choices=('user', 'kernel', 'uapi'))
2811    parser.add_argument('--spec', dest='spec', type=str, required=True)
2812    parser.add_argument('--header', dest='header', action='store_true', default=None)
2813    parser.add_argument('--source', dest='header', action='store_false')
2814    parser.add_argument('--user-header', nargs='+', default=[])
2815    parser.add_argument('--cmp-out', action='store_true', default=None,
2816                        help='Do not overwrite the output file if the new output is identical to the old')
2817    parser.add_argument('--exclude-op', action='append', default=[])
2818    parser.add_argument('-o', dest='out_file', type=str, default=None)
2819    args = parser.parse_args()
2820
2821    if args.header is None:
2822        parser.error("--header or --source is required")
2823
2824    exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2825
2826    try:
2827        parsed = Family(args.spec, exclude_ops)
2828        if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2829            print('Spec license:', parsed.license)
2830            print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2831            os.sys.exit(1)
2832    except yaml.YAMLError as exc:
2833        print(exc)
2834        os.sys.exit(1)
2835        return
2836
2837    cw = CodeWriter(BaseNlLib(), args.out_file, overwrite=(not args.cmp_out))
2838
2839    _, spec_kernel = find_kernel_root(args.spec)
2840    if args.mode == 'uapi' or args.header:
2841        cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2842    else:
2843        cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2844    cw.p("/* Do not edit directly, auto-generated from: */")
2845    cw.p(f"/*\t{spec_kernel} */")
2846    cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2847    if args.exclude_op or args.user_header:
2848        line = ''
2849        line += ' --user-header '.join([''] + args.user_header)
2850        line += ' --exclude-op '.join([''] + args.exclude_op)
2851        cw.p(f'/* YNL-ARG{line} */')
2852    cw.nl()
2853
2854    if args.mode == 'uapi':
2855        render_uapi(parsed, cw)
2856        return
2857
2858    hdr_prot = f"_LINUX_{parsed.c_name.upper()}_GEN_H"
2859    if args.header:
2860        cw.p('#ifndef ' + hdr_prot)
2861        cw.p('#define ' + hdr_prot)
2862        cw.nl()
2863
2864    if args.out_file:
2865        hdr_file = os.path.basename(args.out_file[:-2]) + ".h"
2866    else:
2867        hdr_file = "generated_header_file.h"
2868
2869    if args.mode == 'kernel':
2870        cw.p('#include <net/netlink.h>')
2871        cw.p('#include <net/genetlink.h>')
2872        cw.nl()
2873        if not args.header:
2874            if args.out_file:
2875                cw.p(f'#include "{hdr_file}"')
2876            cw.nl()
2877        headers = ['uapi/' + parsed.uapi_header]
2878        headers += parsed.kernel_family.get('headers', [])
2879    else:
2880        cw.p('#include <stdlib.h>')
2881        cw.p('#include <string.h>')
2882        if args.header:
2883            cw.p('#include <linux/types.h>')
2884            if family_contains_bitfield32(parsed):
2885                cw.p('#include <linux/netlink.h>')
2886        else:
2887            cw.p(f'#include "{hdr_file}"')
2888            cw.p('#include "ynl.h"')
2889        headers = []
2890    for definition in parsed['definitions']:
2891        if 'header' in definition:
2892            headers.append(definition['header'])
2893    if args.mode == 'user':
2894        headers.append(parsed.uapi_header)
2895    seen_header = []
2896    for one in headers:
2897        if one not in seen_header:
2898            cw.p(f"#include <{one}>")
2899            seen_header.append(one)
2900    cw.nl()
2901
2902    if args.mode == "user":
2903        if not args.header:
2904            cw.p("#include <linux/genetlink.h>")
2905            cw.nl()
2906            for one in args.user_header:
2907                cw.p(f'#include "{one}"')
2908        else:
2909            cw.p('struct ynl_sock;')
2910            cw.nl()
2911            render_user_family(parsed, cw, True)
2912        cw.nl()
2913
2914    if args.mode == "kernel":
2915        if args.header:
2916            for _, struct in sorted(parsed.pure_nested_structs.items()):
2917                if struct.request:
2918                    cw.p('/* Common nested types */')
2919                    break
2920            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2921                if struct.request:
2922                    print_req_policy_fwd(cw, struct)
2923            cw.nl()
2924
2925            if parsed.kernel_policy == 'global':
2926                cw.p(f"/* Global operation policy for {parsed.name} */")
2927
2928                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2929                print_req_policy_fwd(cw, struct)
2930                cw.nl()
2931
2932            if parsed.kernel_policy in {'per-op', 'split'}:
2933                for op_name, op in parsed.ops.items():
2934                    if 'do' in op and 'event' not in op:
2935                        ri = RenderInfo(cw, parsed, args.mode, op, "do")
2936                        print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2937                        cw.nl()
2938
2939            print_kernel_op_table_hdr(parsed, cw)
2940            print_kernel_mcgrp_hdr(parsed, cw)
2941            print_kernel_family_struct_hdr(parsed, cw)
2942        else:
2943            print_kernel_policy_ranges(parsed, cw)
2944
2945            for _, struct in sorted(parsed.pure_nested_structs.items()):
2946                if struct.request:
2947                    cw.p('/* Common nested types */')
2948                    break
2949            for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2950                if struct.request:
2951                    print_req_policy(cw, struct)
2952            cw.nl()
2953
2954            if parsed.kernel_policy == 'global':
2955                cw.p(f"/* Global operation policy for {parsed.name} */")
2956
2957                struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2958                print_req_policy(cw, struct)
2959                cw.nl()
2960
2961            for op_name, op in parsed.ops.items():
2962                if parsed.kernel_policy in {'per-op', 'split'}:
2963                    for op_mode in ['do', 'dump']:
2964                        if op_mode in op and 'request' in op[op_mode]:
2965                            cw.p(f"/* {op.enum_name} - {op_mode} */")
2966                            ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2967                            print_req_policy(cw, ri.struct['request'], ri=ri)
2968                            cw.nl()
2969
2970            print_kernel_op_table(parsed, cw)
2971            print_kernel_mcgrp_src(parsed, cw)
2972            print_kernel_family_struct_src(parsed, cw)
2973
2974    if args.mode == "user":
2975        if args.header:
2976            cw.p('/* Enums */')
2977            put_op_name_fwd(parsed, cw)
2978
2979            for name, const in parsed.consts.items():
2980                if isinstance(const, EnumSet):
2981                    put_enum_to_str_fwd(parsed, cw, const)
2982            cw.nl()
2983
2984            cw.p('/* Common nested types */')
2985            for attr_set, struct in parsed.pure_nested_structs.items():
2986                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2987                print_type_full(ri, struct)
2988                if struct.request and struct.in_multi_val:
2989                    free_rsp_nested_prototype(ri)
2990                    cw.nl()
2991
2992            for op_name, op in parsed.ops.items():
2993                cw.p(f"/* ============== {op.enum_name} ============== */")
2994
2995                if 'do' in op and 'event' not in op:
2996                    cw.p(f"/* {op.enum_name} - do */")
2997                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
2998                    print_req_type(ri)
2999                    print_req_type_helpers(ri)
3000                    cw.nl()
3001                    print_rsp_type(ri)
3002                    print_rsp_type_helpers(ri)
3003                    cw.nl()
3004                    print_req_prototype(ri)
3005                    cw.nl()
3006
3007                if 'dump' in op:
3008                    cw.p(f"/* {op.enum_name} - dump */")
3009                    ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
3010                    print_req_type(ri)
3011                    print_req_type_helpers(ri)
3012                    if not ri.type_consistent:
3013                        print_rsp_type(ri)
3014                    print_wrapped_type(ri)
3015                    print_dump_prototype(ri)
3016                    cw.nl()
3017
3018                if op.has_ntf:
3019                    cw.p(f"/* {op.enum_name} - notify */")
3020                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3021                    if not ri.type_consistent:
3022                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3023                    print_wrapped_type(ri)
3024
3025            for op_name, op in parsed.ntfs.items():
3026                if 'event' in op:
3027                    ri = RenderInfo(cw, parsed, args.mode, op, 'event')
3028                    cw.p(f"/* {op.enum_name} - event */")
3029                    print_rsp_type(ri)
3030                    cw.nl()
3031                    print_wrapped_type(ri)
3032            cw.nl()
3033        else:
3034            cw.p('/* Enums */')
3035            put_op_name(parsed, cw)
3036
3037            for name, const in parsed.consts.items():
3038                if isinstance(const, EnumSet):
3039                    put_enum_to_str(parsed, cw, const)
3040            cw.nl()
3041
3042            has_recursive_nests = False
3043            cw.p('/* Policies */')
3044            for struct in parsed.pure_nested_structs.values():
3045                if struct.recursive:
3046                    put_typol_fwd(cw, struct)
3047                    has_recursive_nests = True
3048            if has_recursive_nests:
3049                cw.nl()
3050            for name in parsed.pure_nested_structs:
3051                struct = Struct(parsed, name)
3052                put_typol(cw, struct)
3053            for name in parsed.root_sets:
3054                struct = Struct(parsed, name)
3055                put_typol(cw, struct)
3056
3057            cw.p('/* Common nested types */')
3058            if has_recursive_nests:
3059                for attr_set, struct in parsed.pure_nested_structs.items():
3060                    ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3061                    free_rsp_nested_prototype(ri)
3062                    if struct.request:
3063                        put_req_nested_prototype(ri, struct)
3064                    if struct.reply:
3065                        parse_rsp_nested_prototype(ri, struct)
3066                cw.nl()
3067            for attr_set, struct in parsed.pure_nested_structs.items():
3068                ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
3069
3070                free_rsp_nested(ri, struct)
3071                if struct.request:
3072                    put_req_nested(ri, struct)
3073                if struct.reply:
3074                    parse_rsp_nested(ri, struct)
3075
3076            for op_name, op in parsed.ops.items():
3077                cw.p(f"/* ============== {op.enum_name} ============== */")
3078                if 'do' in op and 'event' not in op:
3079                    cw.p(f"/* {op.enum_name} - do */")
3080                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3081                    print_req_free(ri)
3082                    print_rsp_free(ri)
3083                    parse_rsp_msg(ri)
3084                    print_req(ri)
3085                    cw.nl()
3086
3087                if 'dump' in op:
3088                    cw.p(f"/* {op.enum_name} - dump */")
3089                    ri = RenderInfo(cw, parsed, args.mode, op, "dump")
3090                    if not ri.type_consistent:
3091                        parse_rsp_msg(ri, deref=True)
3092                    print_req_free(ri)
3093                    print_dump_type_free(ri)
3094                    print_dump(ri)
3095                    cw.nl()
3096
3097                if op.has_ntf:
3098                    cw.p(f"/* {op.enum_name} - notify */")
3099                    ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
3100                    if not ri.type_consistent:
3101                        raise Exception(f'Only notifications with consistent types supported ({op.name})')
3102                    print_ntf_type_free(ri)
3103
3104            for op_name, op in parsed.ntfs.items():
3105                if 'event' in op:
3106                    cw.p(f"/* {op.enum_name} - event */")
3107
3108                    ri = RenderInfo(cw, parsed, args.mode, op, "do")
3109                    parse_rsp_msg(ri)
3110
3111                    ri = RenderInfo(cw, parsed, args.mode, op, "event")
3112                    print_ntf_type_free(ri)
3113            cw.nl()
3114            render_user_family(parsed, cw, False)
3115
3116    if args.header:
3117        cw.p(f'#endif /* {hdr_prot} */')
3118
3119
3120if __name__ == "__main__":
3121    main()
3122