xref: /linux/kernel/bpf/tnum.c (revision 32e940f2bd3b16551f23ea44be47f6f5d1746d64) !
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* tnum: tracked (or tristate) numbers
3  *
4  * A tnum tracks knowledge about the bits of a value.  Each bit can be either
5  * known (0 or 1), or unknown (x).  Arithmetic operations on tnums will
6  * propagate the unknown bits such that the tnum result represents all the
7  * possible results for possible values of the operands.
8  */
9 #include <linux/kernel.h>
10 #include <linux/tnum.h>
11 #include <linux/swab.h>
12 
13 #define TNUM(_v, _m)	(struct tnum){.value = _v, .mask = _m}
14 /* A completely unknown value */
15 const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
16 
tnum_const(u64 value)17 struct tnum tnum_const(u64 value)
18 {
19 	return TNUM(value, 0);
20 }
21 
tnum_range(u64 min,u64 max)22 struct tnum tnum_range(u64 min, u64 max)
23 {
24 	u64 chi = min ^ max, delta;
25 	u8 bits = fls64(chi);
26 
27 	/* special case, needed because 1ULL << 64 is undefined */
28 	if (bits > 63)
29 		return tnum_unknown;
30 	/* e.g. if chi = 4, bits = 3, delta = (1<<3) - 1 = 7.
31 	 * if chi = 0, bits = 0, delta = (1<<0) - 1 = 0, so we return
32 	 *  constant min (since min == max).
33 	 */
34 	delta = (1ULL << bits) - 1;
35 	return TNUM(min & ~delta, delta);
36 }
37 
tnum_lshift(struct tnum a,u8 shift)38 struct tnum tnum_lshift(struct tnum a, u8 shift)
39 {
40 	return TNUM(a.value << shift, a.mask << shift);
41 }
42 
tnum_rshift(struct tnum a,u8 shift)43 struct tnum tnum_rshift(struct tnum a, u8 shift)
44 {
45 	return TNUM(a.value >> shift, a.mask >> shift);
46 }
47 
tnum_arshift(struct tnum a,u8 min_shift,u8 insn_bitness)48 struct tnum tnum_arshift(struct tnum a, u8 min_shift, u8 insn_bitness)
49 {
50 	/* if a.value is negative, arithmetic shifting by minimum shift
51 	 * will have larger negative offset compared to more shifting.
52 	 * If a.value is nonnegative, arithmetic shifting by minimum shift
53 	 * will have larger positive offset compare to more shifting.
54 	 */
55 	if (insn_bitness == 32)
56 		return TNUM((u32)(((s32)a.value) >> min_shift),
57 			    (u32)(((s32)a.mask)  >> min_shift));
58 	else
59 		return TNUM((s64)a.value >> min_shift,
60 			    (s64)a.mask  >> min_shift);
61 }
62 
tnum_add(struct tnum a,struct tnum b)63 struct tnum tnum_add(struct tnum a, struct tnum b)
64 {
65 	u64 sm, sv, sigma, chi, mu;
66 
67 	sm = a.mask + b.mask;
68 	sv = a.value + b.value;
69 	sigma = sm + sv;
70 	chi = sigma ^ sv;
71 	mu = chi | a.mask | b.mask;
72 	return TNUM(sv & ~mu, mu);
73 }
74 
tnum_sub(struct tnum a,struct tnum b)75 struct tnum tnum_sub(struct tnum a, struct tnum b)
76 {
77 	u64 dv, alpha, beta, chi, mu;
78 
79 	dv = a.value - b.value;
80 	alpha = dv + a.mask;
81 	beta = dv - b.mask;
82 	chi = alpha ^ beta;
83 	mu = chi | a.mask | b.mask;
84 	return TNUM(dv & ~mu, mu);
85 }
86 
tnum_neg(struct tnum a)87 struct tnum tnum_neg(struct tnum a)
88 {
89 	return tnum_sub(TNUM(0, 0), a);
90 }
91 
tnum_and(struct tnum a,struct tnum b)92 struct tnum tnum_and(struct tnum a, struct tnum b)
93 {
94 	u64 alpha, beta, v;
95 
96 	alpha = a.value | a.mask;
97 	beta = b.value | b.mask;
98 	v = a.value & b.value;
99 	return TNUM(v, alpha & beta & ~v);
100 }
101 
tnum_or(struct tnum a,struct tnum b)102 struct tnum tnum_or(struct tnum a, struct tnum b)
103 {
104 	u64 v, mu;
105 
106 	v = a.value | b.value;
107 	mu = a.mask | b.mask;
108 	return TNUM(v, mu & ~v);
109 }
110 
tnum_xor(struct tnum a,struct tnum b)111 struct tnum tnum_xor(struct tnum a, struct tnum b)
112 {
113 	u64 v, mu;
114 
115 	v = a.value ^ b.value;
116 	mu = a.mask | b.mask;
117 	return TNUM(v & ~mu, mu);
118 }
119 
120 /* Perform long multiplication, iterating through the bits in a using rshift:
121  * - if LSB(a) is a known 0, keep current accumulator
122  * - if LSB(a) is a known 1, add b to current accumulator
123  * - if LSB(a) is unknown, take a union of the above cases.
124  *
125  * For example:
126  *
127  *               acc_0:        acc_1:
128  *
129  *     11 *  ->      11 *  ->      11 *  -> union(0011, 1001) == x0x1
130  *     x1            01            11
131  * ------        ------        ------
132  *     11            11            11
133  *    xx            00            11
134  * ------        ------        ------
135  *   ????          0011          1001
136  */
tnum_mul(struct tnum a,struct tnum b)137 struct tnum tnum_mul(struct tnum a, struct tnum b)
138 {
139 	struct tnum acc = TNUM(0, 0);
140 
141 	while (a.value || a.mask) {
142 		/* LSB of tnum a is a certain 1 */
143 		if (a.value & 1)
144 			acc = tnum_add(acc, b);
145 		/* LSB of tnum a is uncertain */
146 		else if (a.mask & 1) {
147 			/* acc = tnum_union(acc_0, acc_1), where acc_0 and
148 			 * acc_1 are partial accumulators for cases
149 			 * LSB(a) = certain 0 and LSB(a) = certain 1.
150 			 * acc_0 = acc + 0 * b = acc.
151 			 * acc_1 = acc + 1 * b = tnum_add(acc, b).
152 			 */
153 
154 			acc = tnum_union(acc, tnum_add(acc, b));
155 		}
156 		/* Note: no case for LSB is certain 0 */
157 		a = tnum_rshift(a, 1);
158 		b = tnum_lshift(b, 1);
159 	}
160 	return acc;
161 }
162 
tnum_overlap(struct tnum a,struct tnum b)163 bool tnum_overlap(struct tnum a, struct tnum b)
164 {
165 	u64 mu;
166 
167 	mu = ~a.mask & ~b.mask;
168 	return (a.value & mu) == (b.value & mu);
169 }
170 
171 /* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
172  * a 'known 0' - this will return a 'known 1' for that bit.
173  */
tnum_intersect(struct tnum a,struct tnum b)174 struct tnum tnum_intersect(struct tnum a, struct tnum b)
175 {
176 	u64 v, mu;
177 
178 	v = a.value | b.value;
179 	mu = a.mask & b.mask;
180 	return TNUM(v & ~mu, mu);
181 }
182 
183 /* Returns a tnum with the uncertainty from both a and b, and in addition, new
184  * uncertainty at any position that a and b disagree. This represents a
185  * superset of the union of the concrete sets of both a and b. Despite the
186  * overapproximation, it is optimal.
187  */
tnum_union(struct tnum a,struct tnum b)188 struct tnum tnum_union(struct tnum a, struct tnum b)
189 {
190 	u64 v = a.value & b.value;
191 	u64 mu = (a.value ^ b.value) | a.mask | b.mask;
192 
193 	return TNUM(v & ~mu, mu);
194 }
195 
tnum_cast(struct tnum a,u8 size)196 struct tnum tnum_cast(struct tnum a, u8 size)
197 {
198 	a.value &= (1ULL << (size * 8)) - 1;
199 	a.mask &= (1ULL << (size * 8)) - 1;
200 	return a;
201 }
202 
tnum_is_aligned(struct tnum a,u64 size)203 bool tnum_is_aligned(struct tnum a, u64 size)
204 {
205 	if (!size)
206 		return true;
207 	return !((a.value | a.mask) & (size - 1));
208 }
209 
tnum_in(struct tnum a,struct tnum b)210 bool tnum_in(struct tnum a, struct tnum b)
211 {
212 	if (b.mask & ~a.mask)
213 		return false;
214 	b.value &= ~a.mask;
215 	return a.value == b.value;
216 }
217 
tnum_sbin(char * str,size_t size,struct tnum a)218 int tnum_sbin(char *str, size_t size, struct tnum a)
219 {
220 	size_t n;
221 
222 	for (n = 64; n; n--) {
223 		if (n < size) {
224 			if (a.mask & 1)
225 				str[n - 1] = 'x';
226 			else if (a.value & 1)
227 				str[n - 1] = '1';
228 			else
229 				str[n - 1] = '0';
230 		}
231 		a.mask >>= 1;
232 		a.value >>= 1;
233 	}
234 	str[min(size - 1, (size_t)64)] = 0;
235 	return 64;
236 }
237 
tnum_subreg(struct tnum a)238 struct tnum tnum_subreg(struct tnum a)
239 {
240 	return tnum_cast(a, 4);
241 }
242 
tnum_clear_subreg(struct tnum a)243 struct tnum tnum_clear_subreg(struct tnum a)
244 {
245 	return tnum_lshift(tnum_rshift(a, 32), 32);
246 }
247 
tnum_with_subreg(struct tnum reg,struct tnum subreg)248 struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
249 {
250 	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
251 }
252 
tnum_const_subreg(struct tnum a,u32 value)253 struct tnum tnum_const_subreg(struct tnum a, u32 value)
254 {
255 	return tnum_with_subreg(a, tnum_const(value));
256 }
257 
tnum_bswap16(struct tnum a)258 struct tnum tnum_bswap16(struct tnum a)
259 {
260 	return TNUM(swab16(a.value & 0xFFFF), swab16(a.mask & 0xFFFF));
261 }
262 
tnum_bswap32(struct tnum a)263 struct tnum tnum_bswap32(struct tnum a)
264 {
265 	return TNUM(swab32(a.value & 0xFFFFFFFF), swab32(a.mask & 0xFFFFFFFF));
266 }
267 
tnum_bswap64(struct tnum a)268 struct tnum tnum_bswap64(struct tnum a)
269 {
270 	return TNUM(swab64(a.value), swab64(a.mask));
271 }
272 
273 /* Given tnum t, and a number z such that tmin <= z < tmax, where tmin
274  * is the smallest member of the t (= t.value) and tmax is the largest
275  * member of t (= t.value | t.mask), returns the smallest member of t
276  * larger than z.
277  *
278  * For example,
279  * t      = x11100x0
280  * z      = 11110001 (241)
281  * result = 11110010 (242)
282  *
283  * Note: if this function is called with z >= tmax, it just returns
284  * early with tmax; if this function is called with z < tmin, the
285  * algorithm already returns tmin.
286  */
tnum_step(struct tnum t,u64 z)287 u64 tnum_step(struct tnum t, u64 z)
288 {
289 	u64 tmax, d, carry_mask, filled, inc;
290 
291 	tmax = t.value | t.mask;
292 
293 	/* if z >= largest member of t, return largest member of t */
294 	if (z >= tmax)
295 		return tmax;
296 
297 	/* if z < smallest member of t, return smallest member of t */
298 	if (z < t.value)
299 		return t.value;
300 
301 	/*
302 	 * Let r be the result tnum member, z = t.value + d.
303 	 * Every tnum member is t.value | s for some submask s of t.mask,
304 	 * and since t.value & t.mask == 0, t.value | s == t.value + s.
305 	 * So r > z becomes s > d where d = z - t.value.
306 	 *
307 	 * Find the smallest submask s of t.mask greater than d by
308 	 * "incrementing d within the mask": fill every non-mask
309 	 * position with 1 (`filled`) so +1 ripples through the gaps,
310 	 * then keep only mask bits. `carry_mask` additionally fills
311 	 * positions below the highest non-mask 1 in d, preventing
312 	 * it from trapping the carry.
313 	 */
314 	d = z - t.value;
315 	carry_mask = (1ULL << fls64(d & ~t.mask)) - 1;
316 	filled = d | carry_mask | ~t.mask;
317 	inc = (filled + 1) & t.mask;
318 	return t.value | inc;
319 }
320