xref: /kvm-unit-tests/x86/cet.c (revision c865f654ffe4c5955038aaf74f702ba62f3eb014)
1 
2 #include "libcflat.h"
3 #include "x86/desc.h"
4 #include "x86/processor.h"
5 #include "x86/vm.h"
6 #include "x86/msr.h"
7 #include "vmalloc.h"
8 #include "alloc_page.h"
9 #include "fault_test.h"
10 
11 
12 static unsigned char user_stack[0x400];
13 static unsigned long rbx, rsi, rdi, rsp, rbp, r8, r9,
14 		     r10, r11, r12, r13, r14, r15;
15 
16 static unsigned long expected_rip;
17 static int cp_count;
18 typedef u64 (*cet_test_func)(void);
19 
20 cet_test_func func;
21 
22 static u64 cet_shstk_func(void)
23 {
24 	unsigned long *ret_addr, *ssp;
25 
26 	/* rdsspq %rax */
27 	asm volatile (".byte 0xf3, 0x48, 0x0f, 0x1e, 0xc8" : "=a"(ssp));
28 
29 	asm("movq %%rbp,%0" : "=r"(ret_addr));
30 	printf("The return-address in shadow-stack = 0x%lx, in normal stack = 0x%lx\n",
31 	       *ssp, *(ret_addr + 1));
32 
33 	/*
34 	 * In below line, it modifies the return address, it'll trigger #CP
35 	 * while function is returning. The error-code is 0x1, meaning it's
36 	 * caused by a near RET instruction, and the execution is terminated
37 	 * when HW detects the violation.
38 	 */
39 	printf("Try to temper the return-address, this causes #CP on returning...\n");
40 	*(ret_addr + 1) = 0xdeaddead;
41 
42 	return 0;
43 }
44 
45 static u64 cet_ibt_func(void)
46 {
47 	/*
48 	 * In below assembly code, the first instruction at lable 2 is not
49 	 * endbr64, it'll trigger #CP with error code 0x3, and the execution
50 	 * is terminated when HW detects the violation.
51 	 */
52 	printf("No endbr64 instruction at jmp target, this triggers #CP...\n");
53 	asm volatile ("movq $2, %rcx\n"
54 		      "dec %rcx\n"
55 		      "leaq 2f, %rax\n"
56 		      "jmp *%rax \n"
57 		      "2:\n"
58 		      "dec %rcx\n");
59 	return 0;
60 }
61 
62 void test_func(void);
63 void test_func(void) {
64 	asm volatile (
65 			/* IRET into user mode */
66 			"pushq %[user_ds]\n\t"
67 			"pushq %[user_stack_top]\n\t"
68 			"pushfq\n\t"
69 			"pushq %[user_cs]\n\t"
70 			"pushq $user_mode\n\t"
71 			"iretq\n"
72 
73 			"user_mode:\n\t"
74 			"call *%[func]\n\t"
75 			::
76 			[func]"m"(func),
77 			[user_ds]"i"(USER_DS),
78 			[user_cs]"i"(USER_CS),
79 			[user_stack_top]"r"(user_stack +
80 					sizeof(user_stack)));
81 }
82 
83 #define SAVE_REGS() \
84 	asm ("movq %%rbx, %0\t\n"  \
85 	     "movq %%rsi, %1\t\n"  \
86 	     "movq %%rdi, %2\t\n"  \
87 	     "movq %%rsp, %3\t\n"  \
88 	     "movq %%rbp, %4\t\n"  \
89 	     "movq %%r8, %5\t\n"   \
90 	     "movq %%r9, %6\t\n"   \
91 	     "movq %%r10, %7\t\n"  \
92 	     "movq %%r11, %8\t\n"  \
93 	     "movq %%r12, %9\t\n"  \
94 	     "movq %%r13, %10\t\n" \
95 	     "movq %%r14, %11\t\n" \
96 	     "movq %%r15, %12\t\n" :: \
97 	     "m"(rbx), "m"(rsi), "m"(rdi), "m"(rsp), "m"(rbp), \
98 	     "m"(r8), "m"(r9), "m"(r10),  "m"(r11), "m"(r12),  \
99 	     "m"(r13), "m"(r14), "m"(r15));
100 
101 #define RESTOR_REGS() \
102 	asm ("movq %0, %%rbx\t\n"  \
103 	     "movq %1, %%rsi\t\n"  \
104 	     "movq %2, %%rdi\t\n"  \
105 	     "movq %3, %%rsp\t\n"  \
106 	     "movq %4, %%rbp\t\n"  \
107 	     "movq %5, %%r8\t\n"   \
108 	     "movq %6, %%r9\t\n"   \
109 	     "movq %7, %%r10\t\n"  \
110 	     "movq %8, %%r11\t\n"  \
111 	     "movq %9, %%r12\t\n"  \
112 	     "movq %10, %%r13\t\n" \
113 	     "movq %11, %%r14\t\n" \
114 	     "movq %12, %%r15\t\n" ::\
115 	     "m"(rbx), "m"(rsi), "m"(rdi), "m"(rsp), "m"(rbp), \
116 	     "m"(r8), "m"(r9), "m"(r10), "m"(r11), "m"(r12),   \
117 	     "m"(r13), "m"(r14), "m"(r15));
118 
119 #define RUN_TEST() \
120 	do {		\
121 		SAVE_REGS();    \
122 		asm volatile ("pushq %%rax\t\n"           \
123 			      "leaq 1f(%%rip), %%rax\t\n" \
124 			      "movq %%rax, %0\t\n"        \
125 			      "popq %%rax\t\n"            \
126 			      "call test_func\t\n"         \
127 			      "1:" ::"m"(expected_rip) : "rax", "rdi"); \
128 		RESTOR_REGS(); \
129 	} while (0)
130 
131 #define ENABLE_SHSTK_BIT 0x1
132 #define ENABLE_IBT_BIT   0x4
133 
134 static void handle_cp(struct ex_regs *regs)
135 {
136 	cp_count++;
137 	printf("In #CP exception handler, error_code = 0x%lx\n",
138 		regs->error_code);
139 	asm("jmp *%0" :: "m"(expected_rip));
140 }
141 
142 int main(int ac, char **av)
143 {
144 	char *shstk_virt;
145 	unsigned long shstk_phys;
146 	unsigned long *ptep;
147 	pteval_t pte = 0;
148 
149 	cp_count = 0;
150 	if (!this_cpu_has(X86_FEATURE_SHSTK)) {
151 		printf("SHSTK not enabled\n");
152 		return report_summary();
153 	}
154 
155 	if (!this_cpu_has(X86_FEATURE_IBT)) {
156 		printf("IBT not enabled\n");
157 		return report_summary();
158 	}
159 
160 	setup_vm();
161 	setup_idt();
162 	handle_exception(21, handle_cp);
163 
164 	/* Allocate one page for shadow-stack. */
165 	shstk_virt = alloc_vpage();
166 	shstk_phys = (unsigned long)virt_to_phys(alloc_page());
167 
168 	/* Install the new page. */
169 	pte = shstk_phys | PT_PRESENT_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
170 	install_pte(current_page_table(), 1, shstk_virt, pte, 0);
171 	memset(shstk_virt, 0x0, PAGE_SIZE);
172 
173 	/* Mark it as shadow-stack page. */
174 	ptep = get_pte_level(current_page_table(), shstk_virt, 1);
175 	*ptep &= ~PT_WRITABLE_MASK;
176 	*ptep |= PT_DIRTY_MASK;
177 
178 	/* Flush the paging cache. */
179 	invlpg((void *)shstk_phys);
180 
181 	/* Enable shadow-stack protection */
182 	wrmsr(MSR_IA32_U_CET, ENABLE_SHSTK_BIT);
183 
184 	/* Store shadow-stack pointer. */
185 	wrmsr(MSR_IA32_PL3_SSP, (u64)(shstk_virt + 0x1000));
186 
187 	/* Enable CET master control bit in CR4. */
188 	write_cr4(read_cr4() | X86_CR4_CET);
189 
190 	func = cet_shstk_func;
191 	RUN_TEST();
192 	report(cp_count == 1, "Completed shadow-stack protection test successfully.");
193 	cp_count = 0;
194 
195 	/* Do user-mode indirect-branch-tracking test.*/
196 	func = cet_ibt_func;
197 	/* Enable indirect-branch tracking */
198 	wrmsr(MSR_IA32_U_CET, ENABLE_IBT_BIT);
199 
200 	RUN_TEST();
201 	report(cp_count == 1, "Completed Indirect-branch tracking test successfully.");
202 
203 	write_cr4(read_cr4() & ~X86_CR4_CET);
204 	wrmsr(MSR_IA32_U_CET, 0);
205 
206 	return report_summary();
207 }
208