1 /* 2 * ARM SME Operations 3 * 4 * Copyright (c) 2022 Linaro, Ltd. 5 * 6 * This library is free software; you can redistribute it and/or 7 * modify it under the terms of the GNU Lesser General Public 8 * License as published by the Free Software Foundation; either 9 * version 2.1 of the License, or (at your option) any later version. 10 * 11 * This library is distributed in the hope that it will be useful, 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 14 * Lesser General Public License for more details. 15 * 16 * You should have received a copy of the GNU Lesser General Public 17 * License along with this library; if not, see <http://www.gnu.org/licenses/>. 18 */ 19 20 #include "qemu/osdep.h" 21 #include "cpu.h" 22 #include "internals.h" 23 #include "tcg/tcg-gvec-desc.h" 24 #include "exec/helper-proto.h" 25 #include "accel/tcg/cpu-ldst.h" 26 #include "accel/tcg/helper-retaddr.h" 27 #include "qemu/int128.h" 28 #include "fpu/softfloat.h" 29 #include "vec_internal.h" 30 #include "sve_ldst_internal.h" 31 32 33 static bool vectors_overlap(ARMVectorReg *x, unsigned nx, 34 ARMVectorReg *y, unsigned ny) 35 { 36 return !(x + nx <= y || y + ny <= x); 37 } 38 39 void helper_set_svcr(CPUARMState *env, uint32_t val, uint32_t mask) 40 { 41 aarch64_set_svcr(env, val, mask); 42 } 43 44 void helper_sme_zero(CPUARMState *env, uint32_t imm, uint32_t svl) 45 { 46 uint32_t i; 47 48 /* 49 * Special case clearing the entire ZArray. 50 * This falls into the CONSTRAINED UNPREDICTABLE zeroing of any 51 * parts of the ZA storage outside of SVL. 52 */ 53 if (imm == 0xff) { 54 memset(env->za_state.za, 0, sizeof(env->za_state.za)); 55 return; 56 } 57 58 /* 59 * Recall that ZAnH.D[m] is spread across ZA[n+8*m], 60 * so each row is discontiguous within ZA[]. 61 */ 62 for (i = 0; i < svl; i++) { 63 if (imm & (1 << (i % 8))) { 64 memset(&env->za_state.za[i], 0, svl); 65 } 66 } 67 } 68 69 70 /* 71 * When considering the ZA storage as an array of elements of 72 * type T, the index within that array of the Nth element of 73 * a vertical slice of a tile can be calculated like this, 74 * regardless of the size of type T. This is because the tiles 75 * are interleaved, so if type T is size N bytes then row 1 of 76 * the tile is N rows away from row 0. The division by N to 77 * convert a byte offset into an array index and the multiplication 78 * by N to convert from vslice-index-within-the-tile to 79 * the index within the ZA storage cancel out. 80 */ 81 #define tile_vslice_index(i) ((i) * sizeof(ARMVectorReg)) 82 83 /* 84 * When doing byte arithmetic on the ZA storage, the element 85 * byteoff bytes away in a tile vertical slice is always this 86 * many bytes away in the ZA storage, regardless of the 87 * size of the tile element, assuming that byteoff is a multiple 88 * of the element size. Again this is because of the interleaving 89 * of the tiles. For instance if we have 1 byte per element then 90 * each row of the ZA storage has one byte of the vslice data, 91 * and (counting from 0) byte 8 goes in row 8 of the storage 92 * at offset (8 * row-size-in-bytes). 93 * If we have 8 bytes per element then each row of the ZA storage 94 * has 8 bytes of the data, but there are 8 interleaved tiles and 95 * so byte 8 of the data goes into row 1 of the tile, 96 * which is again row 8 of the storage, so the offset is still 97 * (8 * row-size-in-bytes). Similarly for other element sizes. 98 */ 99 #define tile_vslice_offset(byteoff) ((byteoff) * sizeof(ARMVectorReg)) 100 101 102 /* 103 * Move Zreg vector to ZArray column. 104 */ 105 #define DO_MOVA_C(NAME, TYPE, H) \ 106 void HELPER(NAME)(void *za, void *vn, void *vg, uint32_t desc) \ 107 { \ 108 int i, oprsz = simd_oprsz(desc); \ 109 for (i = 0; i < oprsz; ) { \ 110 uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3)); \ 111 do { \ 112 if (pg & 1) { \ 113 *(TYPE *)(za + tile_vslice_offset(i)) = *(TYPE *)(vn + H(i)); \ 114 } \ 115 i += sizeof(TYPE); \ 116 pg >>= sizeof(TYPE); \ 117 } while (i & 15); \ 118 } \ 119 } 120 121 DO_MOVA_C(sme_mova_cz_b, uint8_t, H1) 122 DO_MOVA_C(sme_mova_cz_h, uint16_t, H1_2) 123 DO_MOVA_C(sme_mova_cz_s, uint32_t, H1_4) 124 125 void HELPER(sme_mova_cz_d)(void *za, void *vn, void *vg, uint32_t desc) 126 { 127 int i, oprsz = simd_oprsz(desc) / 8; 128 uint8_t *pg = vg; 129 uint64_t *n = vn; 130 uint64_t *a = za; 131 132 for (i = 0; i < oprsz; i++) { 133 if (pg[H1(i)] & 1) { 134 a[tile_vslice_index(i)] = n[i]; 135 } 136 } 137 } 138 139 void HELPER(sme_mova_cz_q)(void *za, void *vn, void *vg, uint32_t desc) 140 { 141 int i, oprsz = simd_oprsz(desc) / 16; 142 uint16_t *pg = vg; 143 Int128 *n = vn; 144 Int128 *a = za; 145 146 /* 147 * Int128 is used here simply to copy 16 bytes, and to simplify 148 * the address arithmetic. 149 */ 150 for (i = 0; i < oprsz; i++) { 151 if (pg[H2(i)] & 1) { 152 a[tile_vslice_index(i)] = n[i]; 153 } 154 } 155 } 156 157 #undef DO_MOVA_C 158 159 /* 160 * Move ZArray column to Zreg vector. 161 */ 162 #define DO_MOVA_Z(NAME, TYPE, H) \ 163 void HELPER(NAME)(void *vd, void *za, void *vg, uint32_t desc) \ 164 { \ 165 int i, oprsz = simd_oprsz(desc); \ 166 for (i = 0; i < oprsz; ) { \ 167 uint16_t pg = *(uint16_t *)(vg + H1_2(i >> 3)); \ 168 do { \ 169 if (pg & 1) { \ 170 *(TYPE *)(vd + H(i)) = *(TYPE *)(za + tile_vslice_offset(i)); \ 171 } \ 172 i += sizeof(TYPE); \ 173 pg >>= sizeof(TYPE); \ 174 } while (i & 15); \ 175 } \ 176 } 177 178 DO_MOVA_Z(sme_mova_zc_b, uint8_t, H1) 179 DO_MOVA_Z(sme_mova_zc_h, uint16_t, H1_2) 180 DO_MOVA_Z(sme_mova_zc_s, uint32_t, H1_4) 181 182 void HELPER(sme_mova_zc_d)(void *vd, void *za, void *vg, uint32_t desc) 183 { 184 int i, oprsz = simd_oprsz(desc) / 8; 185 uint8_t *pg = vg; 186 uint64_t *d = vd; 187 uint64_t *a = za; 188 189 for (i = 0; i < oprsz; i++) { 190 if (pg[H1(i)] & 1) { 191 d[i] = a[tile_vslice_index(i)]; 192 } 193 } 194 } 195 196 void HELPER(sme_mova_zc_q)(void *vd, void *za, void *vg, uint32_t desc) 197 { 198 int i, oprsz = simd_oprsz(desc) / 16; 199 uint16_t *pg = vg; 200 Int128 *d = vd; 201 Int128 *a = za; 202 203 /* 204 * Int128 is used here simply to copy 16 bytes, and to simplify 205 * the address arithmetic. 206 */ 207 for (i = 0; i < oprsz; i++, za += sizeof(ARMVectorReg)) { 208 if (pg[H2(i)] & 1) { 209 d[i] = a[tile_vslice_index(i)]; 210 } 211 } 212 } 213 214 #undef DO_MOVA_Z 215 216 void HELPER(sme2_mova_zc_b)(void *vdst, void *vsrc, uint32_t desc) 217 { 218 const uint8_t *src = vsrc; 219 uint8_t *dst = vdst; 220 size_t i, n = simd_oprsz(desc); 221 222 for (i = 0; i < n; ++i) { 223 dst[i] = src[tile_vslice_index(i)]; 224 } 225 } 226 227 void HELPER(sme2_mova_zc_h)(void *vdst, void *vsrc, uint32_t desc) 228 { 229 const uint16_t *src = vsrc; 230 uint16_t *dst = vdst; 231 size_t i, n = simd_oprsz(desc) / 2; 232 233 for (i = 0; i < n; ++i) { 234 dst[i] = src[tile_vslice_index(i)]; 235 } 236 } 237 238 void HELPER(sme2_mova_zc_s)(void *vdst, void *vsrc, uint32_t desc) 239 { 240 const uint32_t *src = vsrc; 241 uint32_t *dst = vdst; 242 size_t i, n = simd_oprsz(desc) / 4; 243 244 for (i = 0; i < n; ++i) { 245 dst[i] = src[tile_vslice_index(i)]; 246 } 247 } 248 249 void HELPER(sme2_mova_zc_d)(void *vdst, void *vsrc, uint32_t desc) 250 { 251 const uint64_t *src = vsrc; 252 uint64_t *dst = vdst; 253 size_t i, n = simd_oprsz(desc) / 8; 254 255 for (i = 0; i < n; ++i) { 256 dst[i] = src[tile_vslice_index(i)]; 257 } 258 } 259 260 void HELPER(sme2p1_movaz_zc_b)(void *vdst, void *vsrc, uint32_t desc) 261 { 262 uint8_t *src = vsrc; 263 uint8_t *dst = vdst; 264 size_t i, n = simd_oprsz(desc); 265 266 for (i = 0; i < n; ++i) { 267 dst[i] = src[tile_vslice_index(i)]; 268 src[tile_vslice_index(i)] = 0; 269 } 270 } 271 272 void HELPER(sme2p1_movaz_zc_h)(void *vdst, void *vsrc, uint32_t desc) 273 { 274 uint16_t *src = vsrc; 275 uint16_t *dst = vdst; 276 size_t i, n = simd_oprsz(desc) / 2; 277 278 for (i = 0; i < n; ++i) { 279 dst[i] = src[tile_vslice_index(i)]; 280 src[tile_vslice_index(i)] = 0; 281 } 282 } 283 284 void HELPER(sme2p1_movaz_zc_s)(void *vdst, void *vsrc, uint32_t desc) 285 { 286 uint32_t *src = vsrc; 287 uint32_t *dst = vdst; 288 size_t i, n = simd_oprsz(desc) / 4; 289 290 for (i = 0; i < n; ++i) { 291 dst[i] = src[tile_vslice_index(i)]; 292 src[tile_vslice_index(i)] = 0; 293 } 294 } 295 296 void HELPER(sme2p1_movaz_zc_d)(void *vdst, void *vsrc, uint32_t desc) 297 { 298 uint64_t *src = vsrc; 299 uint64_t *dst = vdst; 300 size_t i, n = simd_oprsz(desc) / 8; 301 302 for (i = 0; i < n; ++i) { 303 dst[i] = src[tile_vslice_index(i)]; 304 src[tile_vslice_index(i)] = 0; 305 } 306 } 307 308 void HELPER(sme2p1_movaz_zc_q)(void *vdst, void *vsrc, uint32_t desc) 309 { 310 Int128 *src = vsrc; 311 Int128 *dst = vdst; 312 size_t i, n = simd_oprsz(desc) / 16; 313 314 for (i = 0; i < n; ++i) { 315 dst[i] = src[tile_vslice_index(i)]; 316 memset(&src[tile_vslice_index(i)], 0, 16); 317 } 318 } 319 320 /* 321 * Clear elements in a tile slice comprising len bytes. 322 */ 323 324 typedef void ClearFn(void *ptr, size_t off, size_t len); 325 326 static void clear_horizontal(void *ptr, size_t off, size_t len) 327 { 328 memset(ptr + off, 0, len); 329 } 330 331 static void clear_vertical_b(void *vptr, size_t off, size_t len) 332 { 333 for (size_t i = 0; i < len; ++i) { 334 *(uint8_t *)(vptr + tile_vslice_offset(i + off)) = 0; 335 } 336 } 337 338 static void clear_vertical_h(void *vptr, size_t off, size_t len) 339 { 340 for (size_t i = 0; i < len; i += 2) { 341 *(uint16_t *)(vptr + tile_vslice_offset(i + off)) = 0; 342 } 343 } 344 345 static void clear_vertical_s(void *vptr, size_t off, size_t len) 346 { 347 for (size_t i = 0; i < len; i += 4) { 348 *(uint32_t *)(vptr + tile_vslice_offset(i + off)) = 0; 349 } 350 } 351 352 static void clear_vertical_d(void *vptr, size_t off, size_t len) 353 { 354 for (size_t i = 0; i < len; i += 8) { 355 *(uint64_t *)(vptr + tile_vslice_offset(i + off)) = 0; 356 } 357 } 358 359 static void clear_vertical_q(void *vptr, size_t off, size_t len) 360 { 361 for (size_t i = 0; i < len; i += 16) { 362 memset(vptr + tile_vslice_offset(i + off), 0, 16); 363 } 364 } 365 366 /* 367 * Copy elements from an array into a tile slice comprising len bytes. 368 */ 369 370 typedef void CopyFn(void *dst, const void *src, size_t len); 371 372 static void copy_horizontal(void *dst, const void *src, size_t len) 373 { 374 memcpy(dst, src, len); 375 } 376 377 static void copy_vertical_b(void *vdst, const void *vsrc, size_t len) 378 { 379 const uint8_t *src = vsrc; 380 uint8_t *dst = vdst; 381 size_t i; 382 383 for (i = 0; i < len; ++i) { 384 dst[tile_vslice_index(i)] = src[i]; 385 } 386 } 387 388 static void copy_vertical_h(void *vdst, const void *vsrc, size_t len) 389 { 390 const uint16_t *src = vsrc; 391 uint16_t *dst = vdst; 392 size_t i; 393 394 for (i = 0; i < len / 2; ++i) { 395 dst[tile_vslice_index(i)] = src[i]; 396 } 397 } 398 399 static void copy_vertical_s(void *vdst, const void *vsrc, size_t len) 400 { 401 const uint32_t *src = vsrc; 402 uint32_t *dst = vdst; 403 size_t i; 404 405 for (i = 0; i < len / 4; ++i) { 406 dst[tile_vslice_index(i)] = src[i]; 407 } 408 } 409 410 static void copy_vertical_d(void *vdst, const void *vsrc, size_t len) 411 { 412 const uint64_t *src = vsrc; 413 uint64_t *dst = vdst; 414 size_t i; 415 416 for (i = 0; i < len / 8; ++i) { 417 dst[tile_vslice_index(i)] = src[i]; 418 } 419 } 420 421 static void copy_vertical_q(void *vdst, const void *vsrc, size_t len) 422 { 423 for (size_t i = 0; i < len; i += 16) { 424 memcpy(vdst + tile_vslice_offset(i), vsrc + i, 16); 425 } 426 } 427 428 void HELPER(sme2_mova_cz_b)(void *vdst, void *vsrc, uint32_t desc) 429 { 430 copy_vertical_b(vdst, vsrc, simd_oprsz(desc)); 431 } 432 433 void HELPER(sme2_mova_cz_h)(void *vdst, void *vsrc, uint32_t desc) 434 { 435 copy_vertical_h(vdst, vsrc, simd_oprsz(desc)); 436 } 437 438 void HELPER(sme2_mova_cz_s)(void *vdst, void *vsrc, uint32_t desc) 439 { 440 copy_vertical_s(vdst, vsrc, simd_oprsz(desc)); 441 } 442 443 void HELPER(sme2_mova_cz_d)(void *vdst, void *vsrc, uint32_t desc) 444 { 445 copy_vertical_d(vdst, vsrc, simd_oprsz(desc)); 446 } 447 448 /* 449 * Host and TLB primitives for vertical tile slice addressing. 450 */ 451 452 #define DO_LD(NAME, TYPE, HOST, TLB) \ 453 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host) \ 454 { \ 455 TYPE val = HOST(host); \ 456 *(TYPE *)(za + tile_vslice_offset(off)) = val; \ 457 } \ 458 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za, \ 459 intptr_t off, target_ulong addr, uintptr_t ra) \ 460 { \ 461 TYPE val = TLB(env, useronly_clean_ptr(addr), ra); \ 462 *(TYPE *)(za + tile_vslice_offset(off)) = val; \ 463 } 464 465 #define DO_ST(NAME, TYPE, HOST, TLB) \ 466 static inline void sme_##NAME##_v_host(void *za, intptr_t off, void *host) \ 467 { \ 468 TYPE val = *(TYPE *)(za + tile_vslice_offset(off)); \ 469 HOST(host, val); \ 470 } \ 471 static inline void sme_##NAME##_v_tlb(CPUARMState *env, void *za, \ 472 intptr_t off, target_ulong addr, uintptr_t ra) \ 473 { \ 474 TYPE val = *(TYPE *)(za + tile_vslice_offset(off)); \ 475 TLB(env, useronly_clean_ptr(addr), val, ra); \ 476 } 477 478 #define DO_LDQ(HNAME, VNAME) \ 479 static inline void VNAME##_v_host(void *za, intptr_t off, void *host) \ 480 { \ 481 HNAME##_host(za, tile_vslice_offset(off), host); \ 482 } \ 483 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off, \ 484 target_ulong addr, uintptr_t ra) \ 485 { \ 486 HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra); \ 487 } 488 489 #define DO_STQ(HNAME, VNAME) \ 490 static inline void VNAME##_v_host(void *za, intptr_t off, void *host) \ 491 { \ 492 HNAME##_host(za, tile_vslice_offset(off), host); \ 493 } \ 494 static inline void VNAME##_v_tlb(CPUARMState *env, void *za, intptr_t off, \ 495 target_ulong addr, uintptr_t ra) \ 496 { \ 497 HNAME##_tlb(env, za, tile_vslice_offset(off), addr, ra); \ 498 } 499 500 DO_LD(ld1b, uint8_t, ldub_p, cpu_ldub_data_ra) 501 DO_LD(ld1h_be, uint16_t, lduw_be_p, cpu_lduw_be_data_ra) 502 DO_LD(ld1h_le, uint16_t, lduw_le_p, cpu_lduw_le_data_ra) 503 DO_LD(ld1s_be, uint32_t, ldl_be_p, cpu_ldl_be_data_ra) 504 DO_LD(ld1s_le, uint32_t, ldl_le_p, cpu_ldl_le_data_ra) 505 DO_LD(ld1d_be, uint64_t, ldq_be_p, cpu_ldq_be_data_ra) 506 DO_LD(ld1d_le, uint64_t, ldq_le_p, cpu_ldq_le_data_ra) 507 508 DO_LDQ(sve_ld1qq_be, sme_ld1q_be) 509 DO_LDQ(sve_ld1qq_le, sme_ld1q_le) 510 511 DO_ST(st1b, uint8_t, stb_p, cpu_stb_data_ra) 512 DO_ST(st1h_be, uint16_t, stw_be_p, cpu_stw_be_data_ra) 513 DO_ST(st1h_le, uint16_t, stw_le_p, cpu_stw_le_data_ra) 514 DO_ST(st1s_be, uint32_t, stl_be_p, cpu_stl_be_data_ra) 515 DO_ST(st1s_le, uint32_t, stl_le_p, cpu_stl_le_data_ra) 516 DO_ST(st1d_be, uint64_t, stq_be_p, cpu_stq_be_data_ra) 517 DO_ST(st1d_le, uint64_t, stq_le_p, cpu_stq_le_data_ra) 518 519 DO_STQ(sve_st1qq_be, sme_st1q_be) 520 DO_STQ(sve_st1qq_le, sme_st1q_le) 521 522 #undef DO_LD 523 #undef DO_ST 524 #undef DO_LDQ 525 #undef DO_STQ 526 527 /* 528 * Common helper for all contiguous predicated loads. 529 */ 530 531 static inline QEMU_ALWAYS_INLINE 532 void sme_ld1(CPUARMState *env, void *za, uint64_t *vg, 533 const target_ulong addr, uint32_t desc, const uintptr_t ra, 534 const int esz, uint32_t mtedesc, bool vertical, 535 sve_ldst1_host_fn *host_fn, 536 sve_ldst1_tlb_fn *tlb_fn, 537 ClearFn *clr_fn, 538 CopyFn *cpy_fn) 539 { 540 const intptr_t reg_max = simd_oprsz(desc); 541 const intptr_t esize = 1 << esz; 542 intptr_t reg_off, reg_last; 543 SVEContLdSt info; 544 void *host; 545 int flags; 546 547 /* Find the active elements. */ 548 if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) { 549 /* The entire predicate was false; no load occurs. */ 550 clr_fn(za, 0, reg_max); 551 return; 552 } 553 554 /* Probe the page(s). Exit with exception for any invalid page. */ 555 sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_LOAD, ra); 556 557 /* Handle watchpoints for all active elements. */ 558 sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize, 559 BP_MEM_READ, ra); 560 561 /* 562 * Handle mte checks for all active elements. 563 * Since TBI must be set for MTE, !mtedesc => !mte_active. 564 */ 565 if (mtedesc) { 566 sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize, 567 mtedesc, ra); 568 } 569 570 flags = info.page[0].flags | info.page[1].flags; 571 if (unlikely(flags != 0)) { 572 #ifdef CONFIG_USER_ONLY 573 g_assert_not_reached(); 574 #else 575 /* 576 * At least one page includes MMIO. 577 * Any bus operation can fail with cpu_transaction_failed, 578 * which for ARM will raise SyncExternal. Perform the load 579 * into scratch memory to preserve register state until the end. 580 */ 581 ARMVectorReg scratch = { }; 582 583 reg_off = info.reg_off_first[0]; 584 reg_last = info.reg_off_last[1]; 585 if (reg_last < 0) { 586 reg_last = info.reg_off_split; 587 if (reg_last < 0) { 588 reg_last = info.reg_off_last[0]; 589 } 590 } 591 592 do { 593 uint64_t pg = vg[reg_off >> 6]; 594 do { 595 if ((pg >> (reg_off & 63)) & 1) { 596 tlb_fn(env, &scratch, reg_off, addr + reg_off, ra); 597 } 598 reg_off += esize; 599 } while (reg_off & 63); 600 } while (reg_off <= reg_last); 601 602 cpy_fn(za, &scratch, reg_max); 603 return; 604 #endif 605 } 606 607 /* The entire operation is in RAM, on valid pages. */ 608 609 reg_off = info.reg_off_first[0]; 610 reg_last = info.reg_off_last[0]; 611 host = info.page[0].host; 612 613 if (!vertical) { 614 memset(za, 0, reg_max); 615 } else if (reg_off) { 616 clr_fn(za, 0, reg_off); 617 } 618 619 set_helper_retaddr(ra); 620 621 while (reg_off <= reg_last) { 622 uint64_t pg = vg[reg_off >> 6]; 623 do { 624 if ((pg >> (reg_off & 63)) & 1) { 625 host_fn(za, reg_off, host + reg_off); 626 } else if (vertical) { 627 clr_fn(za, reg_off, esize); 628 } 629 reg_off += esize; 630 } while (reg_off <= reg_last && (reg_off & 63)); 631 } 632 633 clear_helper_retaddr(); 634 635 /* 636 * Use the slow path to manage the cross-page misalignment. 637 * But we know this is RAM and cannot trap. 638 */ 639 reg_off = info.reg_off_split; 640 if (unlikely(reg_off >= 0)) { 641 tlb_fn(env, za, reg_off, addr + reg_off, ra); 642 } 643 644 reg_off = info.reg_off_first[1]; 645 if (unlikely(reg_off >= 0)) { 646 reg_last = info.reg_off_last[1]; 647 host = info.page[1].host; 648 649 set_helper_retaddr(ra); 650 651 do { 652 uint64_t pg = vg[reg_off >> 6]; 653 do { 654 if ((pg >> (reg_off & 63)) & 1) { 655 host_fn(za, reg_off, host + reg_off); 656 } else if (vertical) { 657 clr_fn(za, reg_off, esize); 658 } 659 reg_off += esize; 660 } while (reg_off & 63); 661 } while (reg_off <= reg_last); 662 663 clear_helper_retaddr(); 664 } 665 } 666 667 static inline QEMU_ALWAYS_INLINE 668 void sme_ld1_mte(CPUARMState *env, void *za, uint64_t *vg, 669 target_ulong addr, uint32_t desc, uintptr_t ra, 670 const int esz, bool vertical, 671 sve_ldst1_host_fn *host_fn, 672 sve_ldst1_tlb_fn *tlb_fn, 673 ClearFn *clr_fn, 674 CopyFn *cpy_fn) 675 { 676 uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT); 677 int bit55 = extract64(addr, 55, 1); 678 679 /* Remove mtedesc from the normal sve descriptor. */ 680 desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT); 681 682 /* Perform gross MTE suppression early. */ 683 if (!tbi_check(mtedesc, bit55) || 684 tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) { 685 mtedesc = 0; 686 } 687 688 sme_ld1(env, za, vg, addr, desc, ra, esz, mtedesc, vertical, 689 host_fn, tlb_fn, clr_fn, cpy_fn); 690 } 691 692 #define DO_LD(L, END, ESZ) \ 693 void HELPER(sme_ld1##L##END##_h)(CPUARMState *env, void *za, void *vg, \ 694 target_ulong addr, uint32_t desc) \ 695 { \ 696 sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false, \ 697 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb, \ 698 clear_horizontal, copy_horizontal); \ 699 } \ 700 void HELPER(sme_ld1##L##END##_v)(CPUARMState *env, void *za, void *vg, \ 701 target_ulong addr, uint32_t desc) \ 702 { \ 703 sme_ld1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true, \ 704 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb, \ 705 clear_vertical_##L, copy_vertical_##L); \ 706 } \ 707 void HELPER(sme_ld1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \ 708 target_ulong addr, uint32_t desc) \ 709 { \ 710 sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false, \ 711 sve_ld1##L##L##END##_host, sve_ld1##L##L##END##_tlb, \ 712 clear_horizontal, copy_horizontal); \ 713 } \ 714 void HELPER(sme_ld1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \ 715 target_ulong addr, uint32_t desc) \ 716 { \ 717 sme_ld1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true, \ 718 sme_ld1##L##END##_v_host, sme_ld1##L##END##_v_tlb, \ 719 clear_vertical_##L, copy_vertical_##L); \ 720 } 721 722 DO_LD(b, , MO_8) 723 DO_LD(h, _be, MO_16) 724 DO_LD(h, _le, MO_16) 725 DO_LD(s, _be, MO_32) 726 DO_LD(s, _le, MO_32) 727 DO_LD(d, _be, MO_64) 728 DO_LD(d, _le, MO_64) 729 DO_LD(q, _be, MO_128) 730 DO_LD(q, _le, MO_128) 731 732 #undef DO_LD 733 734 /* 735 * Common helper for all contiguous predicated stores. 736 */ 737 738 static inline QEMU_ALWAYS_INLINE 739 void sme_st1(CPUARMState *env, void *za, uint64_t *vg, 740 const target_ulong addr, uint32_t desc, const uintptr_t ra, 741 const int esz, uint32_t mtedesc, bool vertical, 742 sve_ldst1_host_fn *host_fn, 743 sve_ldst1_tlb_fn *tlb_fn) 744 { 745 const intptr_t reg_max = simd_oprsz(desc); 746 const intptr_t esize = 1 << esz; 747 intptr_t reg_off, reg_last; 748 SVEContLdSt info; 749 void *host; 750 int flags; 751 752 /* Find the active elements. */ 753 if (!sve_cont_ldst_elements(&info, addr, vg, reg_max, esz, esize)) { 754 /* The entire predicate was false; no store occurs. */ 755 return; 756 } 757 758 /* Probe the page(s). Exit with exception for any invalid page. */ 759 sve_cont_ldst_pages(&info, FAULT_ALL, env, addr, MMU_DATA_STORE, ra); 760 761 /* Handle watchpoints for all active elements. */ 762 sve_cont_ldst_watchpoints(&info, env, vg, addr, esize, esize, 763 BP_MEM_WRITE, ra); 764 765 /* 766 * Handle mte checks for all active elements. 767 * Since TBI must be set for MTE, !mtedesc => !mte_active. 768 */ 769 if (mtedesc) { 770 sve_cont_ldst_mte_check(&info, env, vg, addr, esize, esize, 771 mtedesc, ra); 772 } 773 774 flags = info.page[0].flags | info.page[1].flags; 775 if (unlikely(flags != 0)) { 776 #ifdef CONFIG_USER_ONLY 777 g_assert_not_reached(); 778 #else 779 /* 780 * At least one page includes MMIO. 781 * Any bus operation can fail with cpu_transaction_failed, 782 * which for ARM will raise SyncExternal. We cannot avoid 783 * this fault and will leave with the store incomplete. 784 */ 785 reg_off = info.reg_off_first[0]; 786 reg_last = info.reg_off_last[1]; 787 if (reg_last < 0) { 788 reg_last = info.reg_off_split; 789 if (reg_last < 0) { 790 reg_last = info.reg_off_last[0]; 791 } 792 } 793 794 do { 795 uint64_t pg = vg[reg_off >> 6]; 796 do { 797 if ((pg >> (reg_off & 63)) & 1) { 798 tlb_fn(env, za, reg_off, addr + reg_off, ra); 799 } 800 reg_off += esize; 801 } while (reg_off & 63); 802 } while (reg_off <= reg_last); 803 return; 804 #endif 805 } 806 807 reg_off = info.reg_off_first[0]; 808 reg_last = info.reg_off_last[0]; 809 host = info.page[0].host; 810 811 set_helper_retaddr(ra); 812 813 while (reg_off <= reg_last) { 814 uint64_t pg = vg[reg_off >> 6]; 815 do { 816 if ((pg >> (reg_off & 63)) & 1) { 817 host_fn(za, reg_off, host + reg_off); 818 } 819 reg_off += 1 << esz; 820 } while (reg_off <= reg_last && (reg_off & 63)); 821 } 822 823 clear_helper_retaddr(); 824 825 /* 826 * Use the slow path to manage the cross-page misalignment. 827 * But we know this is RAM and cannot trap. 828 */ 829 reg_off = info.reg_off_split; 830 if (unlikely(reg_off >= 0)) { 831 tlb_fn(env, za, reg_off, addr + reg_off, ra); 832 } 833 834 reg_off = info.reg_off_first[1]; 835 if (unlikely(reg_off >= 0)) { 836 reg_last = info.reg_off_last[1]; 837 host = info.page[1].host; 838 839 set_helper_retaddr(ra); 840 841 do { 842 uint64_t pg = vg[reg_off >> 6]; 843 do { 844 if ((pg >> (reg_off & 63)) & 1) { 845 host_fn(za, reg_off, host + reg_off); 846 } 847 reg_off += 1 << esz; 848 } while (reg_off & 63); 849 } while (reg_off <= reg_last); 850 851 clear_helper_retaddr(); 852 } 853 } 854 855 static inline QEMU_ALWAYS_INLINE 856 void sme_st1_mte(CPUARMState *env, void *za, uint64_t *vg, target_ulong addr, 857 uint32_t desc, uintptr_t ra, int esz, bool vertical, 858 sve_ldst1_host_fn *host_fn, 859 sve_ldst1_tlb_fn *tlb_fn) 860 { 861 uint32_t mtedesc = desc >> (SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT); 862 int bit55 = extract64(addr, 55, 1); 863 864 /* Remove mtedesc from the normal sve descriptor. */ 865 desc = extract32(desc, 0, SIMD_DATA_SHIFT + SVE_MTEDESC_SHIFT); 866 867 /* Perform gross MTE suppression early. */ 868 if (!tbi_check(mtedesc, bit55) || 869 tcma_check(mtedesc, bit55, allocation_tag_from_addr(addr))) { 870 mtedesc = 0; 871 } 872 873 sme_st1(env, za, vg, addr, desc, ra, esz, mtedesc, 874 vertical, host_fn, tlb_fn); 875 } 876 877 #define DO_ST(L, END, ESZ) \ 878 void HELPER(sme_st1##L##END##_h)(CPUARMState *env, void *za, void *vg, \ 879 target_ulong addr, uint32_t desc) \ 880 { \ 881 sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, false, \ 882 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb); \ 883 } \ 884 void HELPER(sme_st1##L##END##_v)(CPUARMState *env, void *za, void *vg, \ 885 target_ulong addr, uint32_t desc) \ 886 { \ 887 sme_st1(env, za, vg, addr, desc, GETPC(), ESZ, 0, true, \ 888 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb); \ 889 } \ 890 void HELPER(sme_st1##L##END##_h_mte)(CPUARMState *env, void *za, void *vg, \ 891 target_ulong addr, uint32_t desc) \ 892 { \ 893 sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, false, \ 894 sve_st1##L##L##END##_host, sve_st1##L##L##END##_tlb); \ 895 } \ 896 void HELPER(sme_st1##L##END##_v_mte)(CPUARMState *env, void *za, void *vg, \ 897 target_ulong addr, uint32_t desc) \ 898 { \ 899 sme_st1_mte(env, za, vg, addr, desc, GETPC(), ESZ, true, \ 900 sme_st1##L##END##_v_host, sme_st1##L##END##_v_tlb); \ 901 } 902 903 DO_ST(b, , MO_8) 904 DO_ST(h, _be, MO_16) 905 DO_ST(h, _le, MO_16) 906 DO_ST(s, _be, MO_32) 907 DO_ST(s, _le, MO_32) 908 DO_ST(d, _be, MO_64) 909 DO_ST(d, _le, MO_64) 910 DO_ST(q, _be, MO_128) 911 DO_ST(q, _le, MO_128) 912 913 #undef DO_ST 914 915 void HELPER(sme_addha_s)(void *vzda, void *vzn, void *vpn, 916 void *vpm, uint32_t desc) 917 { 918 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 919 uint64_t *pn = vpn, *pm = vpm; 920 uint32_t *zda = vzda, *zn = vzn; 921 922 for (row = 0; row < oprsz; ) { 923 uint64_t pa = pn[row >> 4]; 924 do { 925 if (pa & 1) { 926 for (col = 0; col < oprsz; ) { 927 uint64_t pb = pm[col >> 4]; 928 do { 929 if (pb & 1) { 930 zda[tile_vslice_index(row) + H4(col)] += zn[H4(col)]; 931 } 932 pb >>= 4; 933 } while (++col & 15); 934 } 935 } 936 pa >>= 4; 937 } while (++row & 15); 938 } 939 } 940 941 void HELPER(sme_addha_d)(void *vzda, void *vzn, void *vpn, 942 void *vpm, uint32_t desc) 943 { 944 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 945 uint8_t *pn = vpn, *pm = vpm; 946 uint64_t *zda = vzda, *zn = vzn; 947 948 for (row = 0; row < oprsz; ++row) { 949 if (pn[H1(row)] & 1) { 950 for (col = 0; col < oprsz; ++col) { 951 if (pm[H1(col)] & 1) { 952 zda[tile_vslice_index(row) + col] += zn[col]; 953 } 954 } 955 } 956 } 957 } 958 959 void HELPER(sme_addva_s)(void *vzda, void *vzn, void *vpn, 960 void *vpm, uint32_t desc) 961 { 962 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 963 uint64_t *pn = vpn, *pm = vpm; 964 uint32_t *zda = vzda, *zn = vzn; 965 966 for (row = 0; row < oprsz; ) { 967 uint64_t pa = pn[row >> 4]; 968 do { 969 if (pa & 1) { 970 uint32_t zn_row = zn[H4(row)]; 971 for (col = 0; col < oprsz; ) { 972 uint64_t pb = pm[col >> 4]; 973 do { 974 if (pb & 1) { 975 zda[tile_vslice_index(row) + H4(col)] += zn_row; 976 } 977 pb >>= 4; 978 } while (++col & 15); 979 } 980 } 981 pa >>= 4; 982 } while (++row & 15); 983 } 984 } 985 986 void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn, 987 void *vpm, uint32_t desc) 988 { 989 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 990 uint8_t *pn = vpn, *pm = vpm; 991 uint64_t *zda = vzda, *zn = vzn; 992 993 for (row = 0; row < oprsz; ++row) { 994 if (pn[H1(row)] & 1) { 995 uint64_t zn_row = zn[row]; 996 for (col = 0; col < oprsz; ++col) { 997 if (pm[H1(col)] & 1) { 998 zda[tile_vslice_index(row) + col] += zn_row; 999 } 1000 } 1001 } 1002 } 1003 } 1004 1005 static void do_fmopa_h(void *vza, void *vzn, void *vzm, uint16_t *pn, 1006 uint16_t *pm, float_status *fpst, uint32_t desc, 1007 uint16_t negx, int negf) 1008 { 1009 intptr_t row, col, oprsz = simd_maxsz(desc); 1010 1011 for (row = 0; row < oprsz; ) { 1012 uint16_t pa = pn[H2(row >> 4)]; 1013 do { 1014 if (pa & 1) { 1015 void *vza_row = vza + tile_vslice_offset(row); 1016 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx; 1017 1018 for (col = 0; col < oprsz; ) { 1019 uint16_t pb = pm[H2(col >> 4)]; 1020 do { 1021 if (pb & 1) { 1022 uint16_t *a = vza_row + H1_2(col); 1023 uint16_t *m = vzm + H1_2(col); 1024 *a = float16_muladd(n, *m, *a, negf, fpst); 1025 } 1026 col += 2; 1027 pb >>= 2; 1028 } while (col & 15); 1029 } 1030 } 1031 row += 2; 1032 pa >>= 2; 1033 } while (row & 15); 1034 } 1035 } 1036 1037 void HELPER(sme_fmopa_h)(void *vza, void *vzn, void *vzm, void *vpn, 1038 void *vpm, float_status *fpst, uint32_t desc) 1039 { 1040 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1041 } 1042 1043 void HELPER(sme_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn, 1044 void *vpm, float_status *fpst, uint32_t desc) 1045 { 1046 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0); 1047 } 1048 1049 void HELPER(sme_ah_fmops_h)(void *vza, void *vzn, void *vzm, void *vpn, 1050 void *vpm, float_status *fpst, uint32_t desc) 1051 { 1052 do_fmopa_h(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1053 float_muladd_negate_product); 1054 } 1055 1056 static void do_fmopa_s(void *vza, void *vzn, void *vzm, uint16_t *pn, 1057 uint16_t *pm, float_status *fpst, uint32_t desc, 1058 uint32_t negx, int negf) 1059 { 1060 intptr_t row, col, oprsz = simd_maxsz(desc); 1061 1062 for (row = 0; row < oprsz; ) { 1063 uint16_t pa = pn[H2(row >> 4)]; 1064 do { 1065 if (pa & 1) { 1066 void *vza_row = vza + tile_vslice_offset(row); 1067 uint32_t n = *(uint32_t *)(vzn + H1_4(row)) ^ negx; 1068 1069 for (col = 0; col < oprsz; ) { 1070 uint16_t pb = pm[H2(col >> 4)]; 1071 do { 1072 if (pb & 1) { 1073 uint32_t *a = vza_row + H1_4(col); 1074 uint32_t *m = vzm + H1_4(col); 1075 *a = float32_muladd(n, *m, *a, negf, fpst); 1076 } 1077 col += 4; 1078 pb >>= 4; 1079 } while (col & 15); 1080 } 1081 } 1082 row += 4; 1083 pa >>= 4; 1084 } while (row & 15); 1085 } 1086 } 1087 1088 void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, 1089 void *vpm, float_status *fpst, uint32_t desc) 1090 { 1091 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1092 } 1093 1094 void HELPER(sme_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn, 1095 void *vpm, float_status *fpst, uint32_t desc) 1096 { 1097 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 31, 0); 1098 } 1099 1100 void HELPER(sme_ah_fmops_s)(void *vza, void *vzn, void *vzm, void *vpn, 1101 void *vpm, float_status *fpst, uint32_t desc) 1102 { 1103 do_fmopa_s(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1104 float_muladd_negate_product); 1105 } 1106 1107 static void do_fmopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, uint8_t *pn, 1108 uint8_t *pm, float_status *fpst, uint32_t desc, 1109 uint64_t negx, int negf) 1110 { 1111 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 1112 1113 for (row = 0; row < oprsz; ++row) { 1114 if (pn[H1(row)] & 1) { 1115 uint64_t *za_row = &za[tile_vslice_index(row)]; 1116 uint64_t n = zn[row] ^ negx; 1117 1118 for (col = 0; col < oprsz; ++col) { 1119 if (pm[H1(col)] & 1) { 1120 uint64_t *a = &za_row[col]; 1121 *a = float64_muladd(n, zm[col], *a, negf, fpst); 1122 } 1123 } 1124 } 1125 } 1126 } 1127 1128 void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn, 1129 void *vpm, float_status *fpst, uint32_t desc) 1130 { 1131 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1132 } 1133 1134 void HELPER(sme_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn, 1135 void *vpm, float_status *fpst, uint32_t desc) 1136 { 1137 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 1ull << 63, 0); 1138 } 1139 1140 void HELPER(sme_ah_fmops_d)(void *vza, void *vzn, void *vzm, void *vpn, 1141 void *vpm, float_status *fpst, uint32_t desc) 1142 { 1143 do_fmopa_d(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1144 float_muladd_negate_product); 1145 } 1146 1147 static void do_bfmopa(void *vza, void *vzn, void *vzm, uint16_t *pn, 1148 uint16_t *pm, float_status *fpst, uint32_t desc, 1149 uint16_t negx, int negf) 1150 { 1151 intptr_t row, col, oprsz = simd_maxsz(desc); 1152 1153 for (row = 0; row < oprsz; ) { 1154 uint16_t pa = pn[H2(row >> 4)]; 1155 do { 1156 if (pa & 1) { 1157 void *vza_row = vza + tile_vslice_offset(row); 1158 uint16_t n = *(uint32_t *)(vzn + H1_2(row)) ^ negx; 1159 1160 for (col = 0; col < oprsz; ) { 1161 uint16_t pb = pm[H2(col >> 4)]; 1162 do { 1163 if (pb & 1) { 1164 uint16_t *a = vza_row + H1_2(col); 1165 uint16_t *m = vzm + H1_2(col); 1166 *a = bfloat16_muladd(n, *m, *a, negf, fpst); 1167 } 1168 col += 2; 1169 pb >>= 2; 1170 } while (col & 15); 1171 } 1172 } 1173 row += 2; 1174 pa >>= 2; 1175 } while (row & 15); 1176 } 1177 } 1178 1179 void HELPER(sme_bfmopa)(void *vza, void *vzn, void *vzm, void *vpn, 1180 void *vpm, float_status *fpst, uint32_t desc) 1181 { 1182 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 0); 1183 } 1184 1185 void HELPER(sme_bfmops)(void *vza, void *vzn, void *vzm, void *vpn, 1186 void *vpm, float_status *fpst, uint32_t desc) 1187 { 1188 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 1u << 15, 0); 1189 } 1190 1191 void HELPER(sme_ah_bfmops)(void *vza, void *vzn, void *vzm, void *vpn, 1192 void *vpm, float_status *fpst, uint32_t desc) 1193 { 1194 do_bfmopa(vza, vzn, vzm, vpn, vpm, fpst, desc, 0, 1195 float_muladd_negate_product); 1196 } 1197 1198 /* 1199 * Alter PAIR as needed for controlling predicates being false, 1200 * and for NEG on an enabled row element. 1201 */ 1202 static inline uint32_t f16mop_adj_pair(uint32_t pair, uint32_t pg, uint32_t neg) 1203 { 1204 /* 1205 * The pseudocode uses a conditional negate after the conditional zero. 1206 * It is simpler here to unconditionally negate before conditional zero. 1207 */ 1208 pair ^= neg; 1209 if (!(pg & 1)) { 1210 pair &= 0xffff0000u; 1211 } 1212 if (!(pg & 4)) { 1213 pair &= 0x0000ffffu; 1214 } 1215 return pair; 1216 } 1217 1218 static inline uint32_t f16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg) 1219 { 1220 uint32_t l = pg & 1 ? float16_ah_chs(pair) : 0; 1221 uint32_t h = pg & 4 ? float16_ah_chs(pair >> 16) : 0; 1222 return l | (h << 16); 1223 } 1224 1225 static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg) 1226 { 1227 uint32_t l = pg & 1 ? bfloat16_ah_chs(pair) : 0; 1228 uint32_t h = pg & 4 ? bfloat16_ah_chs(pair >> 16) : 0; 1229 return l | (h << 16); 1230 } 1231 1232 static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, 1233 float_status *s_f16, float_status *s_std, 1234 float_status *s_odd) 1235 { 1236 /* 1237 * We need three different float_status for different parts of this 1238 * operation: 1239 * - the input conversion of the float16 values must use the 1240 * f16-specific float_status, so that the FPCR.FZ16 control is applied 1241 * - operations on float32 including the final accumulation must use 1242 * the normal float_status, so that FPCR.FZ is applied 1243 * - we have pre-set-up copy of s_std which is set to round-to-odd, 1244 * for the multiply (see below) 1245 */ 1246 float16 h1r = e1 & 0xffff; 1247 float16 h1c = e1 >> 16; 1248 float16 h2r = e2 & 0xffff; 1249 float16 h2c = e2 >> 16; 1250 float32 t32; 1251 1252 /* C.f. FPProcessNaNs4 */ 1253 if (float16_is_any_nan(h1r) || float16_is_any_nan(h1c) || 1254 float16_is_any_nan(h2r) || float16_is_any_nan(h2c)) { 1255 float16 t16; 1256 1257 if (float16_is_signaling_nan(h1r, s_f16)) { 1258 t16 = h1r; 1259 } else if (float16_is_signaling_nan(h1c, s_f16)) { 1260 t16 = h1c; 1261 } else if (float16_is_signaling_nan(h2r, s_f16)) { 1262 t16 = h2r; 1263 } else if (float16_is_signaling_nan(h2c, s_f16)) { 1264 t16 = h2c; 1265 } else if (float16_is_any_nan(h1r)) { 1266 t16 = h1r; 1267 } else if (float16_is_any_nan(h1c)) { 1268 t16 = h1c; 1269 } else if (float16_is_any_nan(h2r)) { 1270 t16 = h2r; 1271 } else { 1272 t16 = h2c; 1273 } 1274 t32 = float16_to_float32(t16, true, s_f16); 1275 } else { 1276 float64 e1r = float16_to_float64(h1r, true, s_f16); 1277 float64 e1c = float16_to_float64(h1c, true, s_f16); 1278 float64 e2r = float16_to_float64(h2r, true, s_f16); 1279 float64 e2c = float16_to_float64(h2c, true, s_f16); 1280 float64 t64; 1281 1282 /* 1283 * The ARM pseudocode function FPDot performs both multiplies 1284 * and the add with a single rounding operation. Emulate this 1285 * by performing the first multiply in round-to-odd, then doing 1286 * the second multiply as fused multiply-add, and rounding to 1287 * float32 all in one step. 1288 */ 1289 t64 = float64_mul(e1r, e2r, s_odd); 1290 t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std); 1291 1292 /* This conversion is exact, because we've already rounded. */ 1293 t32 = float64_to_float32(t64, s_std); 1294 } 1295 1296 /* The final accumulation step is not fused. */ 1297 return float32_add(sum, t32, s_std); 1298 } 1299 1300 static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn, 1301 uint16_t *pm, CPUARMState *env, uint32_t desc, 1302 uint32_t negx, bool ah_neg) 1303 { 1304 intptr_t row, col, oprsz = simd_maxsz(desc); 1305 float_status fpst_odd = env->vfp.fp_status[FPST_ZA]; 1306 1307 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1308 1309 for (row = 0; row < oprsz; ) { 1310 uint16_t prow = pn[H2(row >> 4)]; 1311 do { 1312 void *vza_row = vza + tile_vslice_offset(row); 1313 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1314 1315 if (ah_neg) { 1316 n = f16mop_ah_neg_adj_pair(n, prow); 1317 } else { 1318 n = f16mop_adj_pair(n, prow, negx); 1319 } 1320 1321 for (col = 0; col < oprsz; ) { 1322 uint16_t pcol = pm[H2(col >> 4)]; 1323 do { 1324 if (prow & pcol & 0b0101) { 1325 uint32_t *a = vza_row + H1_4(col); 1326 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1327 1328 m = f16mop_adj_pair(m, pcol, 0); 1329 *a = f16_dotadd(*a, n, m, 1330 &env->vfp.fp_status[FPST_ZA_F16], 1331 &env->vfp.fp_status[FPST_ZA], 1332 &fpst_odd); 1333 } 1334 col += 4; 1335 pcol >>= 4; 1336 } while (col & 15); 1337 } 1338 row += 4; 1339 prow >>= 4; 1340 } while (row & 15); 1341 } 1342 } 1343 1344 void HELPER(sme_fmopa_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1345 void *vpm, CPUARMState *env, uint32_t desc) 1346 { 1347 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, false); 1348 } 1349 1350 void HELPER(sme_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1351 void *vpm, CPUARMState *env, uint32_t desc) 1352 { 1353 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false); 1354 } 1355 1356 void HELPER(sme_ah_fmops_w_h)(void *vza, void *vzn, void *vzm, void *vpn, 1357 void *vpm, CPUARMState *env, uint32_t desc) 1358 { 1359 do_fmopa_w_h(vza, vzn, vzm, vpn, vpm, env, desc, 0, true); 1360 } 1361 1362 void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va, 1363 CPUARMState *env, uint32_t desc) 1364 { 1365 intptr_t i, oprsz = simd_maxsz(desc); 1366 bool za = extract32(desc, SIMD_DATA_SHIFT, 1); 1367 float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; 1368 float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; 1369 float_status fpst_odd = *fpst_std; 1370 float32 *d = vd, *a = va; 1371 uint32_t *n = vn, *m = vm; 1372 1373 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1374 1375 for (i = 0; i < oprsz / sizeof(float32); ++i) { 1376 d[H4(i)] = f16_dotadd(a[H4(i)], n[H4(i)], m[H4(i)], 1377 fpst_f16, fpst_std, &fpst_odd); 1378 } 1379 } 1380 1381 void HELPER(sme2_fdot_idx_h)(void *vd, void *vn, void *vm, void *va, 1382 CPUARMState *env, uint32_t desc) 1383 { 1384 intptr_t i, j, oprsz = simd_maxsz(desc); 1385 intptr_t elements = oprsz / sizeof(float32); 1386 intptr_t eltspersegment = MIN(4, elements); 1387 int idx = extract32(desc, SIMD_DATA_SHIFT, 2); 1388 bool za = extract32(desc, SIMD_DATA_SHIFT + 2, 1); 1389 float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; 1390 float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; 1391 float_status fpst_odd = *fpst_std; 1392 float32 *d = vd, *a = va; 1393 uint32_t *n = vn, *m = (uint32_t *)vm + H4(idx); 1394 1395 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1396 1397 for (i = 0; i < elements; i += eltspersegment) { 1398 uint32_t mm = m[i]; 1399 for (j = 0; j < eltspersegment; ++j) { 1400 d[H4(i + j)] = f16_dotadd(a[H4(i + j)], n[H4(i + j)], mm, 1401 fpst_f16, fpst_std, &fpst_odd); 1402 } 1403 } 1404 } 1405 1406 void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va, 1407 CPUARMState *env, uint32_t desc) 1408 { 1409 intptr_t i, j, oprsz = simd_maxsz(desc); 1410 intptr_t elements = oprsz / sizeof(float32); 1411 intptr_t eltspersegment = MIN(4, elements); 1412 int idx = extract32(desc, SIMD_DATA_SHIFT, 2); 1413 int sel = extract32(desc, SIMD_DATA_SHIFT + 2, 1); 1414 float_status fpst_odd, *fpst_std, *fpst_f16; 1415 float32 *d = vd, *a = va; 1416 uint16_t *n0 = vn; 1417 uint16_t *n1 = vn + sizeof(ARMVectorReg); 1418 uint32_t *m = (uint32_t *)vm + H4(idx); 1419 1420 fpst_std = &env->vfp.fp_status[FPST_ZA]; 1421 fpst_f16 = &env->vfp.fp_status[FPST_ZA_F16]; 1422 fpst_odd = *fpst_std; 1423 set_float_rounding_mode(float_round_to_odd, &fpst_odd); 1424 1425 for (i = 0; i < elements; i += eltspersegment) { 1426 uint32_t mm = m[i]; 1427 for (j = 0; j < eltspersegment; ++j) { 1428 uint32_t nn = (n0[H2(2 * (i + j) + sel)]) 1429 | (n1[H2(2 * (i + j) + sel)] << 16); 1430 d[i + H4(j)] = f16_dotadd(a[i + H4(j)], nn, mm, 1431 fpst_f16, fpst_std, &fpst_odd); 1432 } 1433 } 1434 } 1435 1436 static void do_bfmopa_w(void *vza, void *vzn, void *vzm, 1437 uint16_t *pn, uint16_t *pm, CPUARMState *env, 1438 uint32_t desc, uint32_t negx, bool ah_neg) 1439 { 1440 intptr_t row, col, oprsz = simd_maxsz(desc); 1441 float_status fpst, fpst_odd; 1442 1443 if (is_ebf(env, &fpst, &fpst_odd)) { 1444 for (row = 0; row < oprsz; ) { 1445 uint16_t prow = pn[H2(row >> 4)]; 1446 do { 1447 void *vza_row = vza + tile_vslice_offset(row); 1448 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1449 1450 if (ah_neg) { 1451 n = bf16mop_ah_neg_adj_pair(n, prow); 1452 } else { 1453 n = f16mop_adj_pair(n, prow, negx); 1454 } 1455 1456 for (col = 0; col < oprsz; ) { 1457 uint16_t pcol = pm[H2(col >> 4)]; 1458 do { 1459 if (prow & pcol & 0b0101) { 1460 uint32_t *a = vza_row + H1_4(col); 1461 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1462 1463 m = f16mop_adj_pair(m, pcol, 0); 1464 *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd); 1465 } 1466 col += 4; 1467 pcol >>= 4; 1468 } while (col & 15); 1469 } 1470 row += 4; 1471 prow >>= 4; 1472 } while (row & 15); 1473 } 1474 } else { 1475 for (row = 0; row < oprsz; ) { 1476 uint16_t prow = pn[H2(row >> 4)]; 1477 do { 1478 void *vza_row = vza + tile_vslice_offset(row); 1479 uint32_t n = *(uint32_t *)(vzn + H1_4(row)); 1480 1481 if (ah_neg) { 1482 n = bf16mop_ah_neg_adj_pair(n, prow); 1483 } else { 1484 n = f16mop_adj_pair(n, prow, negx); 1485 } 1486 1487 for (col = 0; col < oprsz; ) { 1488 uint16_t pcol = pm[H2(col >> 4)]; 1489 do { 1490 if (prow & pcol & 0b0101) { 1491 uint32_t *a = vza_row + H1_4(col); 1492 uint32_t m = *(uint32_t *)(vzm + H1_4(col)); 1493 1494 m = f16mop_adj_pair(m, pcol, 0); 1495 *a = bfdotadd(*a, n, m, &fpst); 1496 } 1497 col += 4; 1498 pcol >>= 4; 1499 } while (col & 15); 1500 } 1501 row += 4; 1502 prow >>= 4; 1503 } while (row & 15); 1504 } 1505 } 1506 } 1507 1508 void HELPER(sme_bfmopa_w)(void *vza, void *vzn, void *vzm, void *vpn, 1509 void *vpm, CPUARMState *env, uint32_t desc) 1510 { 1511 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, false); 1512 } 1513 1514 void HELPER(sme_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn, 1515 void *vpm, CPUARMState *env, uint32_t desc) 1516 { 1517 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0x80008000u, false); 1518 } 1519 1520 void HELPER(sme_ah_bfmops_w)(void *vza, void *vzn, void *vzm, void *vpn, 1521 void *vpm, CPUARMState *env, uint32_t desc) 1522 { 1523 do_bfmopa_w(vza, vzn, vzm, vpn, vpm, env, desc, 0, true); 1524 } 1525 1526 typedef uint32_t IMOPFn32(uint32_t, uint32_t, uint32_t, uint8_t, bool); 1527 static inline void do_imopa_s(uint32_t *za, uint32_t *zn, uint32_t *zm, 1528 uint8_t *pn, uint8_t *pm, 1529 uint32_t desc, IMOPFn32 *fn) 1530 { 1531 intptr_t row, col, oprsz = simd_oprsz(desc) / 4; 1532 bool neg = simd_data(desc); 1533 1534 for (row = 0; row < oprsz; ++row) { 1535 uint8_t pa = (pn[H1(row >> 1)] >> ((row & 1) * 4)) & 0xf; 1536 uint32_t *za_row = &za[tile_vslice_index(row)]; 1537 uint32_t n = zn[H4(row)]; 1538 1539 for (col = 0; col < oprsz; ++col) { 1540 uint8_t pb = pm[H1(col >> 1)] >> ((col & 1) * 4); 1541 uint32_t *a = &za_row[H4(col)]; 1542 1543 *a = fn(n, zm[H4(col)], *a, pa & pb, neg); 1544 } 1545 } 1546 } 1547 1548 typedef uint64_t IMOPFn64(uint64_t, uint64_t, uint64_t, uint8_t, bool); 1549 static inline void do_imopa_d(uint64_t *za, uint64_t *zn, uint64_t *zm, 1550 uint8_t *pn, uint8_t *pm, 1551 uint32_t desc, IMOPFn64 *fn) 1552 { 1553 intptr_t row, col, oprsz = simd_oprsz(desc) / 8; 1554 bool neg = simd_data(desc); 1555 1556 for (row = 0; row < oprsz; ++row) { 1557 uint8_t pa = pn[H1(row)]; 1558 uint64_t *za_row = &za[tile_vslice_index(row)]; 1559 uint64_t n = zn[row]; 1560 1561 for (col = 0; col < oprsz; ++col) { 1562 uint8_t pb = pm[H1(col)]; 1563 uint64_t *a = &za_row[col]; 1564 1565 *a = fn(n, zm[col], *a, pa & pb, neg); 1566 } 1567 } 1568 } 1569 1570 #define DEF_IMOP_8x4_32(NAME, NTYPE, MTYPE) \ 1571 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \ 1572 { \ 1573 uint32_t sum = 0; \ 1574 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1575 n &= expand_pred_b(p); \ 1576 sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1577 sum += (NTYPE)(n >> 8) * (MTYPE)(m >> 8); \ 1578 sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1579 sum += (NTYPE)(n >> 24) * (MTYPE)(m >> 24); \ 1580 return neg ? a - sum : a + sum; \ 1581 } 1582 1583 #define DEF_IMOP_16x4_64(NAME, NTYPE, MTYPE) \ 1584 static uint64_t NAME(uint64_t n, uint64_t m, uint64_t a, uint8_t p, bool neg) \ 1585 { \ 1586 uint64_t sum = 0; \ 1587 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1588 n &= expand_pred_h(p); \ 1589 sum += (int64_t)(NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1590 sum += (int64_t)(NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1591 sum += (int64_t)(NTYPE)(n >> 32) * (MTYPE)(m >> 32); \ 1592 sum += (int64_t)(NTYPE)(n >> 48) * (MTYPE)(m >> 48); \ 1593 return neg ? a - sum : a + sum; \ 1594 } 1595 1596 DEF_IMOP_8x4_32(smopa_s, int8_t, int8_t) 1597 DEF_IMOP_8x4_32(umopa_s, uint8_t, uint8_t) 1598 DEF_IMOP_8x4_32(sumopa_s, int8_t, uint8_t) 1599 DEF_IMOP_8x4_32(usmopa_s, uint8_t, int8_t) 1600 1601 DEF_IMOP_16x4_64(smopa_d, int16_t, int16_t) 1602 DEF_IMOP_16x4_64(umopa_d, uint16_t, uint16_t) 1603 DEF_IMOP_16x4_64(sumopa_d, int16_t, uint16_t) 1604 DEF_IMOP_16x4_64(usmopa_d, uint16_t, int16_t) 1605 1606 #define DEF_IMOPH(P, NAME, S) \ 1607 void HELPER(P##_##NAME##_##S)(void *vza, void *vzn, void *vzm, \ 1608 void *vpn, void *vpm, uint32_t desc) \ 1609 { do_imopa_##S(vza, vzn, vzm, vpn, vpm, desc, NAME##_##S); } 1610 1611 DEF_IMOPH(sme, smopa, s) 1612 DEF_IMOPH(sme, umopa, s) 1613 DEF_IMOPH(sme, sumopa, s) 1614 DEF_IMOPH(sme, usmopa, s) 1615 1616 DEF_IMOPH(sme, smopa, d) 1617 DEF_IMOPH(sme, umopa, d) 1618 DEF_IMOPH(sme, sumopa, d) 1619 DEF_IMOPH(sme, usmopa, d) 1620 1621 static uint32_t bmopa_s(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) 1622 { 1623 uint32_t sum = ctpop32(~(n ^ m)); 1624 if (neg) { 1625 sum = -sum; 1626 } 1627 if (!(p & 1)) { 1628 sum = 0; 1629 } 1630 return a + sum; 1631 } 1632 1633 DEF_IMOPH(sme2, bmopa, s) 1634 1635 #define DEF_IMOP_16x2_32(NAME, NTYPE, MTYPE) \ 1636 static uint32_t NAME(uint32_t n, uint32_t m, uint32_t a, uint8_t p, bool neg) \ 1637 { \ 1638 uint32_t sum = 0; \ 1639 /* Apply P to N as a mask, making the inactive elements 0. */ \ 1640 n &= expand_pred_h(p); \ 1641 sum += (NTYPE)(n >> 0) * (MTYPE)(m >> 0); \ 1642 sum += (NTYPE)(n >> 16) * (MTYPE)(m >> 16); \ 1643 return neg ? a - sum : a + sum; \ 1644 } 1645 1646 DEF_IMOP_16x2_32(smopa2_s, int16_t, int16_t) 1647 DEF_IMOP_16x2_32(umopa2_s, uint16_t, uint16_t) 1648 1649 DEF_IMOPH(sme2, smopa2, s) 1650 DEF_IMOPH(sme2, umopa2, s) 1651 1652 #define DO_VDOT_IDX(NAME, TYPED, TYPEN, TYPEM, HD, HN) \ 1653 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 1654 { \ 1655 intptr_t svl = simd_oprsz(desc); \ 1656 intptr_t elements = svl / sizeof(TYPED); \ 1657 intptr_t eltperseg = 16 / sizeof(TYPED); \ 1658 intptr_t nreg = sizeof(TYPED) / sizeof(TYPEN); \ 1659 intptr_t vstride = (svl / nreg) * sizeof(ARMVectorReg); \ 1660 intptr_t zstride = sizeof(ARMVectorReg) / sizeof(TYPEN); \ 1661 intptr_t idx = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1662 TYPEN *n = vn; \ 1663 TYPEM *m = vm; \ 1664 for (intptr_t r = 0; r < nreg; r++) { \ 1665 TYPED *d = vd + r * vstride; \ 1666 for (intptr_t seg = 0; seg < elements; seg += eltperseg) { \ 1667 intptr_t s = seg + idx; \ 1668 for (intptr_t e = seg; e < seg + eltperseg; e++) { \ 1669 TYPED sum = d[HD(e)]; \ 1670 for (intptr_t i = 0; i < nreg; i++) { \ 1671 TYPED nn = n[i * zstride + HN(nreg * e + r)]; \ 1672 TYPED mm = m[HN(nreg * s + i)]; \ 1673 sum += nn * mm; \ 1674 } \ 1675 d[HD(e)] = sum; \ 1676 } \ 1677 } \ 1678 } \ 1679 } 1680 1681 DO_VDOT_IDX(sme2_svdot_idx_4b, int32_t, int8_t, int8_t, H4, H1) 1682 DO_VDOT_IDX(sme2_uvdot_idx_4b, uint32_t, uint8_t, uint8_t, H4, H1) 1683 DO_VDOT_IDX(sme2_suvdot_idx_4b, int32_t, int8_t, uint8_t, H4, H1) 1684 DO_VDOT_IDX(sme2_usvdot_idx_4b, int32_t, uint8_t, int8_t, H4, H1) 1685 1686 DO_VDOT_IDX(sme2_svdot_idx_4h, int64_t, int16_t, int16_t, H8, H2) 1687 DO_VDOT_IDX(sme2_uvdot_idx_4h, uint64_t, uint16_t, uint16_t, H8, H2) 1688 1689 DO_VDOT_IDX(sme2_svdot_idx_2h, int32_t, int16_t, int16_t, H4, H2) 1690 DO_VDOT_IDX(sme2_uvdot_idx_2h, uint32_t, uint16_t, uint16_t, H4, H2) 1691 1692 #undef DO_VDOT_IDX 1693 1694 #define DO_MLALL(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \ 1695 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \ 1696 { \ 1697 intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW); \ 1698 intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1699 TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm; \ 1700 for (intptr_t i = 0; i < elements; ++i) { \ 1701 TYPEW nn = n[HN(i * 4 + sel)]; \ 1702 TYPEM mm = m[HN(i * 4 + sel)]; \ 1703 d[HW(i)] = a[HW(i)] OP (nn * mm); \ 1704 } \ 1705 } 1706 1707 DO_MLALL(sme2_smlall_s, int32_t, int8_t, int8_t, H4, H1, +) 1708 DO_MLALL(sme2_smlall_d, int64_t, int16_t, int16_t, H8, H2, +) 1709 DO_MLALL(sme2_smlsll_s, int32_t, int8_t, int8_t, H4, H1, -) 1710 DO_MLALL(sme2_smlsll_d, int64_t, int16_t, int16_t, H8, H2, -) 1711 1712 DO_MLALL(sme2_umlall_s, uint32_t, uint8_t, uint8_t, H4, H1, +) 1713 DO_MLALL(sme2_umlall_d, uint64_t, uint16_t, uint16_t, H8, H2, +) 1714 DO_MLALL(sme2_umlsll_s, uint32_t, uint8_t, uint8_t, H4, H1, -) 1715 DO_MLALL(sme2_umlsll_d, uint64_t, uint16_t, uint16_t, H8, H2, -) 1716 1717 DO_MLALL(sme2_usmlall_s, uint32_t, uint8_t, int8_t, H4, H1, +) 1718 1719 #undef DO_MLALL 1720 1721 #define DO_MLALL_IDX(NAME, TYPEW, TYPEN, TYPEM, HW, HN, OP) \ 1722 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \ 1723 { \ 1724 intptr_t elements = simd_oprsz(desc) / sizeof(TYPEW); \ 1725 intptr_t eltspersegment = 16 / sizeof(TYPEW); \ 1726 intptr_t sel = extract32(desc, SIMD_DATA_SHIFT, 2); \ 1727 intptr_t idx = extract32(desc, SIMD_DATA_SHIFT + 2, 4); \ 1728 TYPEW *d = vd, *a = va; TYPEN *n = vn; TYPEM *m = vm; \ 1729 for (intptr_t i = 0; i < elements; i += eltspersegment) { \ 1730 TYPEW mm = m[HN(i * 4 + idx)]; \ 1731 for (intptr_t j = 0; j < eltspersegment; ++j) { \ 1732 TYPEN nn = n[HN((i + j) * 4 + sel)]; \ 1733 d[HW(i + j)] = a[HW(i + j)] OP (nn * mm); \ 1734 } \ 1735 } \ 1736 } 1737 1738 DO_MLALL_IDX(sme2_smlall_idx_s, int32_t, int8_t, int8_t, H4, H1, +) 1739 DO_MLALL_IDX(sme2_smlall_idx_d, int64_t, int16_t, int16_t, H8, H2, +) 1740 DO_MLALL_IDX(sme2_smlsll_idx_s, int32_t, int8_t, int8_t, H4, H1, -) 1741 DO_MLALL_IDX(sme2_smlsll_idx_d, int64_t, int16_t, int16_t, H8, H2, -) 1742 1743 DO_MLALL_IDX(sme2_umlall_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, +) 1744 DO_MLALL_IDX(sme2_umlall_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, +) 1745 DO_MLALL_IDX(sme2_umlsll_idx_s, uint32_t, uint8_t, uint8_t, H4, H1, -) 1746 DO_MLALL_IDX(sme2_umlsll_idx_d, uint64_t, uint16_t, uint16_t, H8, H2, -) 1747 1748 DO_MLALL_IDX(sme2_usmlall_idx_s, uint32_t, uint8_t, int8_t, H4, H1, +) 1749 DO_MLALL_IDX(sme2_sumlall_idx_s, uint32_t, int8_t, uint8_t, H4, H1, +) 1750 1751 #undef DO_MLALL_IDX 1752 1753 /* Convert and compress */ 1754 void HELPER(sme2_bfcvt)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1755 { 1756 ARMVectorReg scratch; 1757 size_t oprsz = simd_oprsz(desc); 1758 size_t i, n = oprsz / 4; 1759 float32 *s0 = vs; 1760 float32 *s1 = vs + sizeof(ARMVectorReg); 1761 bfloat16 *d = vd; 1762 1763 if (vd == s1) { 1764 s1 = memcpy(&scratch, s1, oprsz); 1765 } 1766 1767 for (i = 0; i < n; ++i) { 1768 d[H2(i)] = float32_to_bfloat16(s0[H4(i)], fpst); 1769 } 1770 for (i = 0; i < n; ++i) { 1771 d[H2(i) + n] = float32_to_bfloat16(s1[H4(i)], fpst); 1772 } 1773 } 1774 1775 void HELPER(sme2_fcvt_n)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1776 { 1777 ARMVectorReg scratch; 1778 size_t oprsz = simd_oprsz(desc); 1779 size_t i, n = oprsz / 4; 1780 float32 *s0 = vs; 1781 float32 *s1 = vs + sizeof(ARMVectorReg); 1782 float16 *d = vd; 1783 1784 if (vd == s1) { 1785 s1 = memcpy(&scratch, s1, oprsz); 1786 } 1787 1788 for (i = 0; i < n; ++i) { 1789 d[H2(i)] = sve_f32_to_f16(s0[H4(i)], fpst); 1790 } 1791 for (i = 0; i < n; ++i) { 1792 d[H2(i) + n] = sve_f32_to_f16(s1[H4(i)], fpst); 1793 } 1794 } 1795 1796 #define SQCVT2(NAME, TW, TN, HW, HN, SAT) \ 1797 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1798 { \ 1799 ARMVectorReg scratch; \ 1800 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1801 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1802 TN *d = vd; \ 1803 if (vectors_overlap(vd, 1, vs, 2)) { \ 1804 d = (TN *)&scratch; \ 1805 } \ 1806 for (size_t i = 0; i < n; ++i) { \ 1807 d[HN(i)] = SAT(s0[HW(i)]); \ 1808 d[HN(i + n)] = SAT(s1[HW(i)]); \ 1809 } \ 1810 if (d != vd) { \ 1811 memcpy(vd, d, oprsz); \ 1812 } \ 1813 } 1814 1815 SQCVT2(sme2_sqcvt_sh, int32_t, int16_t, H4, H2, do_ssat_h) 1816 SQCVT2(sme2_uqcvt_sh, uint32_t, uint16_t, H4, H2, do_usat_h) 1817 SQCVT2(sme2_sqcvtu_sh, int32_t, uint16_t, H4, H2, do_usat_h) 1818 1819 #undef SQCVT2 1820 1821 #define SQCVT4(NAME, TW, TN, HW, HN, SAT) \ 1822 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1823 { \ 1824 ARMVectorReg scratch; \ 1825 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1826 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1827 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1828 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1829 TN *d = vd; \ 1830 if (vectors_overlap(vd, 1, vs, 4)) { \ 1831 d = (TN *)&scratch; \ 1832 } \ 1833 for (size_t i = 0; i < n; ++i) { \ 1834 d[HN(i)] = SAT(s0[HW(i)]); \ 1835 d[HN(i + n)] = SAT(s1[HW(i)]); \ 1836 d[HN(i + 2 * n)] = SAT(s2[HW(i)]); \ 1837 d[HN(i + 3 * n)] = SAT(s3[HW(i)]); \ 1838 } \ 1839 if (d != vd) { \ 1840 memcpy(vd, d, oprsz); \ 1841 } \ 1842 } 1843 1844 SQCVT4(sme2_sqcvt_sb, int32_t, int8_t, H4, H2, do_ssat_b) 1845 SQCVT4(sme2_uqcvt_sb, uint32_t, uint8_t, H4, H2, do_usat_b) 1846 SQCVT4(sme2_sqcvtu_sb, int32_t, uint8_t, H4, H2, do_usat_b) 1847 1848 SQCVT4(sme2_sqcvt_dh, int64_t, int16_t, H8, H2, do_ssat_h) 1849 SQCVT4(sme2_uqcvt_dh, uint64_t, uint16_t, H8, H2, do_usat_h) 1850 SQCVT4(sme2_sqcvtu_dh, int64_t, uint16_t, H8, H2, do_usat_h) 1851 1852 #undef SQCVT4 1853 1854 #define SQRSHR2(NAME, TW, TN, HW, HN, RSHR, SAT) \ 1855 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1856 { \ 1857 ARMVectorReg scratch; \ 1858 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1859 int shift = simd_data(desc); \ 1860 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1861 TN *d = vd; \ 1862 if (vectors_overlap(vd, 1, vs, 2)) { \ 1863 d = (TN *)&scratch; \ 1864 } \ 1865 for (size_t i = 0; i < n; ++i) { \ 1866 d[HN(i)] = SAT(RSHR(s0[HW(i)], shift)); \ 1867 d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift)); \ 1868 } \ 1869 if (d != vd) { \ 1870 memcpy(vd, d, oprsz); \ 1871 } \ 1872 } 1873 1874 SQRSHR2(sme2_sqrshr_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h) 1875 SQRSHR2(sme2_uqrshr_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h) 1876 SQRSHR2(sme2_sqrshru_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h) 1877 1878 #undef SQRSHR2 1879 1880 #define SQRSHR4(NAME, TW, TN, HW, HN, RSHR, SAT) \ 1881 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1882 { \ 1883 ARMVectorReg scratch; \ 1884 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1885 int shift = simd_data(desc); \ 1886 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1887 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1888 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1889 TN *d = vd; \ 1890 if (vectors_overlap(vd, 1, vs, 4)) { \ 1891 d = (TN *)&scratch; \ 1892 } \ 1893 for (size_t i = 0; i < n; ++i) { \ 1894 d[HN(i)] = SAT(RSHR(s0[HW(i)], shift)); \ 1895 d[HN(i + n)] = SAT(RSHR(s1[HW(i)], shift)); \ 1896 d[HN(i + 2 * n)] = SAT(RSHR(s2[HW(i)], shift)); \ 1897 d[HN(i + 3 * n)] = SAT(RSHR(s3[HW(i)], shift)); \ 1898 } \ 1899 if (d != vd) { \ 1900 memcpy(vd, d, oprsz); \ 1901 } \ 1902 } 1903 1904 SQRSHR4(sme2_sqrshr_sb, int32_t, int8_t, H4, H2, do_srshr, do_ssat_b) 1905 SQRSHR4(sme2_uqrshr_sb, uint32_t, uint8_t, H4, H2, do_urshr, do_usat_b) 1906 SQRSHR4(sme2_sqrshru_sb, int32_t, uint8_t, H4, H2, do_srshr, do_usat_b) 1907 1908 SQRSHR4(sme2_sqrshr_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h) 1909 SQRSHR4(sme2_uqrshr_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h) 1910 SQRSHR4(sme2_sqrshru_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h) 1911 1912 #undef SQRSHR4 1913 1914 /* Convert and interleave */ 1915 void HELPER(sme2_bfcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1916 { 1917 size_t i, n = simd_oprsz(desc) / 4; 1918 float32 *s0 = vs; 1919 float32 *s1 = vs + sizeof(ARMVectorReg); 1920 bfloat16 *d = vd; 1921 1922 for (i = 0; i < n; ++i) { 1923 bfloat16 d0 = float32_to_bfloat16(s0[H4(i)], fpst); 1924 bfloat16 d1 = float32_to_bfloat16(s1[H4(i)], fpst); 1925 d[H2(i * 2 + 0)] = d0; 1926 d[H2(i * 2 + 1)] = d1; 1927 } 1928 } 1929 1930 void HELPER(sme2_fcvtn)(void *vd, void *vs, float_status *fpst, uint32_t desc) 1931 { 1932 size_t i, n = simd_oprsz(desc) / 4; 1933 float32 *s0 = vs; 1934 float32 *s1 = vs + sizeof(ARMVectorReg); 1935 bfloat16 *d = vd; 1936 1937 for (i = 0; i < n; ++i) { 1938 bfloat16 d0 = sve_f32_to_f16(s0[H4(i)], fpst); 1939 bfloat16 d1 = sve_f32_to_f16(s1[H4(i)], fpst); 1940 d[H2(i * 2 + 0)] = d0; 1941 d[H2(i * 2 + 1)] = d1; 1942 } 1943 } 1944 1945 #define SQCVTN2(NAME, TW, TN, HW, HN, SAT) \ 1946 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1947 { \ 1948 ARMVectorReg scratch; \ 1949 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1950 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1951 TN *d = vd; \ 1952 if (vectors_overlap(vd, 1, vs, 2)) { \ 1953 d = (TN *)&scratch; \ 1954 } \ 1955 for (size_t i = 0; i < n; ++i) { \ 1956 d[HN(2 * i + 0)] = SAT(s0[HW(i)]); \ 1957 d[HN(2 * i + 1)] = SAT(s1[HW(i)]); \ 1958 } \ 1959 if (d != vd) { \ 1960 memcpy(vd, d, oprsz); \ 1961 } \ 1962 } 1963 1964 SQCVTN2(sme2_sqcvtn_sh, int32_t, int16_t, H4, H2, do_ssat_h) 1965 SQCVTN2(sme2_uqcvtn_sh, uint32_t, uint16_t, H4, H2, do_usat_h) 1966 SQCVTN2(sme2_sqcvtun_sh, int32_t, uint16_t, H4, H2, do_usat_h) 1967 1968 #undef SQCVTN2 1969 1970 #define SQCVTN4(NAME, TW, TN, HW, HN, SAT) \ 1971 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 1972 { \ 1973 ARMVectorReg scratch; \ 1974 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 1975 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 1976 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 1977 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 1978 TN *d = vd; \ 1979 if (vectors_overlap(vd, 1, vs, 4)) { \ 1980 d = (TN *)&scratch; \ 1981 } \ 1982 for (size_t i = 0; i < n; ++i) { \ 1983 d[HN(4 * i + 0)] = SAT(s0[HW(i)]); \ 1984 d[HN(4 * i + 1)] = SAT(s1[HW(i)]); \ 1985 d[HN(4 * i + 2)] = SAT(s2[HW(i)]); \ 1986 d[HN(4 * i + 3)] = SAT(s3[HW(i)]); \ 1987 } \ 1988 if (d != vd) { \ 1989 memcpy(vd, d, oprsz); \ 1990 } \ 1991 } 1992 1993 SQCVTN4(sme2_sqcvtn_sb, int32_t, int8_t, H4, H1, do_ssat_b) 1994 SQCVTN4(sme2_uqcvtn_sb, uint32_t, uint8_t, H4, H1, do_usat_b) 1995 SQCVTN4(sme2_sqcvtun_sb, int32_t, uint8_t, H4, H1, do_usat_b) 1996 1997 SQCVTN4(sme2_sqcvtn_dh, int64_t, int16_t, H8, H2, do_ssat_h) 1998 SQCVTN4(sme2_uqcvtn_dh, uint64_t, uint16_t, H8, H2, do_usat_h) 1999 SQCVTN4(sme2_sqcvtun_dh, int64_t, uint16_t, H8, H2, do_usat_h) 2000 2001 #undef SQCVTN4 2002 2003 #define SQRSHRN2(NAME, TW, TN, HW, HN, RSHR, SAT) \ 2004 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2005 { \ 2006 ARMVectorReg scratch; \ 2007 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 2008 int shift = simd_data(desc); \ 2009 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 2010 TN *d = vd; \ 2011 if (vectors_overlap(vd, 1, vs, 2)) { \ 2012 d = (TN *)&scratch; \ 2013 } \ 2014 for (size_t i = 0; i < n; ++i) { \ 2015 d[HN(2 * i + 0)] = SAT(RSHR(s0[HW(i)], shift)); \ 2016 d[HN(2 * i + 1)] = SAT(RSHR(s1[HW(i)], shift)); \ 2017 } \ 2018 if (d != vd) { \ 2019 memcpy(vd, d, oprsz); \ 2020 } \ 2021 } 2022 2023 SQRSHRN2(sme2_sqrshrn_sh, int32_t, int16_t, H4, H2, do_srshr, do_ssat_h) 2024 SQRSHRN2(sme2_uqrshrn_sh, uint32_t, uint16_t, H4, H2, do_urshr, do_usat_h) 2025 SQRSHRN2(sme2_sqrshrun_sh, int32_t, uint16_t, H4, H2, do_srshr, do_usat_h) 2026 2027 #undef SQRSHRN2 2028 2029 #define SQRSHRN4(NAME, TW, TN, HW, HN, RSHR, SAT) \ 2030 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2031 { \ 2032 ARMVectorReg scratch; \ 2033 size_t oprsz = simd_oprsz(desc), n = oprsz / sizeof(TW); \ 2034 int shift = simd_data(desc); \ 2035 TW *s0 = vs, *s1 = vs + sizeof(ARMVectorReg); \ 2036 TW *s2 = vs + 2 * sizeof(ARMVectorReg); \ 2037 TW *s3 = vs + 3 * sizeof(ARMVectorReg); \ 2038 TN *d = vd; \ 2039 if (vectors_overlap(vd, 1, vs, 4)) { \ 2040 d = (TN *)&scratch; \ 2041 } \ 2042 for (size_t i = 0; i < n; ++i) { \ 2043 d[HN(4 * i + 0)] = SAT(RSHR(s0[HW(i)], shift)); \ 2044 d[HN(4 * i + 1)] = SAT(RSHR(s1[HW(i)], shift)); \ 2045 d[HN(4 * i + 2)] = SAT(RSHR(s2[HW(i)], shift)); \ 2046 d[HN(4 * i + 3)] = SAT(RSHR(s3[HW(i)], shift)); \ 2047 } \ 2048 if (d != vd) { \ 2049 memcpy(vd, d, oprsz); \ 2050 } \ 2051 } 2052 2053 SQRSHRN4(sme2_sqrshrn_sb, int32_t, int8_t, H4, H1, do_srshr, do_ssat_b) 2054 SQRSHRN4(sme2_uqrshrn_sb, uint32_t, uint8_t, H4, H1, do_urshr, do_usat_b) 2055 SQRSHRN4(sme2_sqrshrun_sb, int32_t, uint8_t, H4, H1, do_srshr, do_usat_b) 2056 2057 SQRSHRN4(sme2_sqrshrn_dh, int64_t, int16_t, H8, H2, do_srshr, do_ssat_h) 2058 SQRSHRN4(sme2_uqrshrn_dh, uint64_t, uint16_t, H8, H2, do_urshr, do_usat_h) 2059 SQRSHRN4(sme2_sqrshrun_dh, int64_t, uint16_t, H8, H2, do_srshr, do_usat_h) 2060 2061 #undef SQRSHRN4 2062 2063 /* Expand and convert */ 2064 void HELPER(sme2_fcvt_w)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2065 { 2066 ARMVectorReg scratch; 2067 size_t oprsz = simd_oprsz(desc); 2068 size_t i, n = oprsz / 4; 2069 float16 *s = vs; 2070 float32 *d0 = vd; 2071 float32 *d1 = vd + sizeof(ARMVectorReg); 2072 2073 if (vectors_overlap(vd, 1, vs, 2)) { 2074 s = memcpy(&scratch, s, oprsz); 2075 } 2076 2077 for (i = 0; i < n; ++i) { 2078 d0[H4(i)] = sve_f16_to_f32(s[H2(i)], fpst); 2079 } 2080 for (i = 0; i < n; ++i) { 2081 d1[H4(i)] = sve_f16_to_f32(s[H2(n + i)], fpst); 2082 } 2083 } 2084 2085 #define UNPK(NAME, SREG, TW, TN, HW, HN) \ 2086 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2087 { \ 2088 ARMVectorReg scratch[SREG]; \ 2089 size_t oprsz = simd_oprsz(desc); \ 2090 size_t n = oprsz / sizeof(TW); \ 2091 if (vectors_overlap(vd, 2 * SREG, vs, SREG)) { \ 2092 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2093 } \ 2094 for (size_t r = 0; r < SREG; ++r) { \ 2095 TN *s = vs + r * sizeof(ARMVectorReg); \ 2096 for (size_t i = 0; i < 2; ++i) { \ 2097 TW *d = vd + (2 * r + i) * sizeof(ARMVectorReg); \ 2098 for (size_t e = 0; e < n; ++e) { \ 2099 d[HW(e)] = s[HN(i * n + e)]; \ 2100 } \ 2101 } \ 2102 } \ 2103 } 2104 2105 UNPK(sme2_sunpk2_bh, 1, int16_t, int8_t, H2, H1) 2106 UNPK(sme2_sunpk2_hs, 1, int32_t, int16_t, H4, H2) 2107 UNPK(sme2_sunpk2_sd, 1, int64_t, int32_t, H8, H4) 2108 2109 UNPK(sme2_sunpk4_bh, 2, int16_t, int8_t, H2, H1) 2110 UNPK(sme2_sunpk4_hs, 2, int32_t, int16_t, H4, H2) 2111 UNPK(sme2_sunpk4_sd, 2, int64_t, int32_t, H8, H4) 2112 2113 UNPK(sme2_uunpk2_bh, 1, uint16_t, uint8_t, H2, H1) 2114 UNPK(sme2_uunpk2_hs, 1, uint32_t, uint16_t, H4, H2) 2115 UNPK(sme2_uunpk2_sd, 1, uint64_t, uint32_t, H8, H4) 2116 2117 UNPK(sme2_uunpk4_bh, 2, uint16_t, uint8_t, H2, H1) 2118 UNPK(sme2_uunpk4_hs, 2, uint32_t, uint16_t, H4, H2) 2119 UNPK(sme2_uunpk4_sd, 2, uint64_t, uint32_t, H8, H4) 2120 2121 #undef UNPK 2122 2123 /* Deinterleave and convert. */ 2124 void HELPER(sme2_fcvtl)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2125 { 2126 size_t i, n = simd_oprsz(desc) / 4; 2127 float16 *s = vs; 2128 float32 *d0 = vd; 2129 float32 *d1 = vd + sizeof(ARMVectorReg); 2130 2131 for (i = 0; i < n; ++i) { 2132 float32 v0 = sve_f16_to_f32(s[H2(i * 2 + 0)], fpst); 2133 float32 v1 = sve_f16_to_f32(s[H2(i * 2 + 1)], fpst); 2134 d0[H4(i)] = v0; 2135 d1[H4(i)] = v1; 2136 } 2137 } 2138 2139 void HELPER(sme2_scvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2140 { 2141 size_t i, n = simd_oprsz(desc) / 4; 2142 int32_t *d = vd; 2143 float32 *s = vs; 2144 2145 for (i = 0; i < n; ++i) { 2146 d[i] = int32_to_float32(s[i], fpst); 2147 } 2148 } 2149 2150 void HELPER(sme2_ucvtf)(void *vd, void *vs, float_status *fpst, uint32_t desc) 2151 { 2152 size_t i, n = simd_oprsz(desc) / 4; 2153 uint32_t *d = vd; 2154 float32 *s = vs; 2155 2156 for (i = 0; i < n; ++i) { 2157 d[i] = uint32_to_float32(s[i], fpst); 2158 } 2159 } 2160 2161 #define ZIP2(NAME, TYPE, H) \ 2162 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2163 { \ 2164 ARMVectorReg scratch[2]; \ 2165 size_t oprsz = simd_oprsz(desc); \ 2166 size_t pairs = oprsz / (sizeof(TYPE) * 2); \ 2167 TYPE *n = vn, *m = vm; \ 2168 if (vectors_overlap(vd, 2, vn, 1)) { \ 2169 n = memcpy(&scratch[0], vn, oprsz); \ 2170 } \ 2171 if (vectors_overlap(vd, 2, vm, 1)) { \ 2172 m = memcpy(&scratch[1], vm, oprsz); \ 2173 } \ 2174 for (size_t r = 0; r < 2; ++r) { \ 2175 TYPE *d = vd + r * sizeof(ARMVectorReg); \ 2176 size_t base = r * pairs; \ 2177 for (size_t p = 0; p < pairs; ++p) { \ 2178 d[H(2 * p + 0)] = n[base + H(p)]; \ 2179 d[H(2 * p + 1)] = m[base + H(p)]; \ 2180 } \ 2181 } \ 2182 } 2183 2184 ZIP2(sme2_zip2_b, uint8_t, H1) 2185 ZIP2(sme2_zip2_h, uint16_t, H2) 2186 ZIP2(sme2_zip2_s, uint32_t, H4) 2187 ZIP2(sme2_zip2_d, uint64_t, ) 2188 ZIP2(sme2_zip2_q, Int128, ) 2189 2190 #undef ZIP2 2191 2192 #define ZIP4(NAME, TYPE, H) \ 2193 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2194 { \ 2195 ARMVectorReg scratch[4]; \ 2196 size_t oprsz = simd_oprsz(desc); \ 2197 size_t quads = oprsz / (sizeof(TYPE) * 4); \ 2198 TYPE *s0, *s1, *s2, *s3; \ 2199 if (vs == vd) { \ 2200 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2201 } \ 2202 s0 = vs; \ 2203 s1 = vs + sizeof(ARMVectorReg); \ 2204 s2 = vs + 2 * sizeof(ARMVectorReg); \ 2205 s3 = vs + 3 * sizeof(ARMVectorReg); \ 2206 for (size_t r = 0; r < 4; ++r) { \ 2207 TYPE *d = vd + r * sizeof(ARMVectorReg); \ 2208 size_t base = r * quads; \ 2209 for (size_t q = 0; q < quads; ++q) { \ 2210 d[H(4 * q + 0)] = s0[base + H(q)]; \ 2211 d[H(4 * q + 1)] = s1[base + H(q)]; \ 2212 d[H(4 * q + 2)] = s2[base + H(q)]; \ 2213 d[H(4 * q + 3)] = s3[base + H(q)]; \ 2214 } \ 2215 } \ 2216 } 2217 2218 ZIP4(sme2_zip4_b, uint8_t, H1) 2219 ZIP4(sme2_zip4_h, uint16_t, H2) 2220 ZIP4(sme2_zip4_s, uint32_t, H4) 2221 ZIP4(sme2_zip4_d, uint64_t, ) 2222 ZIP4(sme2_zip4_q, Int128, ) 2223 2224 #undef ZIP4 2225 2226 #define UZP2(NAME, TYPE, H) \ 2227 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2228 { \ 2229 ARMVectorReg scratch[2]; \ 2230 size_t oprsz = simd_oprsz(desc); \ 2231 size_t pairs = oprsz / (sizeof(TYPE) * 2); \ 2232 TYPE *d0 = vd, *d1 = vd + sizeof(ARMVectorReg); \ 2233 if (vectors_overlap(vd, 2, vn, 1)) { \ 2234 vn = memcpy(&scratch[0], vn, oprsz); \ 2235 } \ 2236 if (vectors_overlap(vd, 2, vm, 1)) { \ 2237 vm = memcpy(&scratch[1], vm, oprsz); \ 2238 } \ 2239 for (size_t r = 0; r < 2; ++r) { \ 2240 TYPE *s = r ? vm : vn; \ 2241 size_t base = r * pairs; \ 2242 for (size_t p = 0; p < pairs; ++p) { \ 2243 d0[base + H(p)] = s[H(2 * p + 0)]; \ 2244 d1[base + H(p)] = s[H(2 * p + 1)]; \ 2245 } \ 2246 } \ 2247 } 2248 2249 UZP2(sme2_uzp2_b, uint8_t, H1) 2250 UZP2(sme2_uzp2_h, uint16_t, H2) 2251 UZP2(sme2_uzp2_s, uint32_t, H4) 2252 UZP2(sme2_uzp2_d, uint64_t, ) 2253 UZP2(sme2_uzp2_q, Int128, ) 2254 2255 #undef UZP2 2256 2257 #define UZP4(NAME, TYPE, H) \ 2258 void HELPER(NAME)(void *vd, void *vs, uint32_t desc) \ 2259 { \ 2260 ARMVectorReg scratch[4]; \ 2261 size_t oprsz = simd_oprsz(desc); \ 2262 size_t quads = oprsz / (sizeof(TYPE) * 4); \ 2263 TYPE *d0, *d1, *d2, *d3; \ 2264 if (vs == vd) { \ 2265 vs = memcpy(scratch, vs, sizeof(scratch)); \ 2266 } \ 2267 d0 = vd; \ 2268 d1 = vd + sizeof(ARMVectorReg); \ 2269 d2 = vd + 2 * sizeof(ARMVectorReg); \ 2270 d3 = vd + 3 * sizeof(ARMVectorReg); \ 2271 for (size_t r = 0; r < 4; ++r) { \ 2272 TYPE *s = vs + r * sizeof(ARMVectorReg); \ 2273 size_t base = r * quads; \ 2274 for (size_t q = 0; q < quads; ++q) { \ 2275 d0[base + H(q)] = s[H(4 * q + 0)]; \ 2276 d1[base + H(q)] = s[H(4 * q + 1)]; \ 2277 d2[base + H(q)] = s[H(4 * q + 2)]; \ 2278 d3[base + H(q)] = s[H(4 * q + 3)]; \ 2279 } \ 2280 } \ 2281 } 2282 2283 UZP4(sme2_uzp4_b, uint8_t, H1) 2284 UZP4(sme2_uzp4_h, uint16_t, H2) 2285 UZP4(sme2_uzp4_s, uint32_t, H4) 2286 UZP4(sme2_uzp4_d, uint64_t, ) 2287 UZP4(sme2_uzp4_q, Int128, ) 2288 2289 #undef UZP4 2290 2291 #define ICLAMP(NAME, TYPE, H) \ 2292 void HELPER(NAME)(void *vd, void *vn, void *vm, uint32_t desc) \ 2293 { \ 2294 size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE); \ 2295 size_t elements = simd_oprsz(desc) / sizeof(TYPE); \ 2296 size_t nreg = simd_data(desc); \ 2297 TYPE *d = vd, *n = vn, *m = vm; \ 2298 for (size_t e = 0; e < elements; e++) { \ 2299 TYPE nn = n[H(e)], mm = m[H(e)]; \ 2300 for (size_t r = 0; r < nreg; r++) { \ 2301 TYPE *dd = &d[r * stride + H(e)]; \ 2302 *dd = MIN(MAX(*dd, nn), mm); \ 2303 } \ 2304 } \ 2305 } 2306 2307 ICLAMP(sme2_sclamp_b, int8_t, H1) 2308 ICLAMP(sme2_sclamp_h, int16_t, H2) 2309 ICLAMP(sme2_sclamp_s, int32_t, H4) 2310 ICLAMP(sme2_sclamp_d, int64_t, H8) 2311 2312 ICLAMP(sme2_uclamp_b, uint8_t, H1) 2313 ICLAMP(sme2_uclamp_h, uint16_t, H2) 2314 ICLAMP(sme2_uclamp_s, uint32_t, H4) 2315 ICLAMP(sme2_uclamp_d, uint64_t, H8) 2316 2317 #undef ICLAMP 2318 2319 /* 2320 * Note the argument ordering to minnum and maxnum must match 2321 * the ARM pseudocode so that NaNs are propagated properly. 2322 */ 2323 #define FCLAMP(NAME, TYPE, H) \ 2324 void HELPER(NAME)(void *vd, void *vn, void *vm, \ 2325 float_status *fpst, uint32_t desc) \ 2326 { \ 2327 size_t stride = sizeof(ARMVectorReg) / sizeof(TYPE); \ 2328 size_t elements = simd_oprsz(desc) / sizeof(TYPE); \ 2329 size_t nreg = simd_data(desc); \ 2330 TYPE *d = vd, *n = vn, *m = vm; \ 2331 for (size_t e = 0; e < elements; e++) { \ 2332 TYPE nn = n[H(e)], mm = m[H(e)]; \ 2333 for (size_t r = 0; r < nreg; r++) { \ 2334 TYPE *dd = &d[r * stride + H(e)]; \ 2335 *dd = TYPE##_minnum(TYPE##_maxnum(nn, *dd, fpst), mm, fpst); \ 2336 } \ 2337 } \ 2338 } 2339 2340 FCLAMP(sme2_fclamp_h, float16, H2) 2341 FCLAMP(sme2_fclamp_s, float32, H4) 2342 FCLAMP(sme2_fclamp_d, float64, H8) 2343 FCLAMP(sme2_bfclamp, bfloat16, H2) 2344 2345 #undef FCLAMP 2346 2347 void HELPER(sme2_sel_b)(void *vd, void *vn, void *vm, 2348 uint32_t png, uint32_t desc) 2349 { 2350 int vl = simd_oprsz(desc); 2351 int nreg = simd_data(desc); 2352 int elements = vl / sizeof(uint8_t); 2353 DecodeCounter p = decode_counter(png, vl, MO_8); 2354 2355 if (p.lg2_stride == 0) { 2356 if (p.invert) { 2357 for (int r = 0; r < nreg; r++) { 2358 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2359 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2360 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2361 int split = p.count - r * elements; 2362 2363 if (split <= 0) { 2364 memcpy(d, n, vl); /* all true */ 2365 } else if (elements <= split) { 2366 memcpy(d, m, vl); /* all false */ 2367 } else { 2368 for (int e = 0; e < split; e++) { 2369 d[H1(e)] = m[H1(e)]; 2370 } 2371 for (int e = split; e < elements; e++) { 2372 d[H1(e)] = n[H1(e)]; 2373 } 2374 } 2375 } 2376 } else { 2377 for (int r = 0; r < nreg; r++) { 2378 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2379 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2380 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2381 int split = p.count - r * elements; 2382 2383 if (split <= 0) { 2384 memcpy(d, m, vl); /* all false */ 2385 } else if (elements <= split) { 2386 memcpy(d, n, vl); /* all true */ 2387 } else { 2388 for (int e = 0; e < split; e++) { 2389 d[H1(e)] = n[H1(e)]; 2390 } 2391 for (int e = split; e < elements; e++) { 2392 d[H1(e)] = m[H1(e)]; 2393 } 2394 } 2395 } 2396 } 2397 } else { 2398 int estride = 1 << p.lg2_stride; 2399 if (p.invert) { 2400 for (int r = 0; r < nreg; r++) { 2401 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2402 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2403 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2404 int split = p.count - r * elements; 2405 int e = 0; 2406 2407 for (; e < MIN(split, elements); e++) { 2408 d[H1(e)] = m[H1(e)]; 2409 } 2410 for (; e < elements; e += estride) { 2411 d[H1(e)] = n[H1(e)]; 2412 for (int i = 1; i < estride; i++) { 2413 d[H1(e + i)] = m[H1(e + i)]; 2414 } 2415 } 2416 } 2417 } else { 2418 for (int r = 0; r < nreg; r++) { 2419 uint8_t *d = vd + r * sizeof(ARMVectorReg); 2420 uint8_t *n = vn + r * sizeof(ARMVectorReg); 2421 uint8_t *m = vm + r * sizeof(ARMVectorReg); 2422 int split = p.count - r * elements; 2423 int e = 0; 2424 2425 for (; e < MIN(split, elements); e += estride) { 2426 d[H1(e)] = n[H1(e)]; 2427 for (int i = 1; i < estride; i++) { 2428 d[H1(e + i)] = m[H1(e + i)]; 2429 } 2430 } 2431 for (; e < elements; e++) { 2432 d[H1(e)] = m[H1(e)]; 2433 } 2434 } 2435 } 2436 } 2437 } 2438 2439 void HELPER(sme2_sel_h)(void *vd, void *vn, void *vm, 2440 uint32_t png, uint32_t desc) 2441 { 2442 int vl = simd_oprsz(desc); 2443 int nreg = simd_data(desc); 2444 int elements = vl / sizeof(uint16_t); 2445 DecodeCounter p = decode_counter(png, vl, MO_16); 2446 2447 if (p.lg2_stride == 0) { 2448 if (p.invert) { 2449 for (int r = 0; r < nreg; r++) { 2450 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2451 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2452 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2453 int split = p.count - r * elements; 2454 2455 if (split <= 0) { 2456 memcpy(d, n, vl); /* all true */ 2457 } else if (elements <= split) { 2458 memcpy(d, m, vl); /* all false */ 2459 } else { 2460 for (int e = 0; e < split; e++) { 2461 d[H2(e)] = m[H2(e)]; 2462 } 2463 for (int e = split; e < elements; e++) { 2464 d[H2(e)] = n[H2(e)]; 2465 } 2466 } 2467 } 2468 } else { 2469 for (int r = 0; r < nreg; r++) { 2470 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2471 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2472 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2473 int split = p.count - r * elements; 2474 2475 if (split <= 0) { 2476 memcpy(d, m, vl); /* all false */ 2477 } else if (elements <= split) { 2478 memcpy(d, n, vl); /* all true */ 2479 } else { 2480 for (int e = 0; e < split; e++) { 2481 d[H2(e)] = n[H2(e)]; 2482 } 2483 for (int e = split; e < elements; e++) { 2484 d[H2(e)] = m[H2(e)]; 2485 } 2486 } 2487 } 2488 } 2489 } else { 2490 int estride = 1 << p.lg2_stride; 2491 if (p.invert) { 2492 for (int r = 0; r < nreg; r++) { 2493 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2494 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2495 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2496 int split = p.count - r * elements; 2497 int e = 0; 2498 2499 for (; e < MIN(split, elements); e++) { 2500 d[H2(e)] = m[H2(e)]; 2501 } 2502 for (; e < elements; e += estride) { 2503 d[H2(e)] = n[H2(e)]; 2504 for (int i = 1; i < estride; i++) { 2505 d[H2(e + i)] = m[H2(e + i)]; 2506 } 2507 } 2508 } 2509 } else { 2510 for (int r = 0; r < nreg; r++) { 2511 uint16_t *d = vd + r * sizeof(ARMVectorReg); 2512 uint16_t *n = vn + r * sizeof(ARMVectorReg); 2513 uint16_t *m = vm + r * sizeof(ARMVectorReg); 2514 int split = p.count - r * elements; 2515 int e = 0; 2516 2517 for (; e < MIN(split, elements); e += estride) { 2518 d[H2(e)] = n[H2(e)]; 2519 for (int i = 1; i < estride; i++) { 2520 d[H2(e + i)] = m[H2(e + i)]; 2521 } 2522 } 2523 for (; e < elements; e++) { 2524 d[H2(e)] = m[H2(e)]; 2525 } 2526 } 2527 } 2528 } 2529 } 2530 2531 void HELPER(sme2_sel_s)(void *vd, void *vn, void *vm, 2532 uint32_t png, uint32_t desc) 2533 { 2534 int vl = simd_oprsz(desc); 2535 int nreg = simd_data(desc); 2536 int elements = vl / sizeof(uint32_t); 2537 DecodeCounter p = decode_counter(png, vl, MO_32); 2538 2539 if (p.lg2_stride == 0) { 2540 if (p.invert) { 2541 for (int r = 0; r < nreg; r++) { 2542 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2543 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2544 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2545 int split = p.count - r * elements; 2546 2547 if (split <= 0) { 2548 memcpy(d, n, vl); /* all true */ 2549 } else if (elements <= split) { 2550 memcpy(d, m, vl); /* all false */ 2551 } else { 2552 for (int e = 0; e < split; e++) { 2553 d[H4(e)] = m[H4(e)]; 2554 } 2555 for (int e = split; e < elements; e++) { 2556 d[H4(e)] = n[H4(e)]; 2557 } 2558 } 2559 } 2560 } else { 2561 for (int r = 0; r < nreg; r++) { 2562 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2563 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2564 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2565 int split = p.count - r * elements; 2566 2567 if (split <= 0) { 2568 memcpy(d, m, vl); /* all false */ 2569 } else if (elements <= split) { 2570 memcpy(d, n, vl); /* all true */ 2571 } else { 2572 for (int e = 0; e < split; e++) { 2573 d[H4(e)] = n[H4(e)]; 2574 } 2575 for (int e = split; e < elements; e++) { 2576 d[H4(e)] = m[H4(e)]; 2577 } 2578 } 2579 } 2580 } 2581 } else { 2582 /* p.esz must be MO_64, so stride must be 2. */ 2583 if (p.invert) { 2584 for (int r = 0; r < nreg; r++) { 2585 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2586 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2587 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2588 int split = p.count - r * elements; 2589 int e = 0; 2590 2591 for (; e < MIN(split, elements); e++) { 2592 d[H4(e)] = m[H4(e)]; 2593 } 2594 for (; e < elements; e += 2) { 2595 d[H4(e)] = n[H4(e)]; 2596 d[H4(e + 1)] = m[H4(e + 1)]; 2597 } 2598 } 2599 } else { 2600 for (int r = 0; r < nreg; r++) { 2601 uint32_t *d = vd + r * sizeof(ARMVectorReg); 2602 uint32_t *n = vn + r * sizeof(ARMVectorReg); 2603 uint32_t *m = vm + r * sizeof(ARMVectorReg); 2604 int split = p.count - r * elements; 2605 int e = 0; 2606 2607 for (; e < MIN(split, elements); e += 2) { 2608 d[H4(e)] = n[H4(e)]; 2609 d[H4(e + 1)] = m[H4(e + 1)]; 2610 } 2611 for (; e < elements; e++) { 2612 d[H4(e)] = m[H4(e)]; 2613 } 2614 } 2615 } 2616 } 2617 } 2618 2619 void HELPER(sme2_sel_d)(void *vd, void *vn, void *vm, 2620 uint32_t png, uint32_t desc) 2621 { 2622 int vl = simd_oprsz(desc); 2623 int nreg = simd_data(desc); 2624 int elements = vl / sizeof(uint64_t); 2625 DecodeCounter p = decode_counter(png, vl, MO_64); 2626 2627 if (p.invert) { 2628 for (int r = 0; r < nreg; r++) { 2629 uint64_t *d = vd + r * sizeof(ARMVectorReg); 2630 uint64_t *n = vn + r * sizeof(ARMVectorReg); 2631 uint64_t *m = vm + r * sizeof(ARMVectorReg); 2632 int split = p.count - r * elements; 2633 2634 if (split <= 0) { 2635 memcpy(d, n, vl); /* all true */ 2636 } else if (elements <= split) { 2637 memcpy(d, m, vl); /* all false */ 2638 } else { 2639 memcpy(d, m, split * sizeof(uint64_t)); 2640 memcpy(d + split, n + split, 2641 (elements - split) * sizeof(uint64_t)); 2642 } 2643 } 2644 } else { 2645 for (int r = 0; r < nreg; r++) { 2646 uint64_t *d = vd + r * sizeof(ARMVectorReg); 2647 uint64_t *n = vn + r * sizeof(ARMVectorReg); 2648 uint64_t *m = vm + r * sizeof(ARMVectorReg); 2649 int split = p.count - r * elements; 2650 2651 if (split <= 0) { 2652 memcpy(d, m, vl); /* all false */ 2653 } else if (elements <= split) { 2654 memcpy(d, n, vl); /* all true */ 2655 } else { 2656 memcpy(d, n, split * sizeof(uint64_t)); 2657 memcpy(d + split, m + split, 2658 (elements - split) * sizeof(uint64_t)); 2659 } 2660 } 2661 } 2662 } 2663