xref: /qemu/target/arm/tcg/sme_helper.c (revision df6fe2abf2e990f767ce755d426bc439c7bba336)
1 /*
2  * ARM SME Operations
3  *
4  * Copyright (c) 2022 Linaro, Ltd.
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2.1 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public
17  * License along with this library; if not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include "qemu/osdep.h"
21 #include "cpu.h"
22 #include "internals.h"
23 #include "tcg/tcg-gvec-desc.h"
24 #include "exec/helper-proto.h"
25 #include "accel/tcg/cpu-ldst.h"
26 #include "accel/tcg/helper-retaddr.h"
27 #include "qemu/int128.h"
28 #include "fpu/softfloat.h"
29 #include "vec_internal.h"
30 #include "sve_ldst_internal.h"
31 
32 
33 static bool vectors_overlap(ARMVectorReg *x, unsigned nx,
34                             ARMVectorReg *y, unsigned ny)
35 {
36     return !(x + nx <= y || y + ny <= x);
37 }
38 
39 void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask)
40 {
41     aarch64_set_svcr(env, val, mask);
42 }
43 
44 void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl)
45 {
46     uint32_t i;
47 
48     /*
49      * Special case clearing the entire ZArray.
50      * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any
51      * parts of the ZA storage outside of SVL.
52      */
53     if (imm == 0xff) {
54         memset(env->za_state.za, 0, sizeof(env->za_state.za));
55         return;
56     }
57 
58     /*
59      * Recall that ZAnH.D[m] is spread across ZA[n+8*m],
60      * so each row is discontiguous within ZA[].
61      */
62     for (i = 0; i < svl; i++) {
63         if (imm & (1 << (i % 8))) {
64             memset(&env->za_state.za[i], 0, svl);
65         }
66     }
67 }
68 
69 
70 /*
71  * When considering the ZA storage as an array of elements of
72  * type T, the index within that array of the Nth element of
73  * a vertical slice of a tile can be calculated like this,
74  * regardless of the size of type T. This is because the tiles
75  * are interleaved, so if type T is size N bytes then row 1 of
76  * the tile is N rows away from row 0. The division by N to
77  * convert a byte offset into an array index and the multiplication
78  * by N to convert from vslice-index-within-the-tile to
79  * the index within the ZA storage cancel out.
80  */
81 #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg))
82 
83 /*
84  * When doing byte arithmetic on the ZA storage, the element
85  * byteoff bytes away in a tile vertical slice is always this
86  * many bytes away in the ZA storage, regardless of the
87  * size of the tile element, assuming that byteoff is a multiple
88  * of the element size. Again this is because of the interleaving
89  * of the tiles. For instance if we have 1 byte per element then
90  * each row of the ZA storage has one byte of the vslice data,
91  * and (counting from 0) byte 8 goes in row 8 of the storage
92  * at offset (8 * row-size-in-bytes).
93  * If we have 8 bytes per element then each row of the ZA storage
94  * has 8 bytes of the data, but there are 8 interleaved tiles and
95  * so byte 8 of the data goes into row 1 of the tile,
96  * which is again row 8 of the storage, so the offset is still
97  * (8 * row-size-in-bytes). Similarly for other element sizes.
98  */
99 #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg))
100 
101 
102 /*
103  * Move Zreg vector to ZArray column.
104  */
105 #define DO_MOVA_C(NAME, TYPE, H)                                        \
106 void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc)          \
107 {                                                                       \
108     int i, oprsz = simd_oprsz(desc);                                    \
109     for (i = 0; i < oprsz; ) {                                          \
110         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
111         do {                                                            \
112             if (pg & 1) {                                               \
113                 *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \
114             }                                                           \
115             i += sizeof(TYPE);                                          \
116             pg >>= sizeof(TYPE);                                        \
117         } while (i & 15);                                               \
118     }                                                                   \
119 }
120 
121 DO_MOVA_C(sme_mova_cz_b, uint8_t, H1)
122 DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2)
123 DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4)
124 
125 void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc)
126 {
127     int i, oprsz = simd_oprsz(desc) / 8;
128     uint8_t *pg = vg;
129     uint64_t *n = vn;
130     uint64_t *a = za;
131 
132     for (i = 0; i < oprsz; i++) {
133         if (pg[H1(i)] & 1) {
134             a[tile_vslice_index(i)] = n[i];
135         }
136     }
137 }
138 
139 void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc)
140 {
141     int i, oprsz = simd_oprsz(desc) / 16;
142     uint16_t *pg = vg;
143     Int128 *n = vn;
144     Int128 *a = za;
145 
146     /*
147      * Int128 is used here simply to copy 16 bytes, and to simplify
148      * the address arithmetic.
149      */
150     for (i = 0; i < oprsz; i++) {
151         if (pg[H2(i)] & 1) {
152             a[tile_vslice_index(i)] = n[i];
153         }
154     }
155 }
156 
157 #undef DO_MOVA_C
158 
159 /*
160  * Move ZArray column to Zreg vector.
161  */
162 #define DO_MOVA_Z(NAME, TYPE, H)                                        \
163 void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc)          \
164 {                                                                       \
165     int i, oprsz = simd_oprsz(desc);                                    \
166     for (i = 0; i < oprsz; ) {                                          \
167         uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3));                 \
168         do {                                                            \
169             if (pg & 1) {                                               \
170                 *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \
171             }                                                           \
172             i += sizeof(TYPE);                                          \
173             pg >>= sizeof(TYPE);                                        \
174         } while (i & 15);                                               \
175     }                                                                   \
176 }
177 
178 DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1)
179 DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2)
180 DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4)
181 
182 void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc)
183 {
184     int i, oprsz = simd_oprsz(desc) / 8;
185     uint8_t *pg = vg;
186     uint64_t *d = vd;
187     uint64_t *a = za;
188 
189     for (i = 0; i < oprsz; i++) {
190         if (pg[H1(i)] & 1) {
191             d[i] = a[tile_vslice_index(i)];
192         }
193     }
194 }
195 
196 void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc)
197 {
198     int i, oprsz = simd_oprsz(desc) / 16;
199     uint16_t *pg = vg;
200     Int128 *d = vd;
201     Int128 *a = za;
202 
203     /*
204      * Int128 is used here simply to copy 16 bytes, and to simplify
205      * the address arithmetic.
206      */
207     for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) {
208         if (pg[H2(i)] & 1) {
209             d[i] = a[tile_vslice_index(i)];
210         }
211     }
212 }
213 
214 #undef DO_MOVA_Z
215 
216 void HELPER(sme2_mova_zc_b)(void *vdst, void *vsrc, uint32_t desc)
217 {
218     const uint8_t *src = vsrc;
219     uint8_t *dst = vdst;
220     size_t i, n = simd_oprsz(desc);
221 
222     for (i = 0; i < n; ++i) {
223         dst[i] = src[tile_vslice_index(i)];
224     }
225 }
226 
227 void HELPER(sme2_mova_zc_h)(void *vdst, void *vsrc, uint32_t desc)
228 {
229     const uint16_t *src = vsrc;
230     uint16_t *dst = vdst;
231     size_t i, n = simd_oprsz(desc) / 2;
232 
233     for (i = 0; i < n; ++i) {
234         dst[i] = src[tile_vslice_index(i)];
235     }
236 }
237 
238 void HELPER(sme2_mova_zc_s)(void *vdst, void *vsrc, uint32_t desc)
239 {
240     const uint32_t *src = vsrc;
241     uint32_t *dst = vdst;
242     size_t i, n = simd_oprsz(desc) / 4;
243 
244     for (i = 0; i < n; ++i) {
245         dst[i] = src[tile_vslice_index(i)];
246     }
247 }
248 
249 void HELPER(sme2_mova_zc_d)(void *vdst, void *vsrc, uint32_t desc)
250 {
251     const uint64_t *src = vsrc;
252     uint64_t *dst = vdst;
253     size_t i, n = simd_oprsz(desc) / 8;
254 
255     for (i = 0; i < n; ++i) {
256         dst[i] = src[tile_vslice_index(i)];
257     }
258 }
259 
260 void HELPER(sme2p1_movaz_zc_b)(void *vdst, void *vsrc, uint32_t desc)
261 {
262     uint8_t *src = vsrc;
263     uint8_t *dst = vdst;
264     size_t i, n = simd_oprsz(desc);
265 
266     for (i = 0; i < n; ++i) {
267         dst[i] = src[tile_vslice_index(i)];
268         src[tile_vslice_index(i)] = 0;
269     }
270 }
271 
272 void HELPER(sme2p1_movaz_zc_h)(void *vdst, void *vsrc, uint32_t desc)
273 {
274     uint16_t *src = vsrc;
275     uint16_t *dst = vdst;
276     size_t i, n = simd_oprsz(desc) / 2;
277 
278     for (i = 0; i < n; ++i) {
279         dst[i] = src[tile_vslice_index(i)];
280         src[tile_vslice_index(i)] = 0;
281     }
282 }
283 
284 void HELPER(sme2p1_movaz_zc_s)(void *vdst, void *vsrc, uint32_t desc)
285 {
286     uint32_t *src = vsrc;
287     uint32_t *dst = vdst;
288     size_t i, n = simd_oprsz(desc) / 4;
289 
290     for (i = 0; i < n; ++i) {
291         dst[i] = src[tile_vslice_index(i)];
292         src[tile_vslice_index(i)] = 0;
293     }
294 }
295 
296 void HELPER(sme2p1_movaz_zc_d)(void *vdst, void *vsrc, uint32_t desc)
297 {
298     uint64_t *src = vsrc;
299     uint64_t *dst = vdst;
300     size_t i, n = simd_oprsz(desc) / 8;
301 
302     for (i = 0; i < n; ++i) {
303         dst[i] = src[tile_vslice_index(i)];
304         src[tile_vslice_index(i)] = 0;
305     }
306 }
307 
308 void HELPER(sme2p1_movaz_zc_q)(void *vdst, void *vsrc, uint32_t desc)
309 {
310     Int128 *src = vsrc;
311     Int128 *dst = vdst;
312     size_t i, n = simd_oprsz(desc) / 16;
313 
314     for (i = 0; i < n; ++i) {
315         dst[i] = src[tile_vslice_index(i)];
316         memset(&src[tile_vslice_index(i)], 0, 16);
317     }
318 }
319 
320 /*
321  * Clear elements in a tile slice comprising len bytes.
322  */
323 
324 typedef void ClearFn(void *ptr, size_t off, size_t len);
325 
326 static void clear_horizontal(void *ptr, size_t off, size_t len)
327 {
328     memset(ptr + off, 0, len);
329 }
330 
331 static void clear_vertical_b(void *vptr, size_t off, size_t len)
332 {
333     for (size_t i = 0; i < len; ++i) {
334         *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0;
335     }
336 }
337 
338 static void clear_vertical_h(void *vptr, size_t off, size_t len)
339 {
340     for (size_t i = 0; i < len; i += 2) {
341         *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0;
342     }
343 }
344 
345 static void clear_vertical_s(void *vptr, size_t off, size_t len)
346 {
347     for (size_t i = 0; i < len; i += 4) {
348         *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0;
349     }
350 }
351 
352 static void clear_vertical_d(void *vptr, size_t off, size_t len)
353 {
354     for (size_t i = 0; i < len; i += 8) {
355         *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0;
356     }
357 }
358 
359 static void clear_vertical_q(void *vptr, size_t off, size_t len)
360 {
361     for (size_t i = 0; i < len; i += 16) {
362         memset(vptr + tile_vslice_offset(i + off), 0, 16);
363     }
364 }
365 
366 /*
367  * Copy elements from an array into a tile slice comprising len bytes.
368  */
369 
370 typedef void CopyFn(void *dst, const void *src, size_t len);
371 
372 static void copy_horizontal(void *dst, const void *src, size_t len)
373 {
374     memcpy(dst, src, len);
375 }
376 
377 static void copy_vertical_b(void *vdst, const void *vsrc, size_t len)
378 {
379     const uint8_t *src = vsrc;
380     uint8_t *dst = vdst;
381     size_t i;
382 
383     for (i = 0; i < len; ++i) {
384         dst[tile_vslice_index(i)] = src[i];
385     }
386 }
387 
388 static void copy_vertical_h(void *vdst, const void *vsrc, size_t len)
389 {
390     const uint16_t *src = vsrc;
391     uint16_t *dst = vdst;
392     size_t i;
393 
394     for (i = 0; i < len / 2; ++i) {
395         dst[tile_vslice_index(i)] = src[i];
396     }
397 }
398 
399 static void copy_vertical_s(void *vdst, const void *vsrc, size_t len)
400 {
401     const uint32_t *src = vsrc;
402     uint32_t *dst = vdst;
403     size_t i;
404 
405     for (i = 0; i < len / 4; ++i) {
406         dst[tile_vslice_index(i)] = src[i];
407     }
408 }
409 
410 static void copy_vertical_d(void *vdst, const void *vsrc, size_t len)
411 {
412     const uint64_t *src = vsrc;
413     uint64_t *dst = vdst;
414     size_t i;
415 
416     for (i = 0; i < len / 8; ++i) {
417         dst[tile_vslice_index(i)] = src[i];
418     }
419 }
420 
421 static void copy_vertical_q(void *vdst, const void *vsrc, size_t len)
422 {
423     for (size_t i = 0; i < len; i += 16) {
424         memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16);
425     }
426 }
427 
428 void HELPER(sme2_mova_cz_b)(void *vdst, void *vsrc, uint32_t desc)
429 {
430     copy_vertical_b(vdst, vsrc, simd_oprsz(desc));
431 }
432 
433 void HELPER(sme2_mova_cz_h)(void *vdst, void *vsrc, uint32_t desc)
434 {
435     copy_vertical_h(vdst, vsrc, simd_oprsz(desc));
436 }
437 
438 void HELPER(sme2_mova_cz_s)(void *vdst, void *vsrc, uint32_t desc)
439 {
440     copy_vertical_s(vdst, vsrc, simd_oprsz(desc));
441 }
442 
443 void HELPER(sme2_mova_cz_d)(void *vdst, void *vsrc, uint32_t desc)
444 {
445     copy_vertical_d(vdst, vsrc, simd_oprsz(desc));
446 }
447 
448 /*
449  * Host and TLB primitives for vertical tile slice addressing.
450  */
451 
452 #define DO_LD(NAME, TYPE, HOST, TLB)                                        \
453 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
454 {                                                                           \
455     TYPE val = HOST(host);                                                  \
456     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
457 }                                                                           \
458 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
459                         intptr_t off, target_ulong addr, uintptr_t ra)      \
460 {                                                                           \
461     TYPE val = TLB(env, useronly_clean_ptr(addr), ra);                      \
462     *(TYPE *)(za + tile_vslice_offset(off)) = val;                          \
463 }
464 
465 #define DO_ST(NAME, TYPE, HOST, TLB)                                        \
466 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host)  \
467 {                                                                           \
468     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
469     HOST(host, val);                                                        \
470 }                                                                           \
471 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za,           \
472                         intptr_t off, target_ulong addr, uintptr_t ra)      \
473 {                                                                           \
474     TYPE val = *(TYPE *)(za + tile_vslice_offset(off));                     \
475     TLB(env, useronly_clean_ptr(addr), val, ra);                            \
476 }
477 
478 #define DO_LDQ(HNAME, VNAME) \
479 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
480 {                                                                           \
481     HNAME##_host(za, tile_vslice_offset(off), host);                        \
482 }                                                                           \
483 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
484                                target_ulong addr, uintptr_t ra)             \
485 {                                                                           \
486     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
487 }
488 
489 #define DO_STQ(HNAME, VNAME) \
490 static inline void VNAME##_v_host(void *za, intptr_t off, void *host)       \
491 {                                                                           \
492     HNAME##_host(za, tile_vslice_offset(off), host);                        \
493 }                                                                           \
494 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off,  \
495                                target_ulong addr, uintptr_t ra)             \
496 {                                                                           \
497     HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra);                \
498 }
499 
500 DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra)
501 DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra)
502 DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra)
503 DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra)
504 DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra)
505 DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra)
506 DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra)
507 
508 DO_LDQ(sve_ld1qq_be, sme_ld1q_be)
509 DO_LDQ(sve_ld1qq_le, sme_ld1q_le)
510 
511 DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra)
512 DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra)
513 DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra)
514 DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra)
515 DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra)
516 DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra)
517 DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra)
518 
519 DO_STQ(sve_st1qq_be, sme_st1q_be)
520 DO_STQ(sve_st1qq_le, sme_st1q_le)
521 
522 #undef DO_LD
523 #undef DO_ST
524 #undef DO_LDQ
525 #undef DO_STQ
526 
527 /*
528  * Common helper for all contiguous predicated loads.
529  */
530 
531 static inline QEMU_ALWAYS_INLINE
532 void sme_ld1(CPUARMState *env, void *za, uint64_t *vg,
533              const target_ulong addr, uint32_t desc, const uintptr_t ra,
534              const int esz, uint32_t mtedesc, bool vertical,
535              sve_ldst1_host_fn *host_fn,
536              sve_ldst1_tlb_fn *tlb_fn,
537              ClearFn *clr_fn,
538              CopyFn *cpy_fn)
539 {
540     const intptr_t reg_max = simd_oprsz(desc);
541     const intptr_t esize = 1 << esz;
542     intptr_t reg_off, reg_last;
543     SVEContLdSt info;
544     void *host;
545     int flags;
546 
547     /* Find the active elements.  */
548     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
549         /* The entire predicate was false; no load occurs.  */
550         clr_fn(za, 0, reg_max);
551         return;
552     }
553 
554     /* Probe the page(s).  Exit with exception for any invalid page. */
555     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra);
556 
557     /* Handle watchpoints for all active elements. */
558     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
559                               BP_MEM_READ, ra);
560 
561     /*
562      * Handle mte checks for all active elements.
563      * Since TBI must be set for MTE, !mtedesc => !mte_active.
564      */
565     if (mtedesc) {
566         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
567                                 mtedesc, ra);
568     }
569 
570     flags = info.page[0].flags | info.page[1].flags;
571     if (unlikely(flags != 0)) {
572 #ifdef CONFIG_USER_ONLY
573         g_assert_not_reached();
574 #else
575         /*
576          * At least one page includes MMIO.
577          * Any bus operation can fail with cpu_transaction_failed,
578          * which for ARM will raise SyncExternal.  Perform the load
579          * into scratch memory to preserve register state until the end.
580          */
581         ARMVectorReg scratch = { };
582 
583         reg_off = info.reg_off_first[0];
584         reg_last = info.reg_off_last[1];
585         if (reg_last < 0) {
586             reg_last = info.reg_off_split;
587             if (reg_last < 0) {
588                 reg_last = info.reg_off_last[0];
589             }
590         }
591 
592         do {
593             uint64_t pg = vg[reg_off >> 6];
594             do {
595                 if ((pg >> (reg_off & 63)) & 1) {
596                     tlb_fn(env, &scratch, reg_off, addr + reg_off, ra);
597                 }
598                 reg_off += esize;
599             } while (reg_off & 63);
600         } while (reg_off <= reg_last);
601 
602         cpy_fn(za, &scratch, reg_max);
603         return;
604 #endif
605     }
606 
607     /* The entire operation is in RAM, on valid pages. */
608 
609     reg_off = info.reg_off_first[0];
610     reg_last = info.reg_off_last[0];
611     host = info.page[0].host;
612 
613     if (!vertical) {
614         memset(za, 0, reg_max);
615     } else if (reg_off) {
616         clr_fn(za, 0, reg_off);
617     }
618 
619     set_helper_retaddr(ra);
620 
621     while (reg_off <= reg_last) {
622         uint64_t pg = vg[reg_off >> 6];
623         do {
624             if ((pg >> (reg_off & 63)) & 1) {
625                 host_fn(za, reg_off, host + reg_off);
626             } else if (vertical) {
627                 clr_fn(za, reg_off, esize);
628             }
629             reg_off += esize;
630         } while (reg_off <= reg_last && (reg_off & 63));
631     }
632 
633     clear_helper_retaddr();
634 
635     /*
636      * Use the slow path to manage the cross-page misalignment.
637      * But we know this is RAM and cannot trap.
638      */
639     reg_off = info.reg_off_split;
640     if (unlikely(reg_off >= 0)) {
641         tlb_fn(env, za, reg_off, addr + reg_off, ra);
642     }
643 
644     reg_off = info.reg_off_first[1];
645     if (unlikely(reg_off >= 0)) {
646         reg_last = info.reg_off_last[1];
647         host = info.page[1].host;
648 
649         set_helper_retaddr(ra);
650 
651         do {
652             uint64_t pg = vg[reg_off >> 6];
653             do {
654                 if ((pg >> (reg_off & 63)) & 1) {
655                     host_fn(za, reg_off, host + reg_off);
656                 } else if (vertical) {
657                     clr_fn(za, reg_off, esize);
658                 }
659                 reg_off += esize;
660             } while (reg_off & 63);
661         } while (reg_off <= reg_last);
662 
663         clear_helper_retaddr();
664     }
665 }
666 
667 static inline QEMU_ALWAYS_INLINE
668 void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg,
669                  target_ulong addr, uint32_t desc, uintptr_t ra,
670                  const int esz, bool vertical,
671                  sve_ldst1_host_fn *host_fn,
672                  sve_ldst1_tlb_fn *tlb_fn,
673                  ClearFn *clr_fn,
674                  CopyFn *cpy_fn)
675 {
676     uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
677     int bit55 = extract64(addr, 55, 1);
678 
679     /* Remove mtedesc from the normal sve descriptor. */
680     desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
681 
682     /* Perform gross MTE suppression early. */
683     if (!tbi_check(mtedesc, bit55) ||
684         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
685         mtedesc = 0;
686     }
687 
688     sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical,
689             host_fn, tlb_fn, clr_fn, cpy_fn);
690 }
691 
692 #define DO_LD(L, END, ESZ)                                                 \
693 void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
694                                  target_ulong addr, uint32_t desc)         \
695 {                                                                          \
696     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
697             sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,           \
698             clear_horizontal, copy_horizontal);                            \
699 }                                                                          \
700 void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
701                                  target_ulong addr, uint32_t desc)         \
702 {                                                                          \
703     sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
704             sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,             \
705             clear_vertical_##L, copy_vertical_##L);                        \
706 }                                                                          \
707 void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
708                                      target_ulong addr, uint32_t desc)     \
709 {                                                                          \
710     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
711                 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb,       \
712                 clear_horizontal, copy_horizontal);                        \
713 }                                                                          \
714 void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
715                                      target_ulong addr, uint32_t desc)     \
716 {                                                                          \
717     sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
718                 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb,         \
719                 clear_vertical_##L, copy_vertical_##L);                    \
720 }
721 
722 DO_LD(b, , MO_8)
723 DO_LD(h, _be, MO_16)
724 DO_LD(h, _le, MO_16)
725 DO_LD(s, _be, MO_32)
726 DO_LD(s, _le, MO_32)
727 DO_LD(d, _be, MO_64)
728 DO_LD(d, _le, MO_64)
729 DO_LD(q, _be, MO_128)
730 DO_LD(q, _le, MO_128)
731 
732 #undef DO_LD
733 
734 /*
735  * Common helper for all contiguous predicated stores.
736  */
737 
738 static inline QEMU_ALWAYS_INLINE
739 void sme_st1(CPUARMState *env, void *za, uint64_t *vg,
740              const target_ulong addr, uint32_t desc, const uintptr_t ra,
741              const int esz, uint32_t mtedesc, bool vertical,
742              sve_ldst1_host_fn *host_fn,
743              sve_ldst1_tlb_fn *tlb_fn)
744 {
745     const intptr_t reg_max = simd_oprsz(desc);
746     const intptr_t esize = 1 << esz;
747     intptr_t reg_off, reg_last;
748     SVEContLdSt info;
749     void *host;
750     int flags;
751 
752     /* Find the active elements.  */
753     if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) {
754         /* The entire predicate was false; no store occurs.  */
755         return;
756     }
757 
758     /* Probe the page(s).  Exit with exception for any invalid page. */
759     sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra);
760 
761     /* Handle watchpoints for all active elements. */
762     sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize,
763                               BP_MEM_WRITE, ra);
764 
765     /*
766      * Handle mte checks for all active elements.
767      * Since TBI must be set for MTE, !mtedesc => !mte_active.
768      */
769     if (mtedesc) {
770         sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize,
771                                 mtedesc, ra);
772     }
773 
774     flags = info.page[0].flags | info.page[1].flags;
775     if (unlikely(flags != 0)) {
776 #ifdef CONFIG_USER_ONLY
777         g_assert_not_reached();
778 #else
779         /*
780          * At least one page includes MMIO.
781          * Any bus operation can fail with cpu_transaction_failed,
782          * which for ARM will raise SyncExternal.  We cannot avoid
783          * this fault and will leave with the store incomplete.
784          */
785         reg_off = info.reg_off_first[0];
786         reg_last = info.reg_off_last[1];
787         if (reg_last < 0) {
788             reg_last = info.reg_off_split;
789             if (reg_last < 0) {
790                 reg_last = info.reg_off_last[0];
791             }
792         }
793 
794         do {
795             uint64_t pg = vg[reg_off >> 6];
796             do {
797                 if ((pg >> (reg_off & 63)) & 1) {
798                     tlb_fn(env, za, reg_off, addr + reg_off, ra);
799                 }
800                 reg_off += esize;
801             } while (reg_off & 63);
802         } while (reg_off <= reg_last);
803         return;
804 #endif
805     }
806 
807     reg_off = info.reg_off_first[0];
808     reg_last = info.reg_off_last[0];
809     host = info.page[0].host;
810 
811     set_helper_retaddr(ra);
812 
813     while (reg_off <= reg_last) {
814         uint64_t pg = vg[reg_off >> 6];
815         do {
816             if ((pg >> (reg_off & 63)) & 1) {
817                 host_fn(za, reg_off, host + reg_off);
818             }
819             reg_off += 1 << esz;
820         } while (reg_off <= reg_last && (reg_off & 63));
821     }
822 
823     clear_helper_retaddr();
824 
825     /*
826      * Use the slow path to manage the cross-page misalignment.
827      * But we know this is RAM and cannot trap.
828      */
829     reg_off = info.reg_off_split;
830     if (unlikely(reg_off >= 0)) {
831         tlb_fn(env, za, reg_off, addr + reg_off, ra);
832     }
833 
834     reg_off = info.reg_off_first[1];
835     if (unlikely(reg_off >= 0)) {
836         reg_last = info.reg_off_last[1];
837         host = info.page[1].host;
838 
839         set_helper_retaddr(ra);
840 
841         do {
842             uint64_t pg = vg[reg_off >> 6];
843             do {
844                 if ((pg >> (reg_off & 63)) & 1) {
845                     host_fn(za, reg_off, host + reg_off);
846                 }
847                 reg_off += 1 << esz;
848             } while (reg_off & 63);
849         } while (reg_off <= reg_last);
850 
851         clear_helper_retaddr();
852     }
853 }
854 
855 static inline QEMU_ALWAYS_INLINE
856 void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr,
857                  uint32_t desc, uintptr_t ra, int esz, bool vertical,
858                  sve_ldst1_host_fn *host_fn,
859                  sve_ldst1_tlb_fn *tlb_fn)
860 {
861     uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
862     int bit55 = extract64(addr, 55, 1);
863 
864     /* Remove mtedesc from the normal sve descriptor. */
865     desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT);
866 
867     /* Perform gross MTE suppression early. */
868     if (!tbi_check(mtedesc, bit55) ||
869         tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) {
870         mtedesc = 0;
871     }
872 
873     sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc,
874             vertical, host_fn, tlb_fn);
875 }
876 
877 #define DO_ST(L, END, ESZ)                                                 \
878 void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg,     \
879                                  target_ulong addr, uint32_t desc)         \
880 {                                                                          \
881     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false,               \
882             sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);          \
883 }                                                                          \
884 void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg,     \
885                                  target_ulong addr, uint32_t desc)         \
886 {                                                                          \
887     sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true,                \
888             sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);            \
889 }                                                                          \
890 void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \
891                                      target_ulong addr, uint32_t desc)     \
892 {                                                                          \
893     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false,              \
894                 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb);      \
895 }                                                                          \
896 void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \
897                                      target_ulong addr, uint32_t desc)     \
898 {                                                                          \
899     sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true,               \
900                 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb);        \
901 }
902 
903 DO_ST(b, , MO_8)
904 DO_ST(h, _be, MO_16)
905 DO_ST(h, _le, MO_16)
906 DO_ST(s, _be, MO_32)
907 DO_ST(s, _le, MO_32)
908 DO_ST(d, _be, MO_64)
909 DO_ST(d, _le, MO_64)
910 DO_ST(q, _be, MO_128)
911 DO_ST(q, _le, MO_128)
912 
913 #undef DO_ST
914 
915 void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn,
916                          void *vpm, uint32_t desc)
917 {
918     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
919     uint64_t *pn = vpn, *pm = vpm;
920     uint32_t *zda = vzda, *zn = vzn;
921 
922     for (row = 0; row < oprsz; ) {
923         uint64_t pa = pn[row >> 4];
924         do {
925             if (pa & 1) {
926                 for (col = 0; col < oprsz; ) {
927                     uint64_t pb = pm[col >> 4];
928                     do {
929                         if (pb & 1) {
930                             zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)];
931                         }
932                         pb >>= 4;
933                     } while (++col & 15);
934                 }
935             }
936             pa >>= 4;
937         } while (++row & 15);
938     }
939 }
940 
941 void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn,
942                          void *vpm, uint32_t desc)
943 {
944     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
945     uint8_t *pn = vpn, *pm = vpm;
946     uint64_t *zda = vzda, *zn = vzn;
947 
948     for (row = 0; row < oprsz; ++row) {
949         if (pn[H1(row)] & 1) {
950             for (col = 0; col < oprsz; ++col) {
951                 if (pm[H1(col)] & 1) {
952                     zda[tile_vslice_index(row) + col] += zn[col];
953                 }
954             }
955         }
956     }
957 }
958 
959 void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn,
960                          void *vpm, uint32_t desc)
961 {
962     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
963     uint64_t *pn = vpn, *pm = vpm;
964     uint32_t *zda = vzda, *zn = vzn;
965 
966     for (row = 0; row < oprsz; ) {
967         uint64_t pa = pn[row >> 4];
968         do {
969             if (pa & 1) {
970                 uint32_t zn_row = zn[H4(row)];
971                 for (col = 0; col < oprsz; ) {
972                     uint64_t pb = pm[col >> 4];
973                     do {
974                         if (pb & 1) {
975                             zda[tile_vslice_index(row) + H4(col)] += zn_row;
976                         }
977                         pb >>= 4;
978                     } while (++col & 15);
979                 }
980             }
981             pa >>= 4;
982         } while (++row & 15);
983     }
984 }
985 
986 void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn,
987                          void *vpm, uint32_t desc)
988 {
989     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
990     uint8_t *pn = vpn, *pm = vpm;
991     uint64_t *zda = vzda, *zn = vzn;
992 
993     for (row = 0; row < oprsz; ++row) {
994         if (pn[H1(row)] & 1) {
995             uint64_t zn_row = zn[row];
996             for (col = 0; col < oprsz; ++col) {
997                 if (pm[H1(col)] & 1) {
998                     zda[tile_vslice_index(row) + col] += zn_row;
999                 }
1000             }
1001         }
1002     }
1003 }
1004 
1005 static void do_fmopa_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
1006                        uint16_t *pm, float_status *fpst, uint32_t desc,
1007                        uint16_t negx, int negf)
1008 {
1009     intptr_t row, col, oprsz = simd_maxsz(desc);
1010 
1011     for (row = 0; row < oprsz; ) {
1012         uint16_t pa = pn[H2(row >> 4)];
1013         do {
1014             if (pa & 1) {
1015                 void *vza_row = vza + tile_vslice_offset(row);
1016                 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx;
1017 
1018                 for (col = 0; col < oprsz; ) {
1019                     uint16_t pb = pm[H2(col >> 4)];
1020                     do {
1021                         if (pb & 1) {
1022                             uint16_t *a = vza_row + H1_2(col);
1023                             uint16_t *m = vzm + H1_2(col);
1024                             *a = float16_muladd(n, *m, *a, negf, fpst);
1025                         }
1026                         col += 2;
1027                         pb >>= 2;
1028                     } while (col & 15);
1029                 }
1030             }
1031             row += 2;
1032             pa >>= 2;
1033         } while (row & 15);
1034     }
1035 }
1036 
1037 void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn,
1038                          void *vpm, float_status *fpst, uint32_t desc)
1039 {
1040     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1041 }
1042 
1043 void HELPER(sme_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn,
1044                          void *vpm, float_status *fpst, uint32_t desc)
1045 {
1046     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0);
1047 }
1048 
1049 void HELPER(sme_ah_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn,
1050                             void *vpm, float_status *fpst, uint32_t desc)
1051 {
1052     do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1053                float_muladd_negate_product);
1054 }
1055 
1056 static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn,
1057                        uint16_t *pm, float_status *fpst, uint32_t desc,
1058                        uint32_t negx, int negf)
1059 {
1060     intptr_t row, col, oprsz = simd_maxsz(desc);
1061 
1062     for (row = 0; row < oprsz; ) {
1063         uint16_t pa = pn[H2(row >> 4)];
1064         do {
1065             if (pa & 1) {
1066                 void *vza_row = vza + tile_vslice_offset(row);
1067                 uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx;
1068 
1069                 for (col = 0; col < oprsz; ) {
1070                     uint16_t pb = pm[H2(col >> 4)];
1071                     do {
1072                         if (pb & 1) {
1073                             uint32_t *a = vza_row + H1_4(col);
1074                             uint32_t *m = vzm + H1_4(col);
1075                             *a = float32_muladd(n, *m, *a, negf, fpst);
1076                         }
1077                         col += 4;
1078                         pb >>= 4;
1079                     } while (col & 15);
1080                 }
1081             }
1082             row += 4;
1083             pa >>= 4;
1084         } while (row & 15);
1085     }
1086 }
1087 
1088 void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn,
1089                          void *vpm, float_status *fpst, uint32_t desc)
1090 {
1091     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1092 }
1093 
1094 void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
1095                          void *vpm, float_status *fpst, uint32_t desc)
1096 {
1097     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0);
1098 }
1099 
1100 void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn,
1101                             void *vpm, float_status *fpst, uint32_t desc)
1102 {
1103     do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1104                float_muladd_negate_product);
1105 }
1106 
1107 static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn,
1108                        uint8_t *pm, float_status *fpst, uint32_t desc,
1109                        uint64_t negx, int negf)
1110 {
1111     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1112 
1113     for (row = 0; row < oprsz; ++row) {
1114         if (pn[H1(row)] & 1) {
1115             uint64_t *za_row = &za[tile_vslice_index(row)];
1116             uint64_t n = zn[row] ^ negx;
1117 
1118             for (col = 0; col < oprsz; ++col) {
1119                 if (pm[H1(col)] & 1) {
1120                     uint64_t *a = &za_row[col];
1121                     *a = float64_muladd(n, zm[col], *a, negf, fpst);
1122                 }
1123             }
1124         }
1125     }
1126 }
1127 
1128 void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn,
1129                          void *vpm, float_status *fpst, uint32_t desc)
1130 {
1131     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1132 }
1133 
1134 void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
1135                          void *vpm, float_status *fpst, uint32_t desc)
1136 {
1137     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0);
1138 }
1139 
1140 void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn,
1141                             void *vpm, float_status *fpst, uint32_t desc)
1142 {
1143     do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1144                float_muladd_negate_product);
1145 }
1146 
1147 static void do_bfmopa(void *vza, void *vzn, void *vzm, uint16_t *pn,
1148                       uint16_t *pm, float_status *fpst, uint32_t desc,
1149                       uint16_t negx, int negf)
1150 {
1151     intptr_t row, col, oprsz = simd_maxsz(desc);
1152 
1153     for (row = 0; row < oprsz; ) {
1154         uint16_t pa = pn[H2(row >> 4)];
1155         do {
1156             if (pa & 1) {
1157                 void *vza_row = vza + tile_vslice_offset(row);
1158                 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx;
1159 
1160                 for (col = 0; col < oprsz; ) {
1161                     uint16_t pb = pm[H2(col >> 4)];
1162                     do {
1163                         if (pb & 1) {
1164                             uint16_t *a = vza_row + H1_2(col);
1165                             uint16_t *m = vzm + H1_2(col);
1166                             *a = bfloat16_muladd(n, *m, *a, negf, fpst);
1167                         }
1168                         col += 2;
1169                         pb >>= 2;
1170                     } while (col & 15);
1171                 }
1172             }
1173             row += 2;
1174             pa >>= 2;
1175         } while (row & 15);
1176     }
1177 }
1178 
1179 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn,
1180                         void *vpm, float_status *fpst, uint32_t desc)
1181 {
1182     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0);
1183 }
1184 
1185 void HELPER(sme_bfmops)(void *vza, void *vzn, void *vzm, void *vpn,
1186                         void *vpm, float_status *fpst, uint32_t desc)
1187 {
1188     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0);
1189 }
1190 
1191 void HELPER(sme_ah_bfmops)(void *vza, void *vzn, void *vzm, void *vpn,
1192                            void *vpm, float_status *fpst, uint32_t desc)
1193 {
1194     do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0,
1195               float_muladd_negate_product);
1196 }
1197 
1198 /*
1199  * Alter PAIR as needed for controlling predicates being false,
1200  * and for NEG on an enabled row element.
1201  */
1202 static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg)
1203 {
1204     /*
1205      * The pseudocode uses a conditional negate after the conditional zero.
1206      * It is simpler here to unconditionally negate before conditional zero.
1207      */
1208     pair ^= neg;
1209     if (!(pg & 1)) {
1210         pair &= 0xffff0000u;
1211     }
1212     if (!(pg & 4)) {
1213         pair &= 0x0000ffffu;
1214     }
1215     return pair;
1216 }
1217 
1218 static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
1219 {
1220     uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0;
1221     uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0;
1222     return l | (h << 16);
1223 }
1224 
1225 static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg)
1226 {
1227     uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0;
1228     uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0;
1229     return l | (h << 16);
1230 }
1231 
1232 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2,
1233                           float_status *s_f16, float_status *s_std,
1234                           float_status *s_odd)
1235 {
1236     /*
1237      * We need three different float_status for different parts of this
1238      * operation:
1239      *  - the input conversion of the float16 values must use the
1240      *    f16-specific float_status, so that the FPCR.FZ16 control is applied
1241      *  - operations on float32 including the final accumulation must use
1242      *    the normal float_status, so that FPCR.FZ is applied
1243      *  - we have pre-set-up copy of s_std which is set to round-to-odd,
1244      *    for the multiply (see below)
1245      */
1246     float16 h1r = e1 & 0xffff;
1247     float16 h1c = e1 >> 16;
1248     float16 h2r = e2 & 0xffff;
1249     float16 h2c = e2 >> 16;
1250     float32 t32;
1251 
1252     /* C.f. FPProcessNaNs4 */
1253     if (float16_is_any_nan(h1r) || float16_is_any_nan(h1c) ||
1254         float16_is_any_nan(h2r) || float16_is_any_nan(h2c)) {
1255         float16 t16;
1256 
1257         if (float16_is_signaling_nan(h1r, s_f16)) {
1258             t16 = h1r;
1259         } else if (float16_is_signaling_nan(h1c, s_f16)) {
1260             t16 = h1c;
1261         } else if (float16_is_signaling_nan(h2r, s_f16)) {
1262             t16 = h2r;
1263         } else if (float16_is_signaling_nan(h2c, s_f16)) {
1264             t16 = h2c;
1265         } else if (float16_is_any_nan(h1r)) {
1266             t16 = h1r;
1267         } else if (float16_is_any_nan(h1c)) {
1268             t16 = h1c;
1269         } else if (float16_is_any_nan(h2r)) {
1270             t16 = h2r;
1271         } else {
1272             t16 = h2c;
1273         }
1274         t32 = float16_to_float32(t16, true, s_f16);
1275     } else {
1276         float64 e1r = float16_to_float64(h1r, true, s_f16);
1277         float64 e1c = float16_to_float64(h1c, true, s_f16);
1278         float64 e2r = float16_to_float64(h2r, true, s_f16);
1279         float64 e2c = float16_to_float64(h2c, true, s_f16);
1280         float64 t64;
1281 
1282         /*
1283          * The ARM pseudocode function FPDot performs both multiplies
1284          * and the add with a single rounding operation.  Emulate this
1285          * by performing the first multiply in round-to-odd, then doing
1286          * the second multiply as fused multiply-add, and rounding to
1287          * float32 all in one step.
1288          */
1289         t64 = float64_mul(e1r, e2r, s_odd);
1290         t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std);
1291 
1292         /* This conversion is exact, because we've already rounded. */
1293         t32 = float64_to_float32(t64, s_std);
1294     }
1295 
1296     /* The final accumulation step is not fused. */
1297     return float32_add(sum, t32, s_std);
1298 }
1299 
1300 static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn,
1301                          uint16_t *pm, CPUARMState *env, uint32_t desc,
1302                          uint32_t negx, bool ah_neg)
1303 {
1304     intptr_t row, col, oprsz = simd_maxsz(desc);
1305     float_status fpst_odd = env->vfp.fp_status[FPST_ZA];
1306 
1307     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1308 
1309     for (row = 0; row < oprsz; ) {
1310         uint16_t prow = pn[H2(row >> 4)];
1311         do {
1312             void *vza_row = vza + tile_vslice_offset(row);
1313             uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1314 
1315             if (ah_neg) {
1316                 n = f16mop_ah_neg_adj_pair(n, prow);
1317             } else {
1318                 n = f16mop_adj_pair(n, prow, negx);
1319             }
1320 
1321             for (col = 0; col < oprsz; ) {
1322                 uint16_t pcol = pm[H2(col >> 4)];
1323                 do {
1324                     if (prow & pcol & 0b0101) {
1325                         uint32_t *a = vza_row + H1_4(col);
1326                         uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1327 
1328                         m = f16mop_adj_pair(m, pcol, 0);
1329                         *a = f16_dotadd(*a, n, m,
1330                                         &env->vfp.fp_status[FPST_ZA_F16],
1331                                         &env->vfp.fp_status[FPST_ZA],
1332                                         &fpst_odd);
1333                     }
1334                     col += 4;
1335                     pcol >>= 4;
1336                 } while (col & 15);
1337             }
1338             row += 4;
1339             prow >>= 4;
1340         } while (row & 15);
1341     }
1342 }
1343 
1344 void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1345                            void *vpm, CPUARMState *env, uint32_t desc)
1346 {
1347     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
1348 }
1349 
1350 void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1351                            void *vpm, CPUARMState *env, uint32_t desc)
1352 {
1353     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
1354 }
1355 
1356 void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn,
1357                               void *vpm, CPUARMState *env, uint32_t desc)
1358 {
1359     do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
1360 }
1361 
1362 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va,
1363                          CPUARMState *env, uint32_t desc)
1364 {
1365     intptr_t i, oprsz = simd_maxsz(desc);
1366     bool za = extract32(desc, SIMD_DATA_SHIFT, 1);
1367     float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64];
1368     float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16];
1369     float_status fpst_odd = *fpst_std;
1370     float32 *d = vd, *a = va;
1371     uint32_t *n = vn, *m = vm;
1372 
1373     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1374 
1375     for (i = 0; i < oprsz / sizeof(float32); ++i) {
1376         d[H4(i)] = f16_dotadd(a[H4(i)], n[H4(i)], m[H4(i)],
1377                               fpst_f16, fpst_std, &fpst_odd);
1378     }
1379 }
1380 
1381 void HELPER(sme2_fdot_idx_h)(void *vd, void *vn, void *vm, void *va,
1382                              CPUARMState *env, uint32_t desc)
1383 {
1384     intptr_t i, j, oprsz = simd_maxsz(desc);
1385     intptr_t elements = oprsz / sizeof(float32);
1386     intptr_t eltspersegment = MIN(4, elements);
1387     int idx = extract32(desc, SIMD_DATA_SHIFT, 2);
1388     bool za = extract32(desc, SIMD_DATA_SHIFT + 2, 1);
1389     float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64];
1390     float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16];
1391     float_status fpst_odd = *fpst_std;
1392     float32 *d = vd, *a = va;
1393     uint32_t *n = vn, *m = (uint32_t *)vm + H4(idx);
1394 
1395     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1396 
1397     for (i = 0; i < elements; i += eltspersegment) {
1398         uint32_t mm = m[i];
1399         for (j = 0; j < eltspersegment; ++j) {
1400             d[H4(i + j)] = f16_dotadd(a[H4(i + j)], n[H4(i + j)], mm,
1401                                       fpst_f16, fpst_std, &fpst_odd);
1402         }
1403     }
1404 }
1405 
1406 void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va,
1407                               CPUARMState *env, uint32_t desc)
1408 {
1409     intptr_t i, j, oprsz = simd_maxsz(desc);
1410     intptr_t elements = oprsz / sizeof(float32);
1411     intptr_t eltspersegment = MIN(4, elements);
1412     int idx = extract32(desc, SIMD_DATA_SHIFT, 2);
1413     int sel = extract32(desc, SIMD_DATA_SHIFT + 2, 1);
1414     float_status fpst_odd, *fpst_std, *fpst_f16;
1415     float32 *d = vd, *a = va;
1416     uint16_t *n0 = vn;
1417     uint16_t *n1 = vn + sizeof(ARMVectorReg);
1418     uint32_t *m = (uint32_t *)vm + H4(idx);
1419 
1420     fpst_std = &env->vfp.fp_status[FPST_ZA];
1421     fpst_f16 = &env->vfp.fp_status[FPST_ZA_F16];
1422     fpst_odd = *fpst_std;
1423     set_float_rounding_mode(float_round_to_odd, &fpst_odd);
1424 
1425     for (i = 0; i < elements; i += eltspersegment) {
1426         uint32_t mm = m[i];
1427         for (j = 0; j < eltspersegment; ++j) {
1428             uint32_t nn = (n0[H2(2 * (i + j) + sel)])
1429                         | (n1[H2(2 * (i + j) + sel)] << 16);
1430             d[i + H4(j)] = f16_dotadd(a[i + H4(j)], nn, mm,
1431                                       fpst_f16, fpst_std, &fpst_odd);
1432         }
1433     }
1434 }
1435 
1436 static void do_bfmopa_w(void *vza, void *vzn, void *vzm,
1437                         uint16_t *pn, uint16_t *pm, CPUARMState *env,
1438                         uint32_t desc, uint32_t negx, bool ah_neg)
1439 {
1440     intptr_t row, col, oprsz = simd_maxsz(desc);
1441     float_status fpst, fpst_odd;
1442 
1443     if (is_ebf(env, &fpst, &fpst_odd)) {
1444         for (row = 0; row < oprsz; ) {
1445             uint16_t prow = pn[H2(row >> 4)];
1446             do {
1447                 void *vza_row = vza + tile_vslice_offset(row);
1448                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1449 
1450                 if (ah_neg) {
1451                     n = bf16mop_ah_neg_adj_pair(n, prow);
1452                 } else {
1453                     n = f16mop_adj_pair(n, prow, negx);
1454                 }
1455 
1456                 for (col = 0; col < oprsz; ) {
1457                     uint16_t pcol = pm[H2(col >> 4)];
1458                     do {
1459                         if (prow & pcol & 0b0101) {
1460                             uint32_t *a = vza_row + H1_4(col);
1461                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1462 
1463                             m = f16mop_adj_pair(m, pcol, 0);
1464                             *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
1465                         }
1466                         col += 4;
1467                         pcol >>= 4;
1468                     } while (col & 15);
1469                 }
1470                 row += 4;
1471                 prow >>= 4;
1472             } while (row & 15);
1473         }
1474     } else {
1475         for (row = 0; row < oprsz; ) {
1476             uint16_t prow = pn[H2(row >> 4)];
1477             do {
1478                 void *vza_row = vza + tile_vslice_offset(row);
1479                 uint32_t n = *(uint32_t *)(vzn + H1_4(row));
1480 
1481                 if (ah_neg) {
1482                     n = bf16mop_ah_neg_adj_pair(n, prow);
1483                 } else {
1484                     n = f16mop_adj_pair(n, prow, negx);
1485                 }
1486 
1487                 for (col = 0; col < oprsz; ) {
1488                     uint16_t pcol = pm[H2(col >> 4)];
1489                     do {
1490                         if (prow & pcol & 0b0101) {
1491                             uint32_t *a = vza_row + H1_4(col);
1492                             uint32_t m = *(uint32_t *)(vzm + H1_4(col));
1493 
1494                             m = f16mop_adj_pair(m, pcol, 0);
1495                             *a = bfdotadd(*a, n, m, &fpst);
1496                         }
1497                         col += 4;
1498                         pcol >>= 4;
1499                     } while (col & 15);
1500                 }
1501                 row += 4;
1502                 prow >>= 4;
1503             } while (row & 15);
1504         }
1505     }
1506 }
1507 
1508 void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn,
1509                           void *vpm, CPUARMState *env, uint32_t desc)
1510 {
1511     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false);
1512 }
1513 
1514 void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
1515                           void *vpm, CPUARMState *env, uint32_t desc)
1516 {
1517     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false);
1518 }
1519 
1520 void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn,
1521                              void *vpm, CPUARMState *env, uint32_t desc)
1522 {
1523     do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true);
1524 }
1525 
1526 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool);
1527 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm,
1528                               uint8_t *pn, uint8_t *pm,
1529                               uint32_t desc, IMOPFn32 *fn)
1530 {
1531     intptr_t row, col, oprsz = simd_oprsz(desc) / 4;
1532     bool neg = simd_data(desc);
1533 
1534     for (row = 0; row < oprsz; ++row) {
1535         uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf;
1536         uint32_t *za_row = &za[tile_vslice_index(row)];
1537         uint32_t n = zn[H4(row)];
1538 
1539         for (col = 0; col < oprsz; ++col) {
1540             uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4);
1541             uint32_t *a = &za_row[H4(col)];
1542 
1543             *a = fn(n, zm[H4(col)], *a, pa & pb, neg);
1544         }
1545     }
1546 }
1547 
1548 typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool);
1549 static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm,
1550                               uint8_t *pn, uint8_t *pm,
1551                               uint32_t desc, IMOPFn64 *fn)
1552 {
1553     intptr_t row, col, oprsz = simd_oprsz(desc) / 8;
1554     bool neg = simd_data(desc);
1555 
1556     for (row = 0; row < oprsz; ++row) {
1557         uint8_t pa = pn[H1(row)];
1558         uint64_t *za_row = &za[tile_vslice_index(row)];
1559         uint64_t n = zn[row];
1560 
1561         for (col = 0; col < oprsz; ++col) {
1562             uint8_t pb = pm[H1(col)];
1563             uint64_t *a = &za_row[col];
1564 
1565             *a = fn(n, zm[col], *a, pa & pb, neg);
1566         }
1567     }
1568 }
1569 
1570 #define DEF_IMOP_8x4_32(NAME, NTYPE, MTYPE) \
1571 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1572 {                                                                           \
1573     uint32_t sum = 0;                                                       \
1574     /* Apply P to N as a mask, making the inactive elements 0. */           \
1575     n &= expand_pred_b(p);                                                  \
1576     sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1577     sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8);                               \
1578     sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1579     sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24);                             \
1580     return neg ? a - sum : a + sum;                                         \
1581 }
1582 
1583 #define DEF_IMOP_16x4_64(NAME, NTYPE, MTYPE) \
1584 static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \
1585 {                                                                           \
1586     uint64_t sum = 0;                                                       \
1587     /* Apply P to N as a mask, making the inactive elements 0. */           \
1588     n &= expand_pred_h(p);                                                  \
1589     sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0);                      \
1590     sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16);                    \
1591     sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32);                    \
1592     sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48);                    \
1593     return neg ? a - sum : a + sum;                                         \
1594 }
1595 
1596 DEF_IMOP_8x4_32(smopa_s, int8_t, int8_t)
1597 DEF_IMOP_8x4_32(umopa_s, uint8_t, uint8_t)
1598 DEF_IMOP_8x4_32(sumopa_s, int8_t, uint8_t)
1599 DEF_IMOP_8x4_32(usmopa_s, uint8_t, int8_t)
1600 
1601 DEF_IMOP_16x4_64(smopa_d, int16_t, int16_t)
1602 DEF_IMOP_16x4_64(umopa_d, uint16_t, uint16_t)
1603 DEF_IMOP_16x4_64(sumopa_d, int16_t, uint16_t)
1604 DEF_IMOP_16x4_64(usmopa_d, uint16_t, int16_t)
1605 
1606 #define DEF_IMOPH(P, NAME, S) \
1607     void HELPER(P##_##NAME##_##S)(void *vza, void *vzn, void *vzm,          \
1608                                   void *vpn, void *vpm, uint32_t desc)      \
1609     { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); }
1610 
1611 DEF_IMOPH(sme, smopa, s)
1612 DEF_IMOPH(sme, umopa, s)
1613 DEF_IMOPH(sme, sumopa, s)
1614 DEF_IMOPH(sme, usmopa, s)
1615 
1616 DEF_IMOPH(sme, smopa, d)
1617 DEF_IMOPH(sme, umopa, d)
1618 DEF_IMOPH(sme, sumopa, d)
1619 DEF_IMOPH(sme, usmopa, d)
1620 
1621 static uint32_t bmopa_s(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg)
1622 {
1623     uint32_t sum = ctpop32(~(n ^ m));
1624     if (neg) {
1625         sum = -sum;
1626     }
1627     if (!(p & 1)) {
1628         sum = 0;
1629     }
1630     return a + sum;
1631 }
1632 
1633 DEF_IMOPH(sme2, bmopa, s)
1634 
1635 #define DEF_IMOP_16x2_32(NAME, NTYPE, MTYPE) \
1636 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \
1637 {                                                                           \
1638     uint32_t sum = 0;                                                       \
1639     /* Apply P to N as a mask, making the inactive elements 0. */           \
1640     n &= expand_pred_h(p);                                                  \
1641     sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0);                               \
1642     sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16);                             \
1643     return neg ? a - sum : a + sum;                                         \
1644 }
1645 
1646 DEF_IMOP_16x2_32(smopa2_s, int16_t, int16_t)
1647 DEF_IMOP_16x2_32(umopa2_s, uint16_t, uint16_t)
1648 
1649 DEF_IMOPH(sme2, smopa2, s)
1650 DEF_IMOPH(sme2, umopa2, s)
1651 
1652 #define DO_VDOT_IDX(NAME, TYPED, TYPEN, TYPEM, HD, HN) \
1653 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)            \
1654 {                                                                         \
1655     intptr_t svl = simd_oprsz(desc);                                      \
1656     intptr_t elements = svl / sizeof(TYPED);                              \
1657     intptr_t eltperseg = 16 / sizeof(TYPED);                              \
1658     intptr_t nreg = sizeof(TYPED) / sizeof(TYPEN);                        \
1659     intptr_t vstride = (svl / nreg) * sizeof(ARMVectorReg);               \
1660     intptr_t zstride = sizeof(ARMVectorReg) / sizeof(TYPEN);              \
1661     intptr_t idx = extract32(desc, SIMD_DATA_SHIFT, 2);                   \
1662     TYPEN *n = vn;                                                        \
1663     TYPEM *m = vm;                                                        \
1664     for (intptr_t r = 0; r < nreg; r++) {                                 \
1665         TYPED *d = vd + r * vstride;                                      \
1666         for (intptr_t seg = 0; seg < elements; seg += eltperseg) {        \
1667             intptr_t s = seg + idx;                                       \
1668             for (intptr_t e = seg; e < seg + eltperseg; e++) {            \
1669                 TYPED sum = d[HD(e)];                                     \
1670                 for (intptr_t i = 0; i < nreg; i++) {                     \
1671                     TYPED nn = n[i * zstride + HN(nreg * e + r)];         \
1672                     TYPED mm = m[HN(nreg * s + i)];                       \
1673                     sum += nn * mm;                                       \
1674                 }                                                         \
1675                 d[HD(e)] = sum;                                           \
1676             }                                                             \
1677         }                                                                 \
1678     }                                                                     \
1679 }
1680 
1681 DO_VDOT_IDX(sme2_svdot_idx_4b, int32_t, int8_t, int8_t, H4, H1)
1682 DO_VDOT_IDX(sme2_uvdot_idx_4b, uint32_t, uint8_t, uint8_t, H4, H1)
1683 DO_VDOT_IDX(sme2_suvdot_idx_4b, int32_t, int8_t, uint8_t, H4, H1)
1684 DO_VDOT_IDX(sme2_usvdot_idx_4b, int32_t, uint8_t, int8_t, H4, H1)
1685 
1686 DO_VDOT_IDX(sme2_svdot_idx_4h, int64_t, int16_t, int16_t, H8, H2)
1687 DO_VDOT_IDX(sme2_uvdot_idx_4h, uint64_t, uint16_t, uint16_t, H8, H2)
1688 
1689 DO_VDOT_IDX(sme2_svdot_idx_2h, int32_t, int16_t, int16_t, H4, H2)
1690 DO_VDOT_IDX(sme2_uvdot_idx_2h, uint32_t, uint16_t, uint16_t, H4, H2)
1691 
1692 #undef DO_VDOT_IDX
1693 
1694 #define DO_MLALL(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \
1695 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
1696 {                                                               \
1697     intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW);       \
1698     intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2);         \
1699     TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm;       \
1700     for (intptr_t i = 0; i < elements; ++i) {                   \
1701         TYPEW nn = n[HN(i * 4 + sel)];                          \
1702         TYPEM mm = m[HN(i * 4 + sel)];                          \
1703         d[HW(i)] = a[HW(i)] OP (nn * mm);                       \
1704     }                                                           \
1705 }
1706 
1707 DO_MLALL(sme2_smlall_s, int32_t, int8_t, int8_t, H4, H1, +)
1708 DO_MLALL(sme2_smlall_d, int64_t, int16_t, int16_t, H8, H2, +)
1709 DO_MLALL(sme2_smlsll_s, int32_t, int8_t, int8_t, H4, H1, -)
1710 DO_MLALL(sme2_smlsll_d, int64_t, int16_t, int16_t, H8, H2, -)
1711 
1712 DO_MLALL(sme2_umlall_s, uint32_t, uint8_t, uint8_t, H4, H1, +)
1713 DO_MLALL(sme2_umlall_d, uint64_t, uint16_t, uint16_t, H8, H2, +)
1714 DO_MLALL(sme2_umlsll_s, uint32_t, uint8_t, uint8_t, H4, H1, -)
1715 DO_MLALL(sme2_umlsll_d, uint64_t, uint16_t, uint16_t, H8, H2, -)
1716 
1717 DO_MLALL(sme2_usmlall_s, uint32_t, uint8_t, int8_t, H4, H1, +)
1718 
1719 #undef DO_MLALL
1720 
1721 #define DO_MLALL_IDX(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \
1722 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
1723 {                                                               \
1724     intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW);       \
1725     intptr_t eltspersegment = 16 / sizeof(TYPEW);               \
1726     intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2);         \
1727     intptr_t idx = extract32(desc, SIMD_DATA_SHIFT + 2, 4);     \
1728     TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm;       \
1729     for (intptr_t i = 0; i < elements; i += eltspersegment) {   \
1730         TYPEW mm = m[HN(i * 4 + idx)];                          \
1731         for (intptr_t j = 0; j < eltspersegment; ++j) {         \
1732             TYPEN nn = n[HN((i + j) * 4 + sel)];                \
1733             d[HW(i + j)] = a[HW(i + j)] OP (nn * mm);           \
1734         }                                                       \
1735     }                                                           \
1736 }
1737 
1738 DO_MLALL_IDX(sme2_smlall_idx_s, int32_t, int8_t, int8_t, H4, H1, +)
1739 DO_MLALL_IDX(sme2_smlall_idx_d, int64_t, int16_t, int16_t, H8, H2, +)
1740 DO_MLALL_IDX(sme2_smlsll_idx_s, int32_t, int8_t, int8_t, H4, H1, -)
1741 DO_MLALL_IDX(sme2_smlsll_idx_d, int64_t, int16_t, int16_t, H8, H2, -)
1742 
1743 DO_MLALL_IDX(sme2_umlall_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, +)
1744 DO_MLALL_IDX(sme2_umlall_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, +)
1745 DO_MLALL_IDX(sme2_umlsll_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, -)
1746 DO_MLALL_IDX(sme2_umlsll_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, -)
1747 
1748 DO_MLALL_IDX(sme2_usmlall_idx_s, uint32_t, uint8_t, int8_t, H4, H1, +)
1749 DO_MLALL_IDX(sme2_sumlall_idx_s, uint32_t, int8_t, uint8_t, H4, H1, +)
1750 
1751 #undef DO_MLALL_IDX
1752 
1753 /* Convert and compress */
1754 void HELPER(sme2_bfcvt)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1755 {
1756     ARMVectorReg scratch;
1757     size_t oprsz = simd_oprsz(desc);
1758     size_t i, n = oprsz / 4;
1759     float32 *s0 = vs;
1760     float32 *s1 = vs + sizeof(ARMVectorReg);
1761     bfloat16 *d = vd;
1762 
1763     if (vd == s1) {
1764         s1 = memcpy(&scratch, s1, oprsz);
1765     }
1766 
1767     for (i = 0; i < n; ++i) {
1768         d[H2(i)] = float32_to_bfloat16(s0[H4(i)], fpst);
1769     }
1770     for (i = 0; i < n; ++i) {
1771         d[H2(i) + n] = float32_to_bfloat16(s1[H4(i)], fpst);
1772     }
1773 }
1774 
1775 void HELPER(sme2_fcvt_n)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1776 {
1777     ARMVectorReg scratch;
1778     size_t oprsz = simd_oprsz(desc);
1779     size_t i, n = oprsz / 4;
1780     float32 *s0 = vs;
1781     float32 *s1 = vs + sizeof(ARMVectorReg);
1782     float16 *d = vd;
1783 
1784     if (vd == s1) {
1785         s1 = memcpy(&scratch, s1, oprsz);
1786     }
1787 
1788     for (i = 0; i < n; ++i) {
1789         d[H2(i)] = sve_f32_to_f16(s0[H4(i)], fpst);
1790     }
1791     for (i = 0; i < n; ++i) {
1792         d[H2(i) + n] = sve_f32_to_f16(s1[H4(i)], fpst);
1793     }
1794 }
1795 
1796 #define SQCVT2(NAME, TW, TN, HW, HN, SAT)                       \
1797 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1798 {                                                               \
1799     ARMVectorReg scratch;                                       \
1800     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1801     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1802     TN *d = vd;                                                 \
1803     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1804         d = (TN *)&scratch;                                     \
1805     }                                                           \
1806     for (size_t i = 0; i < n; ++i) {                            \
1807         d[HN(i)] = SAT(s0[HW(i)]);                              \
1808         d[HN(i + n)] = SAT(s1[HW(i)]);                          \
1809     }                                                           \
1810     if (d != vd) {                                              \
1811         memcpy(vd, d, oprsz);                                   \
1812     }                                                           \
1813 }
1814 
1815 SQCVT2(sme2_sqcvt_sh, int32_t, int16_t, H4, H2, do_ssat_h)
1816 SQCVT2(sme2_uqcvt_sh, uint32_t, uint16_t, H4, H2, do_usat_h)
1817 SQCVT2(sme2_sqcvtu_sh, int32_t, uint16_t, H4, H2, do_usat_h)
1818 
1819 #undef SQCVT2
1820 
1821 #define SQCVT4(NAME, TW, TN, HW, HN, SAT)                       \
1822 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1823 {                                                               \
1824     ARMVectorReg scratch;                                       \
1825     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1826     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1827     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1828     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1829     TN *d = vd;                                                 \
1830     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1831         d = (TN *)&scratch;                                     \
1832     }                                                           \
1833     for (size_t i = 0; i < n; ++i) {                            \
1834         d[HN(i)] = SAT(s0[HW(i)]);                              \
1835         d[HN(i + n)] = SAT(s1[HW(i)]);                          \
1836         d[HN(i + 2 * n)] = SAT(s2[HW(i)]);                      \
1837         d[HN(i + 3 * n)] = SAT(s3[HW(i)]);                      \
1838     }                                                           \
1839     if (d != vd) {                                              \
1840         memcpy(vd, d, oprsz);                                   \
1841     }                                                           \
1842 }
1843 
1844 SQCVT4(sme2_sqcvt_sb, int32_t, int8_t, H4, H2, do_ssat_b)
1845 SQCVT4(sme2_uqcvt_sb, uint32_t, uint8_t, H4, H2, do_usat_b)
1846 SQCVT4(sme2_sqcvtu_sb, int32_t, uint8_t, H4, H2, do_usat_b)
1847 
1848 SQCVT4(sme2_sqcvt_dh, int64_t, int16_t, H8, H2, do_ssat_h)
1849 SQCVT4(sme2_uqcvt_dh, uint64_t, uint16_t, H8, H2, do_usat_h)
1850 SQCVT4(sme2_sqcvtu_dh, int64_t, uint16_t, H8, H2, do_usat_h)
1851 
1852 #undef SQCVT4
1853 
1854 #define SQRSHR2(NAME, TW, TN, HW, HN, RSHR, SAT)                \
1855 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1856 {                                                               \
1857     ARMVectorReg scratch;                                       \
1858     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1859     int shift = simd_data(desc);                                \
1860     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1861     TN *d = vd;                                                 \
1862     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1863         d = (TN *)&scratch;                                     \
1864     }                                                           \
1865     for (size_t i = 0; i < n; ++i) {                            \
1866         d[HN(i)] = SAT(RSHR(s0[HW(i)], shift));                 \
1867         d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift));             \
1868     }                                                           \
1869     if (d != vd) {                                              \
1870         memcpy(vd, d, oprsz);                                   \
1871     }                                                           \
1872 }
1873 
1874 SQRSHR2(sme2_sqrshr_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h)
1875 SQRSHR2(sme2_uqrshr_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h)
1876 SQRSHR2(sme2_sqrshru_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h)
1877 
1878 #undef SQRSHR2
1879 
1880 #define SQRSHR4(NAME, TW, TN, HW, HN, RSHR, SAT)                \
1881 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1882 {                                                               \
1883     ARMVectorReg scratch;                                       \
1884     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1885     int shift = simd_data(desc);                                \
1886     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1887     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1888     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1889     TN *d = vd;                                                 \
1890     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1891         d = (TN *)&scratch;                                     \
1892     }                                                           \
1893     for (size_t i = 0; i < n; ++i) {                            \
1894         d[HN(i)] = SAT(RSHR(s0[HW(i)], shift));                 \
1895         d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift));             \
1896         d[HN(i + 2 * n)] = SAT(RSHR(s2[HW(i)], shift));         \
1897         d[HN(i + 3 * n)] = SAT(RSHR(s3[HW(i)], shift));         \
1898     }                                                           \
1899     if (d != vd) {                                              \
1900         memcpy(vd, d, oprsz);                                   \
1901     }                                                           \
1902 }
1903 
1904 SQRSHR4(sme2_sqrshr_sb, int32_t, int8_t, H4, H2, do_srshr, do_ssat_b)
1905 SQRSHR4(sme2_uqrshr_sb, uint32_t, uint8_t, H4, H2, do_urshr, do_usat_b)
1906 SQRSHR4(sme2_sqrshru_sb, int32_t, uint8_t, H4, H2, do_srshr, do_usat_b)
1907 
1908 SQRSHR4(sme2_sqrshr_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h)
1909 SQRSHR4(sme2_uqrshr_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h)
1910 SQRSHR4(sme2_sqrshru_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h)
1911 
1912 #undef SQRSHR4
1913 
1914 /* Convert and interleave */
1915 void HELPER(sme2_bfcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1916 {
1917     size_t i, n = simd_oprsz(desc) / 4;
1918     float32 *s0 = vs;
1919     float32 *s1 = vs + sizeof(ARMVectorReg);
1920     bfloat16 *d = vd;
1921 
1922     for (i = 0; i < n; ++i) {
1923         bfloat16 d0 = float32_to_bfloat16(s0[H4(i)], fpst);
1924         bfloat16 d1 = float32_to_bfloat16(s1[H4(i)], fpst);
1925         d[H2(i * 2 + 0)] = d0;
1926         d[H2(i * 2 + 1)] = d1;
1927     }
1928 }
1929 
1930 void HELPER(sme2_fcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc)
1931 {
1932     size_t i, n = simd_oprsz(desc) / 4;
1933     float32 *s0 = vs;
1934     float32 *s1 = vs + sizeof(ARMVectorReg);
1935     bfloat16 *d = vd;
1936 
1937     for (i = 0; i < n; ++i) {
1938         bfloat16 d0 = sve_f32_to_f16(s0[H4(i)], fpst);
1939         bfloat16 d1 = sve_f32_to_f16(s1[H4(i)], fpst);
1940         d[H2(i * 2 + 0)] = d0;
1941         d[H2(i * 2 + 1)] = d1;
1942     }
1943 }
1944 
1945 #define SQCVTN2(NAME, TW, TN, HW, HN, SAT)                      \
1946 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1947 {                                                               \
1948     ARMVectorReg scratch;                                       \
1949     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1950     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1951     TN *d = vd;                                                 \
1952     if (vectors_overlap(vd, 1, vs, 2)) {                        \
1953         d = (TN *)&scratch;                                     \
1954     }                                                           \
1955     for (size_t i = 0; i < n; ++i) {                            \
1956         d[HN(2 * i + 0)] = SAT(s0[HW(i)]);                      \
1957         d[HN(2 * i + 1)] = SAT(s1[HW(i)]);                      \
1958     }                                                           \
1959     if (d != vd) {                                              \
1960         memcpy(vd, d, oprsz);                                   \
1961     }                                                           \
1962 }
1963 
1964 SQCVTN2(sme2_sqcvtn_sh, int32_t, int16_t, H4, H2, do_ssat_h)
1965 SQCVTN2(sme2_uqcvtn_sh, uint32_t, uint16_t, H4, H2, do_usat_h)
1966 SQCVTN2(sme2_sqcvtun_sh, int32_t, uint16_t, H4, H2, do_usat_h)
1967 
1968 #undef SQCVTN2
1969 
1970 #define SQCVTN4(NAME, TW, TN, HW, HN, SAT)                      \
1971 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
1972 {                                                               \
1973     ARMVectorReg scratch;                                       \
1974     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
1975     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
1976     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
1977     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
1978     TN *d = vd;                                                 \
1979     if (vectors_overlap(vd, 1, vs, 4)) {                        \
1980         d = (TN *)&scratch;                                     \
1981     }                                                           \
1982     for (size_t i = 0; i < n; ++i) {                            \
1983         d[HN(4 * i + 0)] = SAT(s0[HW(i)]);                      \
1984         d[HN(4 * i + 1)] = SAT(s1[HW(i)]);                      \
1985         d[HN(4 * i + 2)] = SAT(s2[HW(i)]);                      \
1986         d[HN(4 * i + 3)] = SAT(s3[HW(i)]);                      \
1987     }                                                           \
1988     if (d != vd) {                                              \
1989         memcpy(vd, d, oprsz);                                   \
1990     }                                                           \
1991 }
1992 
1993 SQCVTN4(sme2_sqcvtn_sb, int32_t, int8_t, H4, H1, do_ssat_b)
1994 SQCVTN4(sme2_uqcvtn_sb, uint32_t, uint8_t, H4, H1, do_usat_b)
1995 SQCVTN4(sme2_sqcvtun_sb, int32_t, uint8_t, H4, H1, do_usat_b)
1996 
1997 SQCVTN4(sme2_sqcvtn_dh, int64_t, int16_t, H8, H2, do_ssat_h)
1998 SQCVTN4(sme2_uqcvtn_dh, uint64_t, uint16_t, H8, H2, do_usat_h)
1999 SQCVTN4(sme2_sqcvtun_dh, int64_t, uint16_t, H8, H2, do_usat_h)
2000 
2001 #undef SQCVTN4
2002 
2003 #define SQRSHRN2(NAME, TW, TN, HW, HN, RSHR, SAT)               \
2004 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2005 {                                                               \
2006     ARMVectorReg scratch;                                       \
2007     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
2008     int shift = simd_data(desc);                                \
2009     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
2010     TN *d = vd;                                                 \
2011     if (vectors_overlap(vd, 1, vs, 2)) {                        \
2012         d = (TN *)&scratch;                                     \
2013     }                                                           \
2014     for (size_t i = 0; i < n; ++i) {                            \
2015         d[HN(2 * i + 0)] = SAT(RSHR(s0[HW(i)], shift));         \
2016         d[HN(2 * i + 1)] = SAT(RSHR(s1[HW(i)], shift));         \
2017     }                                                           \
2018     if (d != vd) {                                              \
2019         memcpy(vd, d, oprsz);                                   \
2020     }                                                           \
2021 }
2022 
2023 SQRSHRN2(sme2_sqrshrn_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h)
2024 SQRSHRN2(sme2_uqrshrn_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h)
2025 SQRSHRN2(sme2_sqrshrun_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h)
2026 
2027 #undef SQRSHRN2
2028 
2029 #define SQRSHRN4(NAME, TW, TN, HW, HN, RSHR, SAT)               \
2030 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2031 {                                                               \
2032     ARMVectorReg scratch;                                       \
2033     size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW);    \
2034     int shift = simd_data(desc);                                \
2035     TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg);               \
2036     TW *s2 = vs + 2 * sizeof(ARMVectorReg);                     \
2037     TW *s3 = vs + 3 * sizeof(ARMVectorReg);                     \
2038     TN *d = vd;                                                 \
2039     if (vectors_overlap(vd, 1, vs, 4)) {                        \
2040         d = (TN *)&scratch;                                     \
2041     }                                                           \
2042     for (size_t i = 0; i < n; ++i) {                            \
2043         d[HN(4 * i + 0)] = SAT(RSHR(s0[HW(i)], shift));         \
2044         d[HN(4 * i + 1)] = SAT(RSHR(s1[HW(i)], shift));         \
2045         d[HN(4 * i + 2)] = SAT(RSHR(s2[HW(i)], shift));         \
2046         d[HN(4 * i + 3)] = SAT(RSHR(s3[HW(i)], shift));         \
2047     }                                                           \
2048     if (d != vd) {                                              \
2049         memcpy(vd, d, oprsz);                                   \
2050     }                                                           \
2051 }
2052 
2053 SQRSHRN4(sme2_sqrshrn_sb, int32_t, int8_t, H4, H1, do_srshr, do_ssat_b)
2054 SQRSHRN4(sme2_uqrshrn_sb, uint32_t, uint8_t, H4, H1, do_urshr, do_usat_b)
2055 SQRSHRN4(sme2_sqrshrun_sb, int32_t, uint8_t, H4, H1, do_srshr, do_usat_b)
2056 
2057 SQRSHRN4(sme2_sqrshrn_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h)
2058 SQRSHRN4(sme2_uqrshrn_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h)
2059 SQRSHRN4(sme2_sqrshrun_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h)
2060 
2061 #undef SQRSHRN4
2062 
2063 /* Expand and convert */
2064 void HELPER(sme2_fcvt_w)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2065 {
2066     ARMVectorReg scratch;
2067     size_t oprsz = simd_oprsz(desc);
2068     size_t i, n = oprsz / 4;
2069     float16 *s = vs;
2070     float32 *d0 = vd;
2071     float32 *d1 = vd + sizeof(ARMVectorReg);
2072 
2073     if (vectors_overlap(vd, 1, vs, 2)) {
2074         s = memcpy(&scratch, s, oprsz);
2075     }
2076 
2077     for (i = 0; i < n; ++i) {
2078         d0[H4(i)] = sve_f16_to_f32(s[H2(i)], fpst);
2079     }
2080     for (i = 0; i < n; ++i) {
2081         d1[H4(i)] = sve_f16_to_f32(s[H2(n + i)], fpst);
2082     }
2083 }
2084 
2085 #define UNPK(NAME, SREG, TW, TN, HW, HN)                        \
2086 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2087 {                                                               \
2088     ARMVectorReg scratch[SREG];                                 \
2089     size_t oprsz = simd_oprsz(desc);                            \
2090     size_t n = oprsz / sizeof(TW);                              \
2091     if (vectors_overlap(vd, 2 * SREG, vs, SREG)) {              \
2092         vs = memcpy(scratch, vs, sizeof(scratch));              \
2093     }                                                           \
2094     for (size_t r = 0; r < SREG; ++r) {                         \
2095         TN *s = vs + r * sizeof(ARMVectorReg);                  \
2096         for (size_t i = 0; i < 2; ++i) {                        \
2097             TW *d = vd + (2 * r + i) * sizeof(ARMVectorReg);    \
2098             for (size_t e = 0; e < n; ++e) {                    \
2099                 d[HW(e)] = s[HN(i * n + e)];                    \
2100             }                                                   \
2101         }                                                       \
2102     }                                                           \
2103 }
2104 
2105 UNPK(sme2_sunpk2_bh, 1, int16_t, int8_t, H2, H1)
2106 UNPK(sme2_sunpk2_hs, 1, int32_t, int16_t, H4, H2)
2107 UNPK(sme2_sunpk2_sd, 1, int64_t, int32_t, H8, H4)
2108 
2109 UNPK(sme2_sunpk4_bh, 2, int16_t, int8_t, H2, H1)
2110 UNPK(sme2_sunpk4_hs, 2, int32_t, int16_t, H4, H2)
2111 UNPK(sme2_sunpk4_sd, 2, int64_t, int32_t, H8, H4)
2112 
2113 UNPK(sme2_uunpk2_bh, 1, uint16_t, uint8_t, H2, H1)
2114 UNPK(sme2_uunpk2_hs, 1, uint32_t, uint16_t, H4, H2)
2115 UNPK(sme2_uunpk2_sd, 1, uint64_t, uint32_t, H8, H4)
2116 
2117 UNPK(sme2_uunpk4_bh, 2, uint16_t, uint8_t, H2, H1)
2118 UNPK(sme2_uunpk4_hs, 2, uint32_t, uint16_t, H4, H2)
2119 UNPK(sme2_uunpk4_sd, 2, uint64_t, uint32_t, H8, H4)
2120 
2121 #undef UNPK
2122 
2123 /* Deinterleave and convert. */
2124 void HELPER(sme2_fcvtl)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2125 {
2126     size_t i, n = simd_oprsz(desc) / 4;
2127     float16 *s = vs;
2128     float32 *d0 = vd;
2129     float32 *d1 = vd + sizeof(ARMVectorReg);
2130 
2131     for (i = 0; i < n; ++i) {
2132         float32 v0 = sve_f16_to_f32(s[H2(i * 2 + 0)], fpst);
2133         float32 v1 = sve_f16_to_f32(s[H2(i * 2 + 1)], fpst);
2134         d0[H4(i)] = v0;
2135         d1[H4(i)] = v1;
2136     }
2137 }
2138 
2139 void HELPER(sme2_scvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2140 {
2141     size_t i, n = simd_oprsz(desc) / 4;
2142     int32_t *d = vd;
2143     float32 *s = vs;
2144 
2145     for (i = 0; i < n; ++i) {
2146         d[i] = int32_to_float32(s[i], fpst);
2147     }
2148 }
2149 
2150 void HELPER(sme2_ucvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc)
2151 {
2152     size_t i, n = simd_oprsz(desc) / 4;
2153     uint32_t *d = vd;
2154     float32 *s = vs;
2155 
2156     for (i = 0; i < n; ++i) {
2157         d[i] = uint32_to_float32(s[i], fpst);
2158     }
2159 }
2160 
2161 #define ZIP2(NAME, TYPE, H)                                     \
2162 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2163 {                                                               \
2164     ARMVectorReg scratch[2];                                    \
2165     size_t oprsz = simd_oprsz(desc);                            \
2166     size_t pairs = oprsz / (sizeof(TYPE) * 2);                  \
2167     TYPE *n = vn, *m = vm;                                      \
2168     if (vectors_overlap(vd, 2, vn, 1)) {                        \
2169         n = memcpy(&scratch[0], vn, oprsz);                     \
2170     }                                                           \
2171     if (vectors_overlap(vd, 2, vm, 1)) {                        \
2172         m = memcpy(&scratch[1], vm, oprsz);                     \
2173     }                                                           \
2174     for (size_t r = 0; r < 2; ++r) {                            \
2175         TYPE *d = vd + r * sizeof(ARMVectorReg);                \
2176         size_t base = r * pairs;                                \
2177         for (size_t p = 0; p < pairs; ++p) {                    \
2178             d[H(2 * p + 0)] = n[base + H(p)];                   \
2179             d[H(2 * p + 1)] = m[base + H(p)];                   \
2180         }                                                       \
2181     }                                                           \
2182 }
2183 
2184 ZIP2(sme2_zip2_b, uint8_t, H1)
2185 ZIP2(sme2_zip2_h, uint16_t, H2)
2186 ZIP2(sme2_zip2_s, uint32_t, H4)
2187 ZIP2(sme2_zip2_d, uint64_t, )
2188 ZIP2(sme2_zip2_q, Int128, )
2189 
2190 #undef ZIP2
2191 
2192 #define ZIP4(NAME, TYPE, H)                                     \
2193 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2194 {                                                               \
2195     ARMVectorReg scratch[4];                                    \
2196     size_t oprsz = simd_oprsz(desc);                            \
2197     size_t quads = oprsz / (sizeof(TYPE) * 4);                  \
2198     TYPE *s0, *s1, *s2, *s3;                                    \
2199     if (vs == vd) {                                             \
2200         vs = memcpy(scratch, vs, sizeof(scratch));              \
2201     }                                                           \
2202     s0 = vs;                                                    \
2203     s1 = vs + sizeof(ARMVectorReg);                             \
2204     s2 = vs + 2 * sizeof(ARMVectorReg);                         \
2205     s3 = vs + 3 * sizeof(ARMVectorReg);                         \
2206     for (size_t r = 0; r < 4; ++r) {                            \
2207         TYPE *d = vd + r * sizeof(ARMVectorReg);                \
2208         size_t base = r * quads;                                \
2209         for (size_t q = 0; q < quads; ++q) {                    \
2210             d[H(4 * q + 0)] = s0[base + H(q)];                  \
2211             d[H(4 * q + 1)] = s1[base + H(q)];                  \
2212             d[H(4 * q + 2)] = s2[base + H(q)];                  \
2213             d[H(4 * q + 3)] = s3[base + H(q)];                  \
2214         }                                                       \
2215     }                                                           \
2216 }
2217 
2218 ZIP4(sme2_zip4_b, uint8_t, H1)
2219 ZIP4(sme2_zip4_h, uint16_t, H2)
2220 ZIP4(sme2_zip4_s, uint32_t, H4)
2221 ZIP4(sme2_zip4_d, uint64_t, )
2222 ZIP4(sme2_zip4_q, Int128, )
2223 
2224 #undef ZIP4
2225 
2226 #define UZP2(NAME, TYPE, H)                                     \
2227 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2228 {                                                               \
2229     ARMVectorReg scratch[2];                                    \
2230     size_t oprsz = simd_oprsz(desc);                            \
2231     size_t pairs = oprsz / (sizeof(TYPE) * 2);                  \
2232     TYPE *d0 = vd, *d1 = vd + sizeof(ARMVectorReg);             \
2233     if (vectors_overlap(vd, 2, vn, 1)) {                        \
2234         vn = memcpy(&scratch[0], vn, oprsz);                    \
2235     }                                                           \
2236     if (vectors_overlap(vd, 2, vm, 1)) {                        \
2237         vm = memcpy(&scratch[1], vm, oprsz);                    \
2238     }                                                           \
2239     for (size_t r = 0; r < 2; ++r) {                            \
2240         TYPE *s = r ? vm : vn;                                  \
2241         size_t base = r * pairs;                                \
2242         for (size_t p = 0; p < pairs; ++p) {                    \
2243             d0[base + H(p)] = s[H(2 * p + 0)];                  \
2244             d1[base + H(p)] = s[H(2 * p + 1)];                  \
2245         }                                                       \
2246     }                                                           \
2247 }
2248 
2249 UZP2(sme2_uzp2_b, uint8_t, H1)
2250 UZP2(sme2_uzp2_h, uint16_t, H2)
2251 UZP2(sme2_uzp2_s, uint32_t, H4)
2252 UZP2(sme2_uzp2_d, uint64_t, )
2253 UZP2(sme2_uzp2_q, Int128, )
2254 
2255 #undef UZP2
2256 
2257 #define UZP4(NAME, TYPE, H)                                     \
2258 void HELPER(NAME)(void *vd, void *vs, uint32_t desc)            \
2259 {                                                               \
2260     ARMVectorReg scratch[4];                                    \
2261     size_t oprsz = simd_oprsz(desc);                            \
2262     size_t quads = oprsz / (sizeof(TYPE) * 4);                  \
2263     TYPE *d0, *d1, *d2, *d3;                                    \
2264     if (vs == vd) {                                             \
2265         vs = memcpy(scratch, vs, sizeof(scratch));              \
2266     }                                                           \
2267     d0 = vd;                                                    \
2268     d1 = vd + sizeof(ARMVectorReg);                             \
2269     d2 = vd + 2 * sizeof(ARMVectorReg);                         \
2270     d3 = vd + 3 * sizeof(ARMVectorReg);                         \
2271     for (size_t r = 0; r < 4; ++r) {                            \
2272         TYPE *s = vs + r * sizeof(ARMVectorReg);                \
2273         size_t base = r * quads;                                \
2274         for (size_t q = 0; q < quads; ++q) {                    \
2275             d0[base + H(q)] = s[H(4 * q + 0)];                  \
2276             d1[base + H(q)] = s[H(4 * q + 1)];                  \
2277             d2[base + H(q)] = s[H(4 * q + 2)];                  \
2278             d3[base + H(q)] = s[H(4 * q + 3)];                  \
2279         }                                                       \
2280     }                                                           \
2281 }
2282 
2283 UZP4(sme2_uzp4_b, uint8_t, H1)
2284 UZP4(sme2_uzp4_h, uint16_t, H2)
2285 UZP4(sme2_uzp4_s, uint32_t, H4)
2286 UZP4(sme2_uzp4_d, uint64_t, )
2287 UZP4(sme2_uzp4_q, Int128, )
2288 
2289 #undef UZP4
2290 
2291 #define ICLAMP(NAME, TYPE, H) \
2292 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc)  \
2293 {                                                               \
2294     size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE);        \
2295     size_t elements = simd_oprsz(desc) / sizeof(TYPE);          \
2296     size_t nreg = simd_data(desc);                              \
2297     TYPE *d = vd, *n = vn, *m = vm;                             \
2298     for (size_t e = 0; e < elements; e++) {                     \
2299         TYPE nn = n[H(e)], mm = m[H(e)];                        \
2300         for (size_t r = 0; r < nreg; r++) {                     \
2301             TYPE *dd = &d[r * stride + H(e)];                   \
2302             *dd = MIN(MAX(*dd, nn), mm);                        \
2303         }                                                       \
2304     }                                                           \
2305 }
2306 
2307 ICLAMP(sme2_sclamp_b, int8_t, H1)
2308 ICLAMP(sme2_sclamp_h, int16_t, H2)
2309 ICLAMP(sme2_sclamp_s, int32_t, H4)
2310 ICLAMP(sme2_sclamp_d, int64_t, H8)
2311 
2312 ICLAMP(sme2_uclamp_b, uint8_t, H1)
2313 ICLAMP(sme2_uclamp_h, uint16_t, H2)
2314 ICLAMP(sme2_uclamp_s, uint32_t, H4)
2315 ICLAMP(sme2_uclamp_d, uint64_t, H8)
2316 
2317 #undef ICLAMP
2318 
2319 /*
2320  * Note the argument ordering to minnum and maxnum must match
2321  * the ARM pseudocode so that NaNs are propagated properly.
2322  */
2323 #define FCLAMP(NAME, TYPE, H) \
2324 void HELPER(NAME)(void *vd, void *vn, void *vm,                 \
2325                   float_status *fpst, uint32_t desc)            \
2326 {                                                               \
2327     size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE);        \
2328     size_t elements = simd_oprsz(desc) / sizeof(TYPE);          \
2329     size_t nreg = simd_data(desc);                              \
2330     TYPE *d = vd, *n = vn, *m = vm;                             \
2331     for (size_t e = 0; e < elements; e++) {                     \
2332         TYPE nn = n[H(e)], mm = m[H(e)];                        \
2333         for (size_t r = 0; r < nreg; r++) {                     \
2334             TYPE *dd = &d[r * stride + H(e)];                   \
2335             *dd = TYPE##_minnum(TYPE##_maxnum(nn, *dd, fpst), mm, fpst); \
2336         }                                                       \
2337     }                                                           \
2338 }
2339 
2340 FCLAMP(sme2_fclamp_h, float16, H2)
2341 FCLAMP(sme2_fclamp_s, float32, H4)
2342 FCLAMP(sme2_fclamp_d, float64, H8)
2343 FCLAMP(sme2_bfclamp, bfloat16, H2)
2344 
2345 #undef FCLAMP
2346 
2347 void HELPER(sme2_sel_b)(void *vd, void *vn, void *vm,
2348                         uint32_t png, uint32_t desc)
2349 {
2350     int vl = simd_oprsz(desc);
2351     int nreg = simd_data(desc);
2352     int elements = vl / sizeof(uint8_t);
2353     DecodeCounter p = decode_counter(png, vl, MO_8);
2354 
2355     if (p.lg2_stride == 0) {
2356         if (p.invert) {
2357             for (int r = 0; r < nreg; r++) {
2358                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2359                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2360                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2361                 int split = p.count - r * elements;
2362 
2363                 if (split <= 0) {
2364                     memcpy(d, n, vl);  /* all true */
2365                 } else if (elements <= split) {
2366                     memcpy(d, m, vl);  /* all false */
2367                 } else {
2368                     for (int e = 0; e < split; e++) {
2369                         d[H1(e)] = m[H1(e)];
2370                     }
2371                     for (int e = split; e < elements; e++) {
2372                         d[H1(e)] = n[H1(e)];
2373                     }
2374                 }
2375             }
2376         } else {
2377             for (int r = 0; r < nreg; r++) {
2378                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2379                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2380                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2381                 int split = p.count - r * elements;
2382 
2383                 if (split <= 0) {
2384                     memcpy(d, m, vl);  /* all false */
2385                 } else if (elements <= split) {
2386                     memcpy(d, n, vl);  /* all true */
2387                 } else {
2388                     for (int e = 0; e < split; e++) {
2389                         d[H1(e)] = n[H1(e)];
2390                     }
2391                     for (int e = split; e < elements; e++) {
2392                         d[H1(e)] = m[H1(e)];
2393                     }
2394                 }
2395             }
2396         }
2397     } else {
2398         int estride = 1 << p.lg2_stride;
2399         if (p.invert) {
2400             for (int r = 0; r < nreg; r++) {
2401                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2402                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2403                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2404                 int split = p.count - r * elements;
2405                 int e = 0;
2406 
2407                 for (; e < MIN(split, elements); e++) {
2408                     d[H1(e)] = m[H1(e)];
2409                 }
2410                 for (; e < elements; e += estride) {
2411                     d[H1(e)] = n[H1(e)];
2412                     for (int i = 1; i < estride; i++) {
2413                         d[H1(e + i)] = m[H1(e + i)];
2414                     }
2415                 }
2416             }
2417         } else {
2418             for (int r = 0; r < nreg; r++) {
2419                 uint8_t *d = vd + r * sizeof(ARMVectorReg);
2420                 uint8_t *n = vn + r * sizeof(ARMVectorReg);
2421                 uint8_t *m = vm + r * sizeof(ARMVectorReg);
2422                 int split = p.count - r * elements;
2423                 int e = 0;
2424 
2425                 for (; e < MIN(split, elements); e += estride) {
2426                     d[H1(e)] = n[H1(e)];
2427                     for (int i = 1; i < estride; i++) {
2428                         d[H1(e + i)] = m[H1(e + i)];
2429                     }
2430                 }
2431                 for (; e < elements; e++) {
2432                     d[H1(e)] = m[H1(e)];
2433                 }
2434             }
2435         }
2436     }
2437 }
2438 
2439 void HELPER(sme2_sel_h)(void *vd, void *vn, void *vm,
2440                         uint32_t png, uint32_t desc)
2441 {
2442     int vl = simd_oprsz(desc);
2443     int nreg = simd_data(desc);
2444     int elements = vl / sizeof(uint16_t);
2445     DecodeCounter p = decode_counter(png, vl, MO_16);
2446 
2447     if (p.lg2_stride == 0) {
2448         if (p.invert) {
2449             for (int r = 0; r < nreg; r++) {
2450                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2451                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2452                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2453                 int split = p.count - r * elements;
2454 
2455                 if (split <= 0) {
2456                     memcpy(d, n, vl);  /* all true */
2457                 } else if (elements <= split) {
2458                     memcpy(d, m, vl);  /* all false */
2459                 } else {
2460                     for (int e = 0; e < split; e++) {
2461                         d[H2(e)] = m[H2(e)];
2462                     }
2463                     for (int e = split; e < elements; e++) {
2464                         d[H2(e)] = n[H2(e)];
2465                     }
2466                 }
2467             }
2468         } else {
2469             for (int r = 0; r < nreg; r++) {
2470                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2471                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2472                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2473                 int split = p.count - r * elements;
2474 
2475                 if (split <= 0) {
2476                     memcpy(d, m, vl);  /* all false */
2477                 } else if (elements <= split) {
2478                     memcpy(d, n, vl);  /* all true */
2479                 } else {
2480                     for (int e = 0; e < split; e++) {
2481                         d[H2(e)] = n[H2(e)];
2482                     }
2483                     for (int e = split; e < elements; e++) {
2484                         d[H2(e)] = m[H2(e)];
2485                     }
2486                 }
2487             }
2488         }
2489     } else {
2490         int estride = 1 << p.lg2_stride;
2491         if (p.invert) {
2492             for (int r = 0; r < nreg; r++) {
2493                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2494                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2495                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2496                 int split = p.count - r * elements;
2497                 int e = 0;
2498 
2499                 for (; e < MIN(split, elements); e++) {
2500                     d[H2(e)] = m[H2(e)];
2501                 }
2502                 for (; e < elements; e += estride) {
2503                     d[H2(e)] = n[H2(e)];
2504                     for (int i = 1; i < estride; i++) {
2505                         d[H2(e + i)] = m[H2(e + i)];
2506                     }
2507                 }
2508             }
2509         } else {
2510             for (int r = 0; r < nreg; r++) {
2511                 uint16_t *d = vd + r * sizeof(ARMVectorReg);
2512                 uint16_t *n = vn + r * sizeof(ARMVectorReg);
2513                 uint16_t *m = vm + r * sizeof(ARMVectorReg);
2514                 int split = p.count - r * elements;
2515                 int e = 0;
2516 
2517                 for (; e < MIN(split, elements); e += estride) {
2518                     d[H2(e)] = n[H2(e)];
2519                     for (int i = 1; i < estride; i++) {
2520                         d[H2(e + i)] = m[H2(e + i)];
2521                     }
2522                 }
2523                 for (; e < elements; e++) {
2524                     d[H2(e)] = m[H2(e)];
2525                 }
2526             }
2527         }
2528     }
2529 }
2530 
2531 void HELPER(sme2_sel_s)(void *vd, void *vn, void *vm,
2532                         uint32_t png, uint32_t desc)
2533 {
2534     int vl = simd_oprsz(desc);
2535     int nreg = simd_data(desc);
2536     int elements = vl / sizeof(uint32_t);
2537     DecodeCounter p = decode_counter(png, vl, MO_32);
2538 
2539     if (p.lg2_stride == 0) {
2540         if (p.invert) {
2541             for (int r = 0; r < nreg; r++) {
2542                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2543                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2544                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2545                 int split = p.count - r * elements;
2546 
2547                 if (split <= 0) {
2548                     memcpy(d, n, vl);  /* all true */
2549                 } else if (elements <= split) {
2550                     memcpy(d, m, vl);  /* all false */
2551                 } else {
2552                     for (int e = 0; e < split; e++) {
2553                         d[H4(e)] = m[H4(e)];
2554                     }
2555                     for (int e = split; e < elements; e++) {
2556                         d[H4(e)] = n[H4(e)];
2557                     }
2558                 }
2559             }
2560         } else {
2561             for (int r = 0; r < nreg; r++) {
2562                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2563                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2564                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2565                 int split = p.count - r * elements;
2566 
2567                 if (split <= 0) {
2568                     memcpy(d, m, vl);  /* all false */
2569                 } else if (elements <= split) {
2570                     memcpy(d, n, vl);  /* all true */
2571                 } else {
2572                     for (int e = 0; e < split; e++) {
2573                         d[H4(e)] = n[H4(e)];
2574                     }
2575                     for (int e = split; e < elements; e++) {
2576                         d[H4(e)] = m[H4(e)];
2577                     }
2578                 }
2579             }
2580         }
2581     } else {
2582         /* p.esz must be MO_64, so stride must be 2. */
2583         if (p.invert) {
2584             for (int r = 0; r < nreg; r++) {
2585                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2586                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2587                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2588                 int split = p.count - r * elements;
2589                 int e = 0;
2590 
2591                 for (; e < MIN(split, elements); e++) {
2592                     d[H4(e)] = m[H4(e)];
2593                 }
2594                 for (; e < elements; e += 2) {
2595                     d[H4(e)] = n[H4(e)];
2596                     d[H4(e + 1)] = m[H4(e + 1)];
2597                 }
2598             }
2599         } else {
2600             for (int r = 0; r < nreg; r++) {
2601                 uint32_t *d = vd + r * sizeof(ARMVectorReg);
2602                 uint32_t *n = vn + r * sizeof(ARMVectorReg);
2603                 uint32_t *m = vm + r * sizeof(ARMVectorReg);
2604                 int split = p.count - r * elements;
2605                 int e = 0;
2606 
2607                 for (; e < MIN(split, elements); e += 2) {
2608                     d[H4(e)] = n[H4(e)];
2609                     d[H4(e + 1)] = m[H4(e + 1)];
2610                 }
2611                 for (; e < elements; e++) {
2612                     d[H4(e)] = m[H4(e)];
2613                 }
2614             }
2615         }
2616     }
2617 }
2618 
2619 void HELPER(sme2_sel_d)(void *vd, void *vn, void *vm,
2620                         uint32_t png, uint32_t desc)
2621 {
2622     int vl = simd_oprsz(desc);
2623     int nreg = simd_data(desc);
2624     int elements = vl / sizeof(uint64_t);
2625     DecodeCounter p = decode_counter(png, vl, MO_64);
2626 
2627     if (p.invert) {
2628         for (int r = 0; r < nreg; r++) {
2629             uint64_t *d = vd + r * sizeof(ARMVectorReg);
2630             uint64_t *n = vn + r * sizeof(ARMVectorReg);
2631             uint64_t *m = vm + r * sizeof(ARMVectorReg);
2632             int split = p.count - r * elements;
2633 
2634             if (split <= 0) {
2635                 memcpy(d, n, vl);  /* all true */
2636             } else if (elements <= split) {
2637                 memcpy(d, m, vl);  /* all false */
2638             } else {
2639                 memcpy(d, m, split * sizeof(uint64_t));
2640                 memcpy(d + split, n + split,
2641                        (elements - split) * sizeof(uint64_t));
2642             }
2643         }
2644     } else {
2645         for (int r = 0; r < nreg; r++) {
2646             uint64_t *d = vd + r * sizeof(ARMVectorReg);
2647             uint64_t *n = vn + r * sizeof(ARMVectorReg);
2648             uint64_t *m = vm + r * sizeof(ARMVectorReg);
2649             int split = p.count - r * elements;
2650 
2651             if (split <= 0) {
2652                 memcpy(d, m, vl);  /* all false */
2653             } else if (elements <= split) {
2654                 memcpy(d, n, vl);  /* all true */
2655             } else {
2656                 memcpy(d, n, split * sizeof(uint64_t));
2657                 memcpy(d + split, m + split,
2658                        (elements - split) * sizeof(uint64_t));
2659             }
2660         }
2661     }
2662 }
2663