xref: /qemu/tests/tcg/x86_64/fma.c (revision 70ce076fa6dff60585c229a4b641b13e64bf03cf)
1 /*
2  * Test some fused multiply add corner cases.
3  *
4  * SPDX-License-Identifier: GPL-2.0-or-later
5  */
6 #include <stdio.h>
7 #include <stdint.h>
8 #include <stdbool.h>
9 #include <inttypes.h>
10 
11 #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
12 
13 /*
14  * Perform one "n * m + a" operation using the vfmadd insn and return
15  * the result; on return *mxcsr_p is set to the bottom 6 bits of MXCSR
16  * (the Flag bits). If ftz is true then we set MXCSR.FTZ while doing
17  * the operation.
18  * We print the operation and its results to stdout.
19  */
20 static uint64_t do_fmadd(uint64_t n, uint64_t m, uint64_t a,
21                          bool ftz, uint32_t *mxcsr_p)
22 {
23     uint64_t r;
24     uint32_t mxcsr = 0;
25     uint32_t ftz_bit = ftz ? (1 << 15) : 0;
26     uint32_t saved_mxcsr = 0;
27 
28     asm volatile("stmxcsr %[saved_mxcsr]\n"
29                  "stmxcsr %[mxcsr]\n"
30                  "andl $0xffff7fc0, %[mxcsr]\n"
31                  "orl %[ftz_bit], %[mxcsr]\n"
32                  "ldmxcsr %[mxcsr]\n"
33                  "movq %[a], %%xmm0\n"
34                  "movq %[m], %%xmm1\n"
35                  "movq %[n], %%xmm2\n"
36                  /* xmm0 = xmm0 + xmm2 * xmm1 */
37                  "vfmadd231sd %%xmm1, %%xmm2, %%xmm0\n"
38                  "movq %%xmm0, %[r]\n"
39                  "stmxcsr %[mxcsr]\n"
40                  "ldmxcsr %[saved_mxcsr]\n"
41                  : [r] "=r" (r), [mxcsr] "=m" (mxcsr),
42                    [saved_mxcsr] "=m" (saved_mxcsr)
43                  : [n] "r" (n), [m] "r" (m), [a] "r" (a),
44                    [ftz_bit] "r" (ftz_bit)
45                  : "xmm0", "xmm1", "xmm2");
46     *mxcsr_p = mxcsr & 0x3f;
47     printf("vfmadd132sd 0x%" PRIx64 " 0x%" PRIx64 " 0x%" PRIx64
48            " = 0x%" PRIx64 " MXCSR flags 0x%" PRIx32 "\n",
49            n, m, a, r, *mxcsr_p);
50     return r;
51 }
52 
53 typedef struct testdata {
54     /* Input n, m, a */
55     uint64_t n;
56     uint64_t m;
57     uint64_t a;
58     bool ftz;
59     /* Expected result */
60     uint64_t expected_r;
61     /* Expected low 6 bits of MXCSR (the Flag bits) */
62     uint32_t expected_mxcsr;
63 } testdata;
64 
65 static testdata tests[] = {
66     { 0, 0x7ff0000000000000, 0x7ff000000000aaaa, false, /* 0 * Inf + SNaN */
67       0x7ff800000000aaaa, 1 }, /* Should be QNaN and does raise Invalid */
68     { 0, 0x7ff0000000000000, 0x7ff800000000aaaa, false, /* 0 * Inf + QNaN */
69       0x7ff800000000aaaa, 0 }, /* Should be QNaN and does *not* raise Invalid */
70     /*
71      * These inputs give a result which is tiny before rounding but which
72      * becomes non-tiny after rounding. x86 is a "detect tininess after
73      * rounding" architecture, so it should give a non-denormal result and
74      * not set the Underflow flag (only the Precision flag for an inexact
75      * result).
76      */
77     { 0x3fdfffffffffffff, 0x001fffffffffffff, 0x801fffffffffffff, false,
78       0x8010000000000000, 0x20 },
79     /*
80      * Flushing of denormal outputs to zero should also happen after
81      * rounding, so setting FTZ should not affect the result or the flags.
82      * QEMU currently does not emulate this correctly because we do the
83      * flush-to-zero check before rounding, so we incorrectly produce a
84      * zero result and set Underflow as well as Precision.
85      */
86 #ifdef ENABLE_FAILING_TESTS
87     { 0x3fdfffffffffffff, 0x001fffffffffffff, 0x801fffffffffffff, true,
88       0x8010000000000000, 0x20 }, /* Enabling FTZ shouldn't change flags */
89 #endif
90 };
91 
92 int main(void)
93 {
94     bool passed = true;
95     for (int i = 0; i < ARRAY_SIZE(tests); i++) {
96         uint32_t mxcsr;
97         uint64_t r = do_fmadd(tests[i].n, tests[i].m, tests[i].a,
98                               tests[i].ftz, &mxcsr);
99         if (r != tests[i].expected_r) {
100             printf("expected result 0x%" PRIx64 "\n", tests[i].expected_r);
101             passed = false;
102         }
103         if (mxcsr != tests[i].expected_mxcsr) {
104             printf("expected MXCSR flags 0x%x\n", tests[i].expected_mxcsr);
105             passed = false;
106         }
107     }
108     return passed ? 0 : 1;
109 }
110