xref: /src/crypto/openssl/include/internal/safe_math.h (revision f25b8c9fb4f58cf61adb47d7570abe7caa6d385d)
1 /*
2  * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #ifndef OSSL_INTERNAL_SAFE_MATH_H
11 #define OSSL_INTERNAL_SAFE_MATH_H
12 #pragma once
13 
14 #include <openssl/e_os2.h> /* For 'ossl_inline' */
15 
16 #ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
17 #ifdef __has_builtin
18 #define has(func) __has_builtin(func)
19 #elif defined(__GNUC__)
20 #if __GNUC__ > 5
21 #define has(func) 1
22 #endif
23 #endif
24 #endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
25 
26 #ifndef has
27 #define has(func) 0
28 #endif
29 
30 /*
31  * Safe addition helpers
32  */
33 #if has(__builtin_add_overflow)
34 #define OSSL_SAFE_MATH_ADDS(type_name, type, min, max)               \
35     static ossl_inline ossl_unused type safe_add_##type_name(type a, \
36         type b,                                                      \
37         int *err)                                                    \
38     {                                                                \
39         type r;                                                      \
40                                                                      \
41         if (!__builtin_add_overflow(a, b, &r))                       \
42             return r;                                                \
43         *err |= 1;                                                   \
44         return a < 0 ? min : max;                                    \
45     }
46 
47 #define OSSL_SAFE_MATH_ADDU(type_name, type, max)                    \
48     static ossl_inline ossl_unused type safe_add_##type_name(type a, \
49         type b,                                                      \
50         int *err)                                                    \
51     {                                                                \
52         type r;                                                      \
53                                                                      \
54         if (!__builtin_add_overflow(a, b, &r))                       \
55             return r;                                                \
56         *err |= 1;                                                   \
57         return a + b;                                                \
58     }
59 
60 #else /* has(__builtin_add_overflow) */
61 #define OSSL_SAFE_MATH_ADDS(type_name, type, min, max)               \
62     static ossl_inline ossl_unused type safe_add_##type_name(type a, \
63         type b,                                                      \
64         int *err)                                                    \
65     {                                                                \
66         if ((a < 0) ^ (b < 0)                                        \
67             || (a > 0 && b <= max - a)                               \
68             || (a < 0 && b >= min - a)                               \
69             || a == 0)                                               \
70             return a + b;                                            \
71         *err |= 1;                                                   \
72         return a < 0 ? min : max;                                    \
73     }
74 
75 #define OSSL_SAFE_MATH_ADDU(type_name, type, max)                    \
76     static ossl_inline ossl_unused type safe_add_##type_name(type a, \
77         type b,                                                      \
78         int *err)                                                    \
79     {                                                                \
80         if (b > max - a)                                             \
81             *err |= 1;                                               \
82         return a + b;                                                \
83     }
84 #endif /* has(__builtin_add_overflow) */
85 
86 /*
87  * Safe subtraction helpers
88  */
89 #if has(__builtin_sub_overflow)
90 #define OSSL_SAFE_MATH_SUBS(type_name, type, min, max)               \
91     static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
92         type b,                                                      \
93         int *err)                                                    \
94     {                                                                \
95         type r;                                                      \
96                                                                      \
97         if (!__builtin_sub_overflow(a, b, &r))                       \
98             return r;                                                \
99         *err |= 1;                                                   \
100         return a < 0 ? min : max;                                    \
101     }
102 
103 #else /* has(__builtin_sub_overflow) */
104 #define OSSL_SAFE_MATH_SUBS(type_name, type, min, max)               \
105     static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
106         type b,                                                      \
107         int *err)                                                    \
108     {                                                                \
109         if (!((a < 0) ^ (b < 0))                                     \
110             || (b > 0 && a >= min + b)                               \
111             || (b < 0 && a <= max + b)                               \
112             || b == 0)                                               \
113             return a - b;                                            \
114         *err |= 1;                                                   \
115         return a < 0 ? min : max;                                    \
116     }
117 
118 #endif /* has(__builtin_sub_overflow) */
119 
120 #define OSSL_SAFE_MATH_SUBU(type_name, type)                         \
121     static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
122         type b,                                                      \
123         int *err)                                                    \
124     {                                                                \
125         if (b > a)                                                   \
126             *err |= 1;                                               \
127         return a - b;                                                \
128     }
129 
130 /*
131  * Safe multiplication helpers
132  */
133 #if has(__builtin_mul_overflow)
134 #define OSSL_SAFE_MATH_MULS(type_name, type, min, max)               \
135     static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
136         type b,                                                      \
137         int *err)                                                    \
138     {                                                                \
139         type r;                                                      \
140                                                                      \
141         if (!__builtin_mul_overflow(a, b, &r))                       \
142             return r;                                                \
143         *err |= 1;                                                   \
144         return (a < 0) ^ (b < 0) ? min : max;                        \
145     }
146 
147 #define OSSL_SAFE_MATH_MULU(type_name, type, max)                    \
148     static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
149         type b,                                                      \
150         int *err)                                                    \
151     {                                                                \
152         type r;                                                      \
153                                                                      \
154         if (!__builtin_mul_overflow(a, b, &r))                       \
155             return r;                                                \
156         *err |= 1;                                                   \
157         return a * b;                                                \
158     }
159 
160 #else /* has(__builtin_mul_overflow) */
161 #define OSSL_SAFE_MATH_MULS(type_name, type, min, max)               \
162     static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
163         type b,                                                      \
164         int *err)                                                    \
165     {                                                                \
166         if (a == 0 || b == 0)                                        \
167             return 0;                                                \
168         if (a == 1)                                                  \
169             return b;                                                \
170         if (b == 1)                                                  \
171             return a;                                                \
172         if (a != min && b != min) {                                  \
173             const type x = a < 0 ? -a : a;                           \
174             const type y = b < 0 ? -b : b;                           \
175                                                                      \
176             if (x <= max / y)                                        \
177                 return a * b;                                        \
178         }                                                            \
179         *err |= 1;                                                   \
180         return (a < 0) ^ (b < 0) ? min : max;                        \
181     }
182 
183 #define OSSL_SAFE_MATH_MULU(type_name, type, max)                    \
184     static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
185         type b,                                                      \
186         int *err)                                                    \
187     {                                                                \
188         if (b != 0 && a > max / b)                                   \
189             *err |= 1;                                               \
190         return a * b;                                                \
191     }
192 #endif /* has(__builtin_mul_overflow) */
193 
194 /*
195  * Safe division helpers
196  */
197 #define OSSL_SAFE_MATH_DIVS(type_name, type, min, max)               \
198     static ossl_inline ossl_unused type safe_div_##type_name(type a, \
199         type b,                                                      \
200         int *err)                                                    \
201     {                                                                \
202         if (b == 0) {                                                \
203             *err |= 1;                                               \
204             return a < 0 ? min : max;                                \
205         }                                                            \
206         if (b == -1 && a == min) {                                   \
207             *err |= 1;                                               \
208             return max;                                              \
209         }                                                            \
210         return a / b;                                                \
211     }
212 
213 #define OSSL_SAFE_MATH_DIVU(type_name, type, max)                    \
214     static ossl_inline ossl_unused type safe_div_##type_name(type a, \
215         type b,                                                      \
216         int *err)                                                    \
217     {                                                                \
218         if (b != 0)                                                  \
219             return a / b;                                            \
220         *err |= 1;                                                   \
221         return max;                                                  \
222     }
223 
224 /*
225  * Safe modulus helpers
226  */
227 #define OSSL_SAFE_MATH_MODS(type_name, type, min, max)               \
228     static ossl_inline ossl_unused type safe_mod_##type_name(type a, \
229         type b,                                                      \
230         int *err)                                                    \
231     {                                                                \
232         if (b == 0) {                                                \
233             *err |= 1;                                               \
234             return 0;                                                \
235         }                                                            \
236         if (b == -1 && a == min) {                                   \
237             *err |= 1;                                               \
238             return max;                                              \
239         }                                                            \
240         return a % b;                                                \
241     }
242 
243 #define OSSL_SAFE_MATH_MODU(type_name, type)                         \
244     static ossl_inline ossl_unused type safe_mod_##type_name(type a, \
245         type b,                                                      \
246         int *err)                                                    \
247     {                                                                \
248         if (b != 0)                                                  \
249             return a % b;                                            \
250         *err |= 1;                                                   \
251         return 0;                                                    \
252     }
253 
254 /*
255  * Safe negation helpers
256  */
257 #define OSSL_SAFE_MATH_NEGS(type_name, type, min)                    \
258     static ossl_inline ossl_unused type safe_neg_##type_name(type a, \
259         int *err)                                                    \
260     {                                                                \
261         if (a != min)                                                \
262             return -a;                                               \
263         *err |= 1;                                                   \
264         return min;                                                  \
265     }
266 
267 #define OSSL_SAFE_MATH_NEGU(type_name, type)                         \
268     static ossl_inline ossl_unused type safe_neg_##type_name(type a, \
269         int *err)                                                    \
270     {                                                                \
271         if (a == 0)                                                  \
272             return a;                                                \
273         *err |= 1;                                                   \
274         return 1 + ~a;                                               \
275     }
276 
277 /*
278  * Safe absolute value helpers
279  */
280 #define OSSL_SAFE_MATH_ABSS(type_name, type, min)                    \
281     static ossl_inline ossl_unused type safe_abs_##type_name(type a, \
282         int *err)                                                    \
283     {                                                                \
284         if (a != min)                                                \
285             return a < 0 ? -a : a;                                   \
286         *err |= 1;                                                   \
287         return min;                                                  \
288     }
289 
290 #define OSSL_SAFE_MATH_ABSU(type_name, type)                         \
291     static ossl_inline ossl_unused type safe_abs_##type_name(type a, \
292         int *err)                                                    \
293     {                                                                \
294         return a;                                                    \
295     }
296 
297 /*
298  * Safe fused multiply divide helpers
299  *
300  * These are a bit obscure:
301  *    . They begin by checking the denominator for zero and getting rid of this
302  *      corner case.
303  *
304  *    . Second is an attempt to do the multiplication directly, if it doesn't
305  *      overflow, the quotient is returned (for signed values there is a
306  *      potential problem here which isn't present for unsigned).
307  *
308  *    . Finally, the multiplication/division is transformed so that the larger
309  *      of the numerators is divided first.  This requires a remainder
310  *      correction:
311  *
312  *          a b / c = (a / c) b + (a mod c) b / c, where a > b
313  *
314  *      The individual operations need to be overflow checked (again signed
315  *      being more problematic).
316  *
317  * The algorithm used is not perfect but it should be "good enough".
318  */
319 #define OSSL_SAFE_MATH_MULDIVS(type_name, type, max)                    \
320     static ossl_inline ossl_unused type safe_muldiv_##type_name(type a, \
321         type b,                                                         \
322         type c,                                                         \
323         int *err)                                                       \
324     {                                                                   \
325         int e2 = 0;                                                     \
326         type q, r, x, y;                                                \
327                                                                         \
328         if (c == 0) {                                                   \
329             *err |= 1;                                                  \
330             return a == 0 || b == 0 ? 0 : max;                          \
331         }                                                               \
332         x = safe_mul_##type_name(a, b, &e2);                            \
333         if (!e2)                                                        \
334             return safe_div_##type_name(x, c, err);                     \
335         if (b > a) {                                                    \
336             x = b;                                                      \
337             b = a;                                                      \
338             a = x;                                                      \
339         }                                                               \
340         q = safe_div_##type_name(a, c, err);                            \
341         r = safe_mod_##type_name(a, c, err);                            \
342         x = safe_mul_##type_name(r, b, err);                            \
343         y = safe_mul_##type_name(q, b, err);                            \
344         q = safe_div_##type_name(x, c, err);                            \
345         return safe_add_##type_name(y, q, err);                         \
346     }
347 
348 #define OSSL_SAFE_MATH_MULDIVU(type_name, type, max)                    \
349     static ossl_inline ossl_unused type safe_muldiv_##type_name(type a, \
350         type b,                                                         \
351         type c,                                                         \
352         int *err)                                                       \
353     {                                                                   \
354         int e2 = 0;                                                     \
355         type x, y;                                                      \
356                                                                         \
357         if (c == 0) {                                                   \
358             *err |= 1;                                                  \
359             return a == 0 || b == 0 ? 0 : max;                          \
360         }                                                               \
361         x = safe_mul_##type_name(a, b, &e2);                            \
362         if (!e2)                                                        \
363             return x / c;                                               \
364         if (b > a) {                                                    \
365             x = b;                                                      \
366             b = a;                                                      \
367             a = x;                                                      \
368         }                                                               \
369         x = safe_mul_##type_name(a % c, b, err);                        \
370         y = safe_mul_##type_name(a / c, b, err);                        \
371         return safe_add_##type_name(y, x / c, err);                     \
372     }
373 
374 /*
375  * Calculate a / b rounding up:
376  *     i.e. a / b + (a % b != 0)
377  * Which is usually (less safely) converted to (a + b - 1) / b
378  * If you *know* that b != 0, then it's safe to ignore err.
379  */
380 #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max)                                        \
381     static ossl_inline ossl_unused type safe_div_round_up_##type_name(type a, type b, int *errp) \
382     {                                                                                            \
383         type x;                                                                                  \
384         int *err, err_local = 0;                                                                 \
385                                                                                                  \
386         /* Allow errors to be ignored by callers */                                              \
387         err = errp != NULL ? errp : &err_local;                                                  \
388         /* Fast path, both positive */                                                           \
389         if (b > 0 && a > 0) {                                                                    \
390             /* Faster path: no overflow concerns */                                              \
391             if (a < max - b)                                                                     \
392                 return (a + b - 1) / b;                                                          \
393             return a / b + (a % b != 0);                                                         \
394         }                                                                                        \
395         if (b == 0) {                                                                            \
396             *err |= 1;                                                                           \
397             return a == 0 ? 0 : max;                                                             \
398         }                                                                                        \
399         if (a == 0)                                                                              \
400             return 0;                                                                            \
401         /* Rather slow path because there are negatives involved */                              \
402         x = safe_mod_##type_name(a, b, err);                                                     \
403         return safe_add_##type_name(safe_div_##type_name(a, b, err),                             \
404             x != 0, err);                                                                        \
405     }
406 
407 /* Calculate ranges of types */
408 #define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
409 #define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
410 #define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
411 
412 /*
413  * Wrapper macros to create all the functions of a given type
414  */
415 #define OSSL_SAFE_MATH_SIGNED(type_name, type)                         \
416     OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type),    \
417         OSSL_SAFE_MATH_MAXS(type))                                     \
418     OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type),    \
419         OSSL_SAFE_MATH_MAXS(type))                                     \
420     OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type),    \
421         OSSL_SAFE_MATH_MAXS(type))                                     \
422     OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type),    \
423         OSSL_SAFE_MATH_MAXS(type))                                     \
424     OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type),    \
425         OSSL_SAFE_MATH_MAXS(type))                                     \
426     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                       \
427         OSSL_SAFE_MATH_MAXS(type))                                     \
428     OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
429     OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type))    \
430     OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
431 
432 #define OSSL_SAFE_MATH_UNSIGNED(type_name, type)                       \
433     OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type))    \
434     OSSL_SAFE_MATH_SUBU(type_name, type)                               \
435     OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type))    \
436     OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type))    \
437     OSSL_SAFE_MATH_MODU(type_name, type)                               \
438     OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type,                       \
439         OSSL_SAFE_MATH_MAXU(type))                                     \
440     OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
441     OSSL_SAFE_MATH_NEGU(type_name, type)                               \
442     OSSL_SAFE_MATH_ABSU(type_name, type)
443 
444 #endif /* OSSL_INTERNAL_SAFE_MATH_H */
445