1#!/usr/bin/env python3
2# SPDX-License-Identifier: GPL-2.0-or-later
3#
4# Script that generates constants for computing the given CRC variant(s).
5#
6# Copyright 2025 Google LLC
7#
8# Author: Eric Biggers <ebiggers@google.com>
9
10import sys
11
12# XOR (add) an iterable of polynomials.
13def xor(iterable):
14    res = 0
15    for val in iterable:
16        res ^= val
17    return res
18
19# Multiply two polynomials.
20def clmul(a, b):
21    return xor(a << i for i in range(b.bit_length()) if (b & (1 << i)) != 0)
22
23# Polynomial division floor(a / b).
24def div(a, b):
25    q = 0
26    while a.bit_length() >= b.bit_length():
27        q ^= 1 << (a.bit_length() - b.bit_length())
28        a ^= b << (a.bit_length() - b.bit_length())
29    return q
30
31# Reduce the polynomial 'a' modulo the polynomial 'b'.
32def reduce(a, b):
33    return a ^ clmul(div(a, b), b)
34
35# Reflect the bits of a polynomial.
36def bitreflect(poly, num_bits):
37    assert poly.bit_length() <= num_bits
38    return xor(((poly >> i) & 1) << (num_bits - 1 - i) for i in range(num_bits))
39
40# Format a polynomial as hex.  Bit-reflect it if the CRC is lsb-first.
41def fmt_poly(variant, poly, num_bits):
42    if variant.lsb:
43        poly = bitreflect(poly, num_bits)
44    return f'0x{poly:0{2*num_bits//8}x}'
45
46# Print a pair of 64-bit polynomial multipliers.  They are always passed in the
47# order [HI64_TERMS, LO64_TERMS] but will be printed in the appropriate order.
48def print_mult_pair(variant, mults):
49    mults = list(mults if variant.lsb else reversed(mults))
50    terms = ['HI64_TERMS', 'LO64_TERMS'] if variant.lsb else ['LO64_TERMS', 'HI64_TERMS']
51    for i in range(2):
52        print(f'\t\t{fmt_poly(variant, mults[i]["val"], 64)},\t/* {terms[i]}: {mults[i]["desc"]} */')
53
54# Pretty-print a polynomial.
55def pprint_poly(prefix, poly):
56    terms = [f'x^{i}' for i in reversed(range(poly.bit_length()))
57             if (poly & (1 << i)) != 0]
58    j = 0
59    while j < len(terms):
60        s = prefix + terms[j] + (' +' if j < len(terms) - 1 else '')
61        j += 1
62        while j < len(terms) and len(s) < 73:
63            s += ' ' + terms[j] + (' +' if j < len(terms) - 1 else '')
64            j += 1
65        print(s)
66        prefix = ' * ' + (' ' * (len(prefix) - 3))
67
68# Print a comment describing constants generated for the given CRC variant.
69def print_header(variant, what):
70    print('/*')
71    s = f'{"least" if variant.lsb else "most"}-significant-bit-first CRC-{variant.bits}'
72    print(f' * {what} generated for {s} using')
73    pprint_poly(' * G(x) = ', variant.G)
74    print(' */')
75
76class CrcVariant:
77    def __init__(self, bits, generator_poly, bit_order):
78        self.bits = bits
79        if bit_order not in ['lsb', 'msb']:
80            raise ValueError('Invalid value for bit_order')
81        self.lsb = bit_order == 'lsb'
82        self.name = f'crc{bits}_{bit_order}_0x{generator_poly:0{(2*bits+7)//8}x}'
83        if self.lsb:
84            generator_poly = bitreflect(generator_poly, bits)
85        self.G = generator_poly ^ (1 << bits)
86
87# Generate tables for CRC computation using the "slice-by-N" method.
88# N=1 corresponds to the traditional byte-at-a-time table.
89def gen_slicebyN_tables(variants, n):
90    for v in variants:
91        print('')
92        print_header(v, f'Slice-by-{n} CRC table')
93        print(f'static const u{v.bits} __maybe_unused {v.name}_table[{256*n}] = {{')
94        s = ''
95        for i in range(256 * n):
96            # The i'th table entry is the CRC of the message consisting of byte
97            # i % 256 followed by i // 256 zero bytes.
98            poly = (bitreflect(i % 256, 8) if v.lsb else i % 256) << (v.bits + 8*(i//256))
99            next_entry = fmt_poly(v, reduce(poly, v.G), v.bits) + ','
100            if len(s + next_entry) > 71:
101                print(f'\t{s}')
102                s = ''
103            s += (' ' if s else '') + next_entry
104        if s:
105            print(f'\t{s}')
106        print('};')
107
108def print_riscv_const(v, bits_per_long, name, val, desc):
109    print(f'\t.{name} = {fmt_poly(v, val, bits_per_long)}, /* {desc} */')
110
111def do_gen_riscv_clmul_consts(v, bits_per_long):
112    (G, n, lsb) = (v.G, v.bits, v.lsb)
113
114    pow_of_x = 3 * bits_per_long - (1 if lsb else 0)
115    print_riscv_const(v, bits_per_long, 'fold_across_2_longs_const_hi',
116                      reduce(1 << pow_of_x, G), f'x^{pow_of_x} mod G')
117    pow_of_x = 2 * bits_per_long - (1 if lsb else 0)
118    print_riscv_const(v, bits_per_long, 'fold_across_2_longs_const_lo',
119                      reduce(1 << pow_of_x, G), f'x^{pow_of_x} mod G')
120
121    pow_of_x = bits_per_long - 1 + n
122    print_riscv_const(v, bits_per_long, 'barrett_reduction_const_1',
123                      div(1 << pow_of_x, G), f'floor(x^{pow_of_x} / G)')
124
125    val = G - (1 << n)
126    desc = f'G - x^{n}'
127    if lsb:
128        val <<= bits_per_long - n
129        desc = f'({desc}) * x^{bits_per_long - n}'
130    print_riscv_const(v, bits_per_long, 'barrett_reduction_const_2', val, desc)
131
132def gen_riscv_clmul_consts(variants):
133    print('')
134    print('struct crc_clmul_consts {');
135    print('\tunsigned long fold_across_2_longs_const_hi;');
136    print('\tunsigned long fold_across_2_longs_const_lo;');
137    print('\tunsigned long barrett_reduction_const_1;');
138    print('\tunsigned long barrett_reduction_const_2;');
139    print('};');
140    for v in variants:
141        print('');
142        if v.bits > 32:
143            print_header(v, 'Constants')
144            print('#ifdef CONFIG_64BIT')
145            print(f'static const struct crc_clmul_consts {v.name}_consts __maybe_unused = {{')
146            do_gen_riscv_clmul_consts(v, 64)
147            print('};')
148            print('#endif')
149        else:
150            print_header(v, 'Constants')
151            print(f'static const struct crc_clmul_consts {v.name}_consts __maybe_unused = {{')
152            print('#ifdef CONFIG_64BIT')
153            do_gen_riscv_clmul_consts(v, 64)
154            print('#else')
155            do_gen_riscv_clmul_consts(v, 32)
156            print('#endif')
157            print('};')
158
159# Generate constants for carryless multiplication based CRC computation.
160def gen_x86_pclmul_consts(variants):
161    # These are the distances, in bits, to generate folding constants for.
162    FOLD_DISTANCES = [2048, 1024, 512, 256, 128]
163
164    for v in variants:
165        (G, n, lsb) = (v.G, v.bits, v.lsb)
166        print('')
167        print_header(v, 'CRC folding constants')
168        print('static const struct {')
169        if not lsb:
170            print('\tu8 bswap_mask[16];')
171        for i in FOLD_DISTANCES:
172            print(f'\tu64 fold_across_{i}_bits_consts[2];')
173        print('\tu8 shuf_table[48];')
174        print('\tu64 barrett_reduction_consts[2];')
175        print(f'}} {v.name}_consts ____cacheline_aligned __maybe_unused = {{')
176
177        # Byte-reflection mask, needed for msb-first CRCs
178        if not lsb:
179            print('\t.bswap_mask = {' + ', '.join(str(i) for i in reversed(range(16))) + '},')
180
181        # Fold constants for all distances down to 128 bits
182        for i in FOLD_DISTANCES:
183            print(f'\t.fold_across_{i}_bits_consts = {{')
184            # Given 64x64 => 128 bit carryless multiplication instructions, two
185            # 64-bit fold constants are needed per "fold distance" i: one for
186            # HI64_TERMS that is basically x^(i+64) mod G and one for LO64_TERMS
187            # that is basically x^i mod G.  The exact values however undergo a
188            # couple adjustments, described below.
189            mults = []
190            for j in [64, 0]:
191                pow_of_x = i + j
192                if lsb:
193                    # Each 64x64 => 128 bit carryless multiplication instruction
194                    # actually generates a 127-bit product in physical bits 0
195                    # through 126, which in the lsb-first case represent the
196                    # coefficients of x^1 through x^127, not x^0 through x^126.
197                    # Thus in the lsb-first case, each such instruction
198                    # implicitly adds an extra factor of x.  The below removes a
199                    # factor of x from each constant to compensate for this.
200                    # For n < 64 the x could be removed from either the reduced
201                    # part or unreduced part, but for n == 64 the reduced part
202                    # is the only option.  Just always use the reduced part.
203                    pow_of_x -= 1
204                # Make a factor of x^(64-n) be applied unreduced rather than
205                # reduced, to cause the product to use only the x^(64-n) and
206                # higher terms and always be zero in the lower terms.  Usually
207                # this makes no difference as it does not affect the product's
208                # congruence class mod G and the constant remains 64-bit, but
209                # part of the final reduction from 128 bits does rely on this
210                # property when it reuses one of the constants.
211                pow_of_x -= 64 - n
212                mults.append({ 'val': reduce(1 << pow_of_x, G) << (64 - n),
213                               'desc': f'(x^{pow_of_x} mod G) * x^{64-n}' })
214            print_mult_pair(v, mults)
215            print('\t},')
216
217        # Shuffle table for handling 1..15 bytes at end
218        print('\t.shuf_table = {')
219        print('\t\t' + (16*'-1, ').rstrip())
220        print('\t\t' + ''.join(f'{i:2}, ' for i in range(16)).rstrip())
221        print('\t\t' + (16*'-1, ').rstrip())
222        print('\t},')
223
224        # Barrett reduction constants for reducing 128 bits to the final CRC
225        print('\t.barrett_reduction_consts = {')
226        mults = []
227
228        val = div(1 << (63+n), G)
229        desc = f'floor(x^{63+n} / G)'
230        if not lsb:
231            val = (val << 1) - (1 << 64)
232            desc = f'({desc} * x) - x^64'
233        mults.append({ 'val': val, 'desc': desc })
234
235        val = G - (1 << n)
236        desc = f'G - x^{n}'
237        if lsb and n == 64:
238            assert (val & 1) != 0  # The x^0 term should always be nonzero.
239            val >>= 1
240            desc = f'({desc} - x^0) / x'
241        else:
242            pow_of_x = 64 - n - (1 if lsb else 0)
243            val <<= pow_of_x
244            desc = f'({desc}) * x^{pow_of_x}'
245        mults.append({ 'val': val, 'desc': desc })
246
247        print_mult_pair(v, mults)
248        print('\t},')
249
250        print('};')
251
252def parse_crc_variants(vars_string):
253    variants = []
254    for var_string in vars_string.split(','):
255        bits, bit_order, generator_poly = var_string.split('_')
256        assert bits.startswith('crc')
257        bits = int(bits.removeprefix('crc'))
258        assert generator_poly.startswith('0x')
259        generator_poly = generator_poly.removeprefix('0x')
260        assert len(generator_poly) % 2 == 0
261        generator_poly = int(generator_poly, 16)
262        variants.append(CrcVariant(bits, generator_poly, bit_order))
263    return variants
264
265if len(sys.argv) != 3:
266    sys.stderr.write(f'Usage: {sys.argv[0]} CONSTS_TYPE[,CONSTS_TYPE]... CRC_VARIANT[,CRC_VARIANT]...\n')
267    sys.stderr.write('  CONSTS_TYPE can be sliceby[1-8], riscv_clmul, or x86_pclmul\n')
268    sys.stderr.write('  CRC_VARIANT is crc${num_bits}_${bit_order}_${generator_poly_as_hex}\n')
269    sys.stderr.write('     E.g. crc16_msb_0x8bb7 or crc32_lsb_0xedb88320\n')
270    sys.stderr.write('     Polynomial must use the given bit_order and exclude x^{num_bits}\n')
271    sys.exit(1)
272
273print('/* SPDX-License-Identifier: GPL-2.0-or-later */')
274print('/*')
275print(' * CRC constants generated by:')
276print(' *')
277print(f' *\t{sys.argv[0]} {" ".join(sys.argv[1:])}')
278print(' *')
279print(' * Do not edit manually.')
280print(' */')
281consts_types = sys.argv[1].split(',')
282variants = parse_crc_variants(sys.argv[2])
283for consts_type in consts_types:
284    if consts_type.startswith('sliceby'):
285        gen_slicebyN_tables(variants, int(consts_type.removeprefix('sliceby')))
286    elif consts_type == 'riscv_clmul':
287        gen_riscv_clmul_consts(variants)
288    elif consts_type == 'x86_pclmul':
289        gen_x86_pclmul_consts(variants)
290    else:
291        raise ValueError(f'Unknown consts_type: {consts_type}')
292