1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/static_call.h>
3 #include <linux/memory.h>
4 #include <linux/bug.h>
5 #include <asm/text-patching.h>
6
7 enum insn_type {
8 CALL = 0, /* site call */
9 NOP = 1, /* site cond-call */
10 JMP = 2, /* tramp / site tail-call */
11 RET = 3, /* tramp / site cond-tail-call */
12 };
13
__static_call_transform(void * insn,enum insn_type type,void * func)14 static void __ref __static_call_transform(void *insn, enum insn_type type, void *func)
15 {
16 int size = CALL_INSN_SIZE;
17 const void *code;
18
19 switch (type) {
20 case CALL:
21 code = text_gen_insn(CALL_INSN_OPCODE, insn, func);
22 break;
23
24 case NOP:
25 code = ideal_nops[NOP_ATOMIC5];
26 break;
27
28 case JMP:
29 code = text_gen_insn(JMP32_INSN_OPCODE, insn, func);
30 break;
31
32 case RET:
33 code = text_gen_insn(RET_INSN_OPCODE, insn, func);
34 size = RET_INSN_SIZE;
35 break;
36 }
37
38 if (memcmp(insn, code, size) == 0)
39 return;
40
41 if (unlikely(system_state == SYSTEM_BOOTING))
42 return text_poke_early(insn, code, size);
43
44 text_poke_bp(insn, code, size, NULL);
45 }
46
__static_call_validate(void * insn,bool tail)47 static void __static_call_validate(void *insn, bool tail)
48 {
49 u8 opcode = *(u8 *)insn;
50
51 if (tail) {
52 if (opcode == JMP32_INSN_OPCODE ||
53 opcode == RET_INSN_OPCODE)
54 return;
55 } else {
56 if (opcode == CALL_INSN_OPCODE ||
57 !memcmp(insn, ideal_nops[NOP_ATOMIC5], 5))
58 return;
59 }
60
61 /*
62 * If we ever trigger this, our text is corrupt, we'll probably not live long.
63 */
64 WARN_ONCE(1, "unexpected static_call insn opcode 0x%x at %pS\n", opcode, insn);
65 }
66
__sc_insn(bool null,bool tail)67 static inline enum insn_type __sc_insn(bool null, bool tail)
68 {
69 /*
70 * Encode the following table without branches:
71 *
72 * tail null insn
73 * -----+-------+------
74 * 0 | 0 | CALL
75 * 0 | 1 | NOP
76 * 1 | 0 | JMP
77 * 1 | 1 | RET
78 */
79 return 2*tail + null;
80 }
81
arch_static_call_transform(void * site,void * tramp,void * func,bool tail)82 void arch_static_call_transform(void *site, void *tramp, void *func, bool tail)
83 {
84 mutex_lock(&text_mutex);
85
86 if (tramp) {
87 __static_call_validate(tramp, true);
88 __static_call_transform(tramp, __sc_insn(!func, true), func);
89 }
90
91 if (IS_ENABLED(CONFIG_HAVE_STATIC_CALL_INLINE) && site) {
92 __static_call_validate(site, tail);
93 __static_call_transform(site, __sc_insn(!func, tail), func);
94 }
95
96 mutex_unlock(&text_mutex);
97 }
98 EXPORT_SYMBOL_GPL(arch_static_call_transform);
99