#include "fault_test.h" static jmp_buf jmpbuf; static void restore_exec_to_jmpbuf(void) { longjmp(jmpbuf, 1); } static void fault_test_fault(struct ex_regs *regs) { regs->rip = (unsigned long)&restore_exec_to_jmpbuf; } static bool fault_test(struct fault_test_arg *arg) { volatile uint64_t val; bool raised_vector = false; test_fault_func func = (test_fault_func) arg->func; /* Init as success in case there isn't callback */ bool callback_success = true; handler old; if (arg->usermode) { val = run_in_user((usermode_func) func, arg->fault_vector, arg->arg[0], arg->arg[1], arg->arg[2], arg->arg[3], &raised_vector); } else { old = handle_exception(arg->fault_vector, fault_test_fault); if (setjmp(jmpbuf) == 0) val = func(arg->arg[0], arg->arg[1], arg->arg[2], arg->arg[3]); else raised_vector = true; handle_exception(arg->fault_vector, old); } if (!raised_vector) { arg->retval = val; if (arg->callback != NULL) callback_success = arg->callback(arg); } return arg->should_fault ? raised_vector : (!raised_vector && callback_success); } void test_run(struct fault_test *test) { bool passed = fault_test(&(test->arg)); report(passed, "%s", test->name); }