xref: /kvm-unit-tests/x86/taskswitch.c (revision 7d36db351752e29ad27eaafe3f102de7064e429b)
1 /*
2  * Copyright 2010 Siemens AG
3  * Author: Jan Kiszka
4  *
5  * Released under GPLv2.
6  */
7 
8 #include "libcflat.h"
9 
10 #define FIRST_SPARE_SEL		0x18
11 
12 struct exception_frame {
13 	unsigned long error_code;
14 	unsigned long ip;
15 	unsigned long cs;
16 	unsigned long flags;
17 };
18 
19 struct tss32 {
20 	unsigned short prev;
21 	unsigned short res1;
22 	unsigned long esp0;
23 	unsigned short ss0;
24 	unsigned short res2;
25 	unsigned long esp1;
26 	unsigned short ss1;
27 	unsigned short res3;
28 	unsigned long esp2;
29 	unsigned short ss2;
30 	unsigned short res4;
31 	unsigned long cr3;
32 	unsigned long eip;
33 	unsigned long eflags;
34 	unsigned long eax, ecx, edx, ebx, esp, ebp, esi, edi;
35 	unsigned short es;
36 	unsigned short res5;
37 	unsigned short cs;
38 	unsigned short res6;
39 	unsigned short ss;
40 	unsigned short res7;
41 	unsigned short ds;
42 	unsigned short res8;
43 	unsigned short fs;
44 	unsigned short res9;
45 	unsigned short gs;
46 	unsigned short res10;
47 	unsigned short ldt;
48 	unsigned short res11;
49 	unsigned short t:1;
50 	unsigned short res12:15;
51 	unsigned short iomap_base;
52 };
53 
54 static char main_stack[4096];
55 static char fault_stack[4096];
56 static struct tss32 main_tss;
57 static struct tss32 fault_tss;
58 
59 static unsigned long long gdt[] __attribute__((aligned(16))) = {
60 	0,
61 	0x00cf9b000000ffffull,
62 	0x00cf93000000ffffull,
63 	0, 0,	/* TSS segments */
64 	0,	/* task return gate */
65 };
66 
67 static unsigned long long gdtr;
68 
69 void fault_entry(void);
70 
71 static __attribute__((used, regparm(1))) void
72 fault_handler(unsigned long error_code)
73 {
74 	unsigned short *desc;
75 
76 	printf("fault at %x:%x, prev task %x, error code %x\n",
77 	       main_tss.cs, main_tss.eip, fault_tss.prev, error_code);
78 
79 	main_tss.eip += 2;
80 
81 	desc = (unsigned short *)&gdt[3];
82 	desc[2] &= ~0x0200;
83 
84 	desc = (unsigned short *)&gdt[5];
85 	desc[0] = 0;
86 	desc[1] = fault_tss.prev;
87 	desc[2] = 0x8500;
88 	desc[3] = 0;
89 }
90 
91 asm (
92 	"fault_entry:\n"
93 	"	mov (%esp),%eax\n"
94 	"	call fault_handler\n"
95 	"	jmp $0x28, $0\n"
96 );
97 
98 static void setup_tss(struct tss32 *tss, void *entry,
99 		      void *stack_base, unsigned long stack_size)
100 {
101 	unsigned long cr3;
102 	unsigned short cs, ds;
103 
104 	asm ("mov %%cr3,%0" : "=r" (cr3));
105 	asm ("mov %%cs,%0" : "=r" (cs));
106 	asm ("mov %%ds,%0" : "=r" (ds));
107 
108 	tss->ss0 = tss->ss1 = tss->ss2 = tss->ss = ds;
109 	tss->esp0 = tss->esp1 = tss->esp2 = tss->esp =
110 		(unsigned long)stack_base + stack_size;
111 	tss->ds = tss->es = tss->fs = tss->gs = ds;
112 	tss->cs = cs;
113 	tss->eip = (unsigned long)entry;
114 	tss->cr3 = cr3;
115 }
116 
117 static void setup_tss_desc(unsigned short tss_sel, struct tss32 *tss)
118 {
119 	unsigned long addr = (unsigned long)tss;
120 	unsigned short *desc;
121 
122 	desc = (unsigned short *)&gdt[tss_sel/8];
123 	desc[0] = sizeof(*tss) - 1;
124 	desc[1] = addr;
125 	desc[2] = 0x8900 | ((addr & 0x00ff0000) >> 16);
126 	desc[3] = (addr & 0xff000000) >> 16;
127 }
128 
129 static void set_intr_task(unsigned short tss_sel, int intr, struct tss32 *tss)
130 {
131 	unsigned short *desc = (void *)(intr* sizeof(long) * 2);
132 
133 	setup_tss_desc(tss_sel, tss);
134 
135 	desc[0] = 0;
136 	desc[1] = tss_sel;
137 	desc[2] = 0x8500;
138 	desc[3] = 0;
139 }
140 
141 int main(int ac, char **av)
142 {
143 	const long invalid_segment = 0x1234;
144 
145 	gdtr = ((unsigned long long)(unsigned long)&gdt << 16) |
146 		(sizeof(gdt) - 1);
147 	asm ("lgdt %0" : : "m" (gdtr));
148 
149 	setup_tss(&main_tss, 0, main_stack, sizeof(main_stack));
150 	setup_tss_desc(FIRST_SPARE_SEL, &main_tss);
151 	asm ("ltr %0" : : "r" ((unsigned short)FIRST_SPARE_SEL));
152 
153 	setup_tss(&fault_tss, fault_entry, fault_stack, sizeof(fault_stack));
154 	set_intr_task(FIRST_SPARE_SEL+8, 13, &fault_tss);
155 
156 	asm (
157 		"mov %0,%%es\n"
158 		: : "r" (invalid_segment) : "edi"
159 	);
160 
161 	printf("post fault\n");
162 
163 	return 0;
164 }
165