xref: /linux/tools/net/sunrpc/xdrgen/xdr_ast.py (revision 2831fa8b8bcf1083f9526aa0c41fafb0796cf874)
1#!/usr/bin/env python3
2# ex: set filetype=python:
3
4"""Define and implement the Abstract Syntax Tree for the XDR language."""
5
6import sys
7from typing import List
8from dataclasses import dataclass
9
10from lark import ast_utils, Transformer
11from lark.tree import Meta
12
13this_module = sys.modules[__name__]
14
15big_endian = []
16excluded_apis = []
17header_name = "none"
18public_apis = []
19structs = set()
20pass_by_reference = set()
21
22constants = {}
23
24
25def xdr_quadlen(val: str) -> int:
26    """Return integer XDR width of an XDR type"""
27    if val in constants:
28        octets = constants[val]
29    else:
30        octets = int(val)
31    return int((octets + 3) / 4)
32
33
34symbolic_widths = {
35    "void": ["XDR_void"],
36    "bool": ["XDR_bool"],
37    "short": ["XDR_short"],
38    "unsigned_short": ["XDR_unsigned_short"],
39    "int": ["XDR_int"],
40    "unsigned_int": ["XDR_unsigned_int"],
41    "long": ["XDR_long"],
42    "unsigned_long": ["XDR_unsigned_long"],
43    "hyper": ["XDR_hyper"],
44    "unsigned_hyper": ["XDR_unsigned_hyper"],
45}
46
47# Numeric XDR widths are tracked in a dictionary that is keyed
48# by type_name because sometimes a caller has nothing more than
49# the type_name to use to figure out the numeric width.
50max_widths = {
51    "void": 0,
52    "bool": 1,
53    "short": 1,
54    "unsigned_short": 1,
55    "int": 1,
56    "unsigned_int": 1,
57    "long": 1,
58    "unsigned_long": 1,
59    "hyper": 2,
60    "unsigned_hyper": 2,
61}
62
63
64@dataclass
65class _XdrAst(ast_utils.Ast):
66    """Base class for the XDR abstract syntax tree"""
67
68
69@dataclass
70class _XdrIdentifier(_XdrAst):
71    """Corresponds to 'identifier' in the XDR language grammar"""
72
73    symbol: str
74
75
76@dataclass
77class _XdrValue(_XdrAst):
78    """Corresponds to 'value' in the XDR language grammar"""
79
80    value: str
81
82
83@dataclass
84class _XdrConstantValue(_XdrAst):
85    """Corresponds to 'constant' in the XDR language grammar"""
86
87    value: int
88
89
90@dataclass
91class _XdrTypeSpecifier(_XdrAst):
92    """Corresponds to 'type_specifier' in the XDR language grammar"""
93
94    type_name: str
95    c_classifier: str = ""
96
97
98@dataclass
99class _XdrDefinedType(_XdrTypeSpecifier):
100    """Corresponds to a type defined by the input specification"""
101
102    def symbolic_width(self) -> List:
103        """Return list containing XDR width of type's components"""
104        return [get_header_name().upper() + "_" + self.type_name + "_sz"]
105
106    def __post_init__(self):
107        if self.type_name in structs:
108            self.c_classifier = "struct "
109        symbolic_widths[self.type_name] = self.symbolic_width()
110
111
112@dataclass
113class _XdrBuiltInType(_XdrTypeSpecifier):
114    """Corresponds to a built-in XDR type"""
115
116    def symbolic_width(self) -> List:
117        """Return list containing XDR width of type's components"""
118        return symbolic_widths[self.type_name]
119
120
121@dataclass
122class _XdrDeclaration(_XdrAst):
123    """Base class of XDR type declarations"""
124
125
126@dataclass
127class _XdrFixedLengthOpaque(_XdrDeclaration):
128    """A fixed-length opaque declaration"""
129
130    name: str
131    size: str
132    template: str = "fixed_length_opaque"
133
134    def max_width(self) -> int:
135        """Return width of type in XDR_UNITS"""
136        return xdr_quadlen(self.size)
137
138    def symbolic_width(self) -> List:
139        """Return list containing XDR width of type's components"""
140        return ["XDR_QUADLEN(" + self.size + ")"]
141
142    def __post_init__(self):
143        max_widths[self.name] = self.max_width()
144        symbolic_widths[self.name] = self.symbolic_width()
145
146
147@dataclass
148class _XdrVariableLengthOpaque(_XdrDeclaration):
149    """A variable-length opaque declaration"""
150
151    name: str
152    maxsize: str
153    template: str = "variable_length_opaque"
154
155    def max_width(self) -> int:
156        """Return width of type in XDR_UNITS"""
157        return 1 + xdr_quadlen(self.maxsize)
158
159    def symbolic_width(self) -> List:
160        """Return list containing XDR width of type's components"""
161        widths = ["XDR_unsigned_int"]
162        if self.maxsize != "0":
163            widths.append("XDR_QUADLEN(" + self.maxsize + ")")
164        return widths
165
166    def __post_init__(self):
167        max_widths[self.name] = self.max_width()
168        symbolic_widths[self.name] = self.symbolic_width()
169
170
171@dataclass
172class _XdrString(_XdrDeclaration):
173    """A (NUL-terminated) variable-length string declaration"""
174
175    name: str
176    maxsize: str
177    template: str = "string"
178
179    def max_width(self) -> int:
180        """Return width of type in XDR_UNITS"""
181        return 1 + xdr_quadlen(self.maxsize)
182
183    def symbolic_width(self) -> List:
184        """Return list containing XDR width of type's components"""
185        widths = ["XDR_unsigned_int"]
186        if self.maxsize != "0":
187            widths.append("XDR_QUADLEN(" + self.maxsize + ")")
188        return widths
189
190    def __post_init__(self):
191        max_widths[self.name] = self.max_width()
192        symbolic_widths[self.name] = self.symbolic_width()
193
194
195@dataclass
196class _XdrFixedLengthArray(_XdrDeclaration):
197    """A fixed-length array declaration"""
198
199    name: str
200    spec: _XdrTypeSpecifier
201    size: str
202    template: str = "fixed_length_array"
203
204    def max_width(self) -> int:
205        """Return width of type in XDR_UNITS"""
206        return xdr_quadlen(self.size) * max_widths[self.spec.type_name]
207
208    def symbolic_width(self) -> List:
209        """Return list containing XDR width of type's components"""
210        item_width = " + ".join(symbolic_widths[self.spec.type_name])
211        return ["(" + self.size + " * (" + item_width + "))"]
212
213    def __post_init__(self):
214        max_widths[self.name] = self.max_width()
215        symbolic_widths[self.name] = self.symbolic_width()
216
217
218@dataclass
219class _XdrVariableLengthArray(_XdrDeclaration):
220    """A variable-length array declaration"""
221
222    name: str
223    spec: _XdrTypeSpecifier
224    maxsize: str
225    template: str = "variable_length_array"
226
227    def max_width(self) -> int:
228        """Return width of type in XDR_UNITS"""
229        return 1 + (xdr_quadlen(self.maxsize) * max_widths[self.spec.type_name])
230
231    def symbolic_width(self) -> List:
232        """Return list containing XDR width of type's components"""
233        widths = ["XDR_unsigned_int"]
234        if self.maxsize != "0":
235            item_width = " + ".join(symbolic_widths[self.spec.type_name])
236            widths.append("(" + self.maxsize + " * (" + item_width + "))")
237        return widths
238
239    def __post_init__(self):
240        max_widths[self.name] = self.max_width()
241        symbolic_widths[self.name] = self.symbolic_width()
242
243
244@dataclass
245class _XdrOptionalData(_XdrDeclaration):
246    """An 'optional_data' declaration"""
247
248    name: str
249    spec: _XdrTypeSpecifier
250    template: str = "optional_data"
251
252    def max_width(self) -> int:
253        """Return width of type in XDR_UNITS"""
254        return 1
255
256    def symbolic_width(self) -> List:
257        """Return list containing XDR width of type's components"""
258        return ["XDR_bool"]
259
260    def __post_init__(self):
261        structs.add(self.name)
262        pass_by_reference.add(self.name)
263        max_widths[self.name] = self.max_width()
264        symbolic_widths[self.name] = self.symbolic_width()
265
266
267@dataclass
268class _XdrBasic(_XdrDeclaration):
269    """A 'basic' declaration"""
270
271    name: str
272    spec: _XdrTypeSpecifier
273    template: str = "basic"
274
275    def max_width(self) -> int:
276        """Return width of type in XDR_UNITS"""
277        return max_widths[self.spec.type_name]
278
279    def symbolic_width(self) -> List:
280        """Return list containing XDR width of type's components"""
281        return symbolic_widths[self.spec.type_name]
282
283    def __post_init__(self):
284        max_widths[self.name] = self.max_width()
285        symbolic_widths[self.name] = self.symbolic_width()
286
287
288@dataclass
289class _XdrVoid(_XdrDeclaration):
290    """A void declaration"""
291
292    name: str = "void"
293    template: str = "void"
294
295    def max_width(self) -> int:
296        """Return width of type in XDR_UNITS"""
297        return 0
298
299    def symbolic_width(self) -> List:
300        """Return list containing XDR width of type's components"""
301        return []
302
303
304@dataclass
305class _XdrConstant(_XdrAst):
306    """Corresponds to 'constant_def' in the grammar"""
307
308    name: str
309    value: str
310
311    def __post_init__(self):
312        if self.value not in constants:
313            constants[self.name] = int(self.value, 0)
314
315
316@dataclass
317class _XdrEnumerator(_XdrAst):
318    """An 'identifier = value' enumerator"""
319
320    name: str
321    value: str
322
323    def __post_init__(self):
324        if self.value not in constants:
325            constants[self.name] = int(self.value, 0)
326
327
328@dataclass
329class _XdrEnum(_XdrAst):
330    """An XDR enum definition"""
331
332    name: str
333    enumerators: List[_XdrEnumerator]
334
335    def max_width(self) -> int:
336        """Return width of type in XDR_UNITS"""
337        return 1
338
339    def symbolic_width(self) -> List:
340        """Return list containing XDR width of type's components"""
341        return ["XDR_int"]
342
343    def __post_init__(self):
344        max_widths[self.name] = self.max_width()
345        symbolic_widths[self.name] = self.symbolic_width()
346
347
348@dataclass
349class _XdrStruct(_XdrAst):
350    """An XDR struct definition"""
351
352    name: str
353    fields: List[_XdrDeclaration]
354
355    def max_width(self) -> int:
356        """Return width of type in XDR_UNITS"""
357        width = 0
358        for field in self.fields:
359            width += field.max_width()
360        return width
361
362    def symbolic_width(self) -> List:
363        """Return list containing XDR width of type's components"""
364        widths = []
365        for field in self.fields:
366            widths += field.symbolic_width()
367        return widths
368
369    def __post_init__(self):
370        structs.add(self.name)
371        pass_by_reference.add(self.name)
372        max_widths[self.name] = self.max_width()
373        symbolic_widths[self.name] = self.symbolic_width()
374
375
376@dataclass
377class _XdrPointer(_XdrAst):
378    """An XDR pointer definition"""
379
380    name: str
381    fields: List[_XdrDeclaration]
382
383    def max_width(self) -> int:
384        """Return width of type in XDR_UNITS"""
385        width = 1
386        for field in self.fields[0:-1]:
387            width += field.max_width()
388        return width
389
390    def symbolic_width(self) -> List:
391        """Return list containing XDR width of type's components"""
392        widths = []
393        widths += ["XDR_bool"]
394        for field in self.fields[0:-1]:
395            widths += field.symbolic_width()
396        return widths
397
398    def __post_init__(self):
399        structs.add(self.name)
400        pass_by_reference.add(self.name)
401        max_widths[self.name] = self.max_width()
402        symbolic_widths[self.name] = self.symbolic_width()
403
404
405@dataclass
406class _XdrTypedef(_XdrAst):
407    """An XDR typedef"""
408
409    declaration: _XdrDeclaration
410
411    def max_width(self) -> int:
412        """Return width of type in XDR_UNITS"""
413        return self.declaration.max_width()
414
415    def symbolic_width(self) -> List:
416        """Return list containing XDR width of type's components"""
417        return self.declaration.symbolic_width()
418
419    def __post_init__(self):
420        if isinstance(self.declaration, _XdrBasic):
421            new_type = self.declaration
422            if isinstance(new_type.spec, _XdrDefinedType):
423                if new_type.spec.type_name in pass_by_reference:
424                    pass_by_reference.add(new_type.name)
425                max_widths[new_type.name] = self.max_width()
426                symbolic_widths[new_type.name] = self.symbolic_width()
427
428
429@dataclass
430class _XdrCaseSpec(_XdrAst):
431    """One case in an XDR union"""
432
433    values: List[str]
434    arm: _XdrDeclaration
435    template: str = "case_spec"
436
437
438@dataclass
439class _XdrDefaultSpec(_XdrAst):
440    """Default case in an XDR union"""
441
442    arm: _XdrDeclaration
443    template: str = "default_spec"
444
445
446@dataclass
447class _XdrUnion(_XdrAst):
448    """An XDR union"""
449
450    name: str
451    discriminant: _XdrDeclaration
452    cases: List[_XdrCaseSpec]
453    default: _XdrDeclaration
454
455    def max_width(self) -> int:
456        """Return width of type in XDR_UNITS"""
457        max_width = 0
458        for case in self.cases:
459            if case.arm.max_width() > max_width:
460                max_width = case.arm.max_width()
461        if self.default:
462            if self.default.arm.max_width() > max_width:
463                max_width = self.default.arm.max_width()
464        return 1 + max_width
465
466    def symbolic_width(self) -> List:
467        """Return list containing XDR width of type's components"""
468        max_width = 0
469        for case in self.cases:
470            if case.arm.max_width() > max_width:
471                max_width = case.arm.max_width()
472                width = case.arm.symbolic_width()
473        if self.default:
474            if self.default.arm.max_width() > max_width:
475                max_width = self.default.arm.max_width()
476                width = self.default.arm.symbolic_width()
477        return symbolic_widths[self.discriminant.name] + width
478
479    def __post_init__(self):
480        structs.add(self.name)
481        pass_by_reference.add(self.name)
482        max_widths[self.name] = self.max_width()
483        symbolic_widths[self.name] = self.symbolic_width()
484
485
486@dataclass
487class _RpcProcedure(_XdrAst):
488    """RPC procedure definition"""
489
490    name: str
491    number: str
492    argument: _XdrTypeSpecifier
493    result: _XdrTypeSpecifier
494
495
496@dataclass
497class _RpcVersion(_XdrAst):
498    """RPC version definition"""
499
500    name: str
501    number: str
502    procedures: List[_RpcProcedure]
503
504
505@dataclass
506class _RpcProgram(_XdrAst):
507    """RPC program definition"""
508
509    name: str
510    number: str
511    versions: List[_RpcVersion]
512
513
514@dataclass
515class _Pragma(_XdrAst):
516    """Empty class for pragma directives"""
517
518
519@dataclass
520class _XdrPassthru(_XdrAst):
521    """Passthrough line to emit verbatim in output"""
522
523    content: str
524
525
526@dataclass
527class Definition(_XdrAst, ast_utils.WithMeta):
528    """Corresponds to 'definition' in the grammar"""
529
530    meta: Meta
531    value: _XdrAst
532
533
534@dataclass
535class Specification(_XdrAst, ast_utils.AsList):
536    """Corresponds to 'specification' in the grammar"""
537
538    definitions: List[Definition]
539
540
541class ParseToAst(Transformer):
542    """Functions that transform productions into AST nodes"""
543
544    def identifier(self, children):
545        """Instantiate one _XdrIdentifier object"""
546        return _XdrIdentifier(children[0].value)
547
548    def value(self, children):
549        """Instantiate one _XdrValue object"""
550        if isinstance(children[0], _XdrIdentifier):
551            return _XdrValue(children[0].symbol)
552        return _XdrValue(children[0].children[0].value)
553
554    def constant(self, children):
555        """Instantiate one _XdrConstantValue object"""
556        match children[0].data:
557            case "decimal_constant":
558                value = int(children[0].children[0].value, base=10)
559            case "hexadecimal_constant":
560                value = int(children[0].children[0].value, base=16)
561            case "octal_constant":
562                value = int(children[0].children[0].value, base=8)
563        return _XdrConstantValue(value)
564
565    def type_specifier(self, children):
566        """Instantiate one _XdrTypeSpecifier object"""
567        if isinstance(children[0], _XdrIdentifier):
568            name = children[0].symbol
569            return _XdrDefinedType(type_name=name)
570
571        name = children[0].data.value
572        return _XdrBuiltInType(type_name=name)
573
574    def constant_def(self, children):
575        """Instantiate one _XdrConstant object"""
576        name = children[0].symbol
577        value = children[1].value
578        return _XdrConstant(name, value)
579
580    def enum(self, children):
581        """Instantiate one _XdrEnum object"""
582        enum_name = children[0].symbol
583
584        i = 0
585        enumerators = []
586        body = children[1]
587        while i < len(body.children):
588            name = body.children[i].symbol
589            value = body.children[i + 1].value
590            enumerators.append(_XdrEnumerator(name, value))
591            i = i + 2
592
593        return _XdrEnum(enum_name, enumerators)
594
595    def fixed_length_opaque(self, children):
596        """Instantiate one _XdrFixedLengthOpaque declaration object"""
597        name = children[0].symbol
598        size = children[1].value
599
600        return _XdrFixedLengthOpaque(name, size)
601
602    def variable_length_opaque(self, children):
603        """Instantiate one _XdrVariableLengthOpaque declaration object"""
604        name = children[0].symbol
605        if children[1] is not None:
606            maxsize = children[1].value
607        else:
608            maxsize = "0"
609
610        return _XdrVariableLengthOpaque(name, maxsize)
611
612    def string(self, children):
613        """Instantiate one _XdrString declaration object"""
614        name = children[0].symbol
615        if children[1] is not None:
616            maxsize = children[1].value
617        else:
618            maxsize = "0"
619
620        return _XdrString(name, maxsize)
621
622    def fixed_length_array(self, children):
623        """Instantiate one _XdrFixedLengthArray declaration object"""
624        spec = children[0]
625        name = children[1].symbol
626        size = children[2].value
627
628        return _XdrFixedLengthArray(name, spec, size)
629
630    def variable_length_array(self, children):
631        """Instantiate one _XdrVariableLengthArray declaration object"""
632        spec = children[0]
633        name = children[1].symbol
634        if children[2] is not None:
635            maxsize = children[2].value
636        else:
637            maxsize = "0"
638
639        return _XdrVariableLengthArray(name, spec, maxsize)
640
641    def optional_data(self, children):
642        """Instantiate one _XdrOptionalData declaration object"""
643        spec = children[0]
644        name = children[1].symbol
645
646        return _XdrOptionalData(name, spec)
647
648    def basic(self, children):
649        """Instantiate one _XdrBasic object"""
650        spec = children[0]
651        name = children[1].symbol
652
653        return _XdrBasic(name, spec)
654
655    def void(self, children):
656        """Instantiate one _XdrVoid declaration object"""
657
658        return _XdrVoid()
659
660    def struct(self, children):
661        """Instantiate one _XdrStruct object"""
662        name = children[0].symbol
663        fields = children[1].children
664
665        last_field = fields[-1]
666        if (
667            isinstance(last_field, _XdrOptionalData)
668            and name == last_field.spec.type_name
669        ):
670            return _XdrPointer(name, fields)
671
672        return _XdrStruct(name, fields)
673
674    def typedef(self, children):
675        """Instantiate one _XdrTypedef object"""
676        new_type = children[0]
677
678        return _XdrTypedef(new_type)
679
680    def case_spec(self, children):
681        """Instantiate one _XdrCaseSpec object"""
682        values = []
683        for item in children[0:-1]:
684            values.append(item.value)
685        arm = children[-1]
686
687        return _XdrCaseSpec(values, arm)
688
689    def default_spec(self, children):
690        """Instantiate one _XdrDefaultSpec object"""
691        arm = children[0]
692
693        return _XdrDefaultSpec(arm)
694
695    def union(self, children):
696        """Instantiate one _XdrUnion object"""
697        name = children[0].symbol
698
699        body = children[1]
700        discriminant = body.children[0].children[0]
701        cases = body.children[1:-1]
702        default = body.children[-1]
703
704        return _XdrUnion(name, discriminant, cases, default)
705
706    def procedure_def(self, children):
707        """Instantiate one _RpcProcedure object"""
708        result = children[0]
709        name = children[1].symbol
710        argument = children[2]
711        number = children[3].value
712
713        return _RpcProcedure(name, number, argument, result)
714
715    def version_def(self, children):
716        """Instantiate one _RpcVersion object"""
717        name = children[0].symbol
718        number = children[-1].value
719        procedures = children[1:-1]
720
721        return _RpcVersion(name, number, procedures)
722
723    def program_def(self, children):
724        """Instantiate one _RpcProgram object"""
725        name = children[0].symbol
726        number = children[-1].value
727        versions = children[1:-1]
728
729        return _RpcProgram(name, number, versions)
730
731    def pragma_def(self, children):
732        """Instantiate one _Pragma object"""
733        directive = children[0].children[0].data
734        match directive:
735            case "big_endian_directive":
736                big_endian.append(children[1].symbol)
737            case "exclude_directive":
738                excluded_apis.append(children[1].symbol)
739            case "header_directive":
740                global header_name
741                header_name = children[1].symbol
742            case "public_directive":
743                public_apis.append(children[1].symbol)
744            case _:
745                raise NotImplementedError("Directive not supported")
746        return _Pragma()
747
748    def passthru_def(self, children):
749        """Instantiate one _XdrPassthru object"""
750        token = children[0]
751        content = token.value[1:]
752        return _XdrPassthru(content)
753
754
755transformer = ast_utils.create_transformer(this_module, ParseToAst())
756
757
758def _merge_consecutive_passthru(definitions: List[Definition]) -> List[Definition]:
759    """Merge consecutive passthru definitions into single nodes"""
760    result = []
761    i = 0
762    while i < len(definitions):
763        if isinstance(definitions[i].value, _XdrPassthru):
764            lines = [definitions[i].value.content]
765            meta = definitions[i].meta
766            j = i + 1
767            while j < len(definitions) and isinstance(definitions[j].value, _XdrPassthru):
768                lines.append(definitions[j].value.content)
769                j += 1
770            merged = _XdrPassthru("\n".join(lines))
771            result.append(Definition(meta, merged))
772            i = j
773        else:
774            result.append(definitions[i])
775            i += 1
776    return result
777
778
779def transform_parse_tree(parse_tree):
780    """Transform productions into an abstract syntax tree"""
781    ast = transformer.transform(parse_tree)
782    ast.definitions = _merge_consecutive_passthru(ast.definitions)
783    return ast
784
785
786def get_header_name() -> str:
787    """Return header name set by pragma header directive"""
788    return header_name
789