Skip to content

Commit

Permalink
Leave the CPU in a good state on decode segfaults
Browse files Browse the repository at this point in the history
The segfault_addr is not always the first byte of the faulting
instruction. If the instruction spans two pages, and the second page is
inaccessible, it needs to point to the first byte of that page.
  • Loading branch information
tbodt committed Jan 18, 2021
1 parent 5f2d6fe commit 99dc0f5
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 43 deletions.
6 changes: 3 additions & 3 deletions emu/modrm.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ static const unsigned rm_disp32 = reg_ebp;
// read modrm and maybe sib, output information into *modrm, return false for segfault
static inline bool modrm_decode32(addr_t *ip, struct tlb *tlb, struct modrm *modrm) {
#define READ(thing) \
if (!tlb_read(tlb, *ip, &(thing), sizeof(thing))) \
return false; \
*ip += sizeof(thing);
*ip += sizeof(thing); \
if (!tlb_read(tlb, *ip - sizeof(thing), &(thing), sizeof(thing))) \
return false

byte_t modrm_byte;
READ(modrm_byte);
Expand Down
12 changes: 9 additions & 3 deletions emu/tlb.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ void tlb_free(struct tlb *tlb) {

bool __tlb_read_cross_page(struct tlb *tlb, addr_t addr, char *value, unsigned size) {
char *ptr1 = __tlb_read_ptr(tlb, addr);
if (ptr1 == NULL)
return false;
char *ptr2 = __tlb_read_ptr(tlb, (PAGE(addr) + 1) << PAGE_BITS);
if (ptr1 == NULL || ptr2 == NULL)
if (ptr2 == NULL)
return false;
size_t part1 = PAGE_SIZE - PGOFFSET(addr);
assert(part1 < size);
Expand All @@ -34,8 +36,10 @@ bool __tlb_read_cross_page(struct tlb *tlb, addr_t addr, char *value, unsigned s

bool __tlb_write_cross_page(struct tlb *tlb, addr_t addr, const char *value, unsigned size) {
char *ptr1 = __tlb_write_ptr(tlb, addr);
if (ptr1 == NULL)
return false;
char *ptr2 = __tlb_write_ptr(tlb, (PAGE(addr) + 1) << PAGE_BITS);
if (ptr1 == NULL || ptr2 == NULL)
if (ptr2 == NULL)
return false;
size_t part1 = PAGE_SIZE - PGOFFSET(addr);
assert(part1 < size);
Expand All @@ -48,8 +52,10 @@ __no_instrument void *tlb_handle_miss(struct tlb *tlb, addr_t addr, int type) {
char *ptr = mmu_translate(tlb->mmu, TLB_PAGE(addr), type);
if (tlb->mmu->changes != tlb->mem_changes)
tlb_flush(tlb);
if (ptr == NULL)
if (ptr == NULL) {
tlb->segfault_addr = addr;
return NULL;
}
tlb->dirty_page = TLB_PAGE(addr);

struct tlb_entry *tlb_ent = &tlb->entries[TLB_INDEX(addr)];
Expand Down
3 changes: 3 additions & 0 deletions emu/tlb.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ struct tlb {
struct mmu *mmu;
page_t dirty_page;
unsigned mem_changes;
// this is basically one of the return values of tlb_handle_miss, tlb_{read,write}, and __tlb_{read,write}_cross_page
// yes, this sucks
addr_t segfault_addr;
struct tlb_entry entries[TLB_SIZE];
};

Expand Down
2 changes: 2 additions & 0 deletions jit/gadgets-aarch64/entry.S
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jit_exit:

.gadget interrupt
ldr _tmp, [_ip]
ldr w8, [_ip, 16]
str w8, [_cpu, CPU_segfault_addr]
ldr eip, [_ip, 8]
b jit_exit

Expand Down
1 change: 1 addition & 0 deletions jit/gadgets-aarch64/memory.S
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ NAME(si_gadgets):
ret

segfault_\type:
ldr _addr, [_tlb, -TLB_entries+TLB_segfault_addr]
str _addr, [_cpu, CPU_segfault_addr]
ldr eip, [_ip]
mov x0, INT_GPF
Expand Down
5 changes: 4 additions & 1 deletion jit/gadgets-aarch64/misc.S
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@
.endif
.ifin(\type, read,write)
mov x1, _xaddr
ldr x8, [_ip, 8]
.endifin
.ifin(\type, 0,1,2)
ldr x8, [_ip]
.endifin
ldr x8, [_ip]
blr x8
restore_c
load_regs
Expand Down
3 changes: 2 additions & 1 deletion jit/gadgets-x86_64/entry.S
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ jit_exit:

.gadget interrupt
movl (%_ip), %_tmp
movl 16(%_ip), %r14d
movl %r14d, CPU_segfault_addr(%_cpu)
movl 8(%_ip), %_eip
movl %_eip, CPU_segfault_addr(%_cpu)
movb $0, CPU_segfault_was_write(%_cpu)
jmp jit_exit

Expand Down
1 change: 1 addition & 0 deletions jit/gadgets-x86_64/memory.S
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
ret

segfault_\type:
movl -TLB_entries+TLB_segfault_addr(%_tlb), %_addr
movl %_addr, CPU_segfault_addr(%_cpu)
.ifc \type,read
movb $0, CPU_segfault_was_write(%_cpu)
Expand Down
5 changes: 4 additions & 1 deletion jit/gadgets-x86_64/misc.S
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,11 @@
.endif
.ifin(\type, read,write)
movq %_addrq, %rsi
callq *8(%_ip)
.endifin
.ifin(\type, 0,1,2)
callq *(%_ip)
.endifin
callq *(%_ip)
restore_c
load_regs
.ifc \type,write
Expand Down
66 changes: 36 additions & 30 deletions jit/gen.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
#include "emu/vec.h"
#include "emu/interrupt.h"

static int gen_step32(struct gen_state *state, struct tlb *tlb);
static int gen_step16(struct gen_state *state, struct tlb *tlb);

int gen_step(struct gen_state *state, struct tlb *tlb) {
state->orig_ip = state->ip;
return gen_step32(state, tlb);
}

static void gen(struct gen_state *state, unsigned long thing) {
assert(state->size <= state->capacity);
if (state->size >= state->capacity) {
Expand Down Expand Up @@ -71,18 +79,17 @@ void gen_exit(struct gen_state *state) {
}

#define DECLARE_LOCALS \
dword_t saved_ip = state->ip; \
dword_t addr_offset = 0; \
bool end_block = false; \
bool seg_gs = false

#define FINISH \
return !end_block

#define RESTORE_IP state->ip = saved_ip
#define RESTORE_IP state->ip = state->orig_ip
#define _READIMM(name, size) \
if (!tlb_read(tlb, state->ip, &name, size/8)) SEGFAULT; \
state->ip += size/8
state->ip += size/8; \
if (!tlb_read(tlb, state->ip - size/8, &name, size/8)) SEGFAULT; else

#define READMODRM if (!modrm_decode32(&state->ip, tlb, &modrm)) SEGFAULT
#define READADDR _READIMM(addr_offset, 32)
Expand Down Expand Up @@ -134,11 +141,10 @@ typedef void (*gadget_t)(void);
#define h(h) gg(helper_0, h)
#define hh(h, a) ggg(helper_1, h, a)
#define hhh(h, a, b) gggg(helper_2, h, a, b)
#define h_read(h, z) do { g_addr(); gg_here(helper_read##z, h##z); } while (0)
#define h_write(h, z) do { g_addr(); gg_here(helper_write##z, h##z); } while (0)
#define gg_here(g, a) ggg(g, a, saved_ip)
#define UNDEFINED do { gg_here(interrupt, INT_UNDEFINED); return false; } while (0)
#define SEGFAULT do { gg_here(interrupt, INT_GPF); return false; } while (0)
#define h_read(h, z) do { g_addr(); ggg(helper_read##z, state->orig_ip, h##z); } while (0)
#define h_write(h, z) do { g_addr(); ggg(helper_write##z, state->orig_ip, h##z); } while (0)
#define UNDEFINED do { gggg(interrupt, INT_UNDEFINED, state->orig_ip, state->orig_ip); return false; } while (0)
#define SEGFAULT do { gggg(interrupt, INT_GPF, state->orig_ip, tlb->segfault_addr); return false; } while (0)

static inline int sz(int size) {
switch (size) {
Expand All @@ -149,7 +155,7 @@ static inline int sz(int size) {
}
}

bool gen_addr(struct gen_state *state, struct modrm *modrm, bool seg_gs, dword_t saved_ip) {
bool gen_addr(struct gen_state *state, struct modrm *modrm, bool seg_gs) {
if (modrm->base == reg_none)
gg(addr_none, modrm->offset);
else
Expand All @@ -160,12 +166,12 @@ bool gen_addr(struct gen_state *state, struct modrm *modrm, bool seg_gs, dword_t
g(seg_gs);
return true;
}
#define g_addr() gen_addr(state, &modrm, seg_gs, saved_ip)
#define g_addr() gen_addr(state, &modrm, seg_gs)

// this really wants to use all the locals of the decoder, which we can do
// really nicely in gcc using nested functions, but that won't work in clang,
// so we explicitly pass 500 arguments. sorry for the mess
static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg arg, struct modrm *modrm, uint64_t *imm, int size, dword_t saved_ip, bool seg_gs, dword_t addr_offset) {
static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg arg, struct modrm *modrm, uint64_t *imm, int size, bool seg_gs, dword_t addr_offset) {
size = sz(size);
gadgets = gadgets + size * arg_count;

Expand Down Expand Up @@ -194,19 +200,19 @@ static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg a
UNDEFINED;
}
if (arg == arg_mem || arg == arg_addr) {
if (!gen_addr(state, modrm, seg_gs, saved_ip))
if (!gen_addr(state, modrm, seg_gs))
return false;
}
GEN(gadgets[arg]);
if (arg == arg_imm)
GEN(*imm);
else if (arg == arg_mem)
GEN(saved_ip);
GEN(state->orig_ip);
return true;
}
#define op(type, thing, z) do { \
extern gadget_t type##_gadgets[]; \
if (!gen_op(state, type##_gadgets, arg_##thing, &modrm, &imm, z, saved_ip, seg_gs, addr_offset)) return false; \
if (!gen_op(state, type##_gadgets, arg_##thing, &modrm, &imm, z, seg_gs, addr_offset)) return false; \
} while (0)

#define load(thing, z) op(load, thing, z)
Expand All @@ -233,8 +239,8 @@ static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg a
#define NOT(val,z) load(val,z); gz(not, z); store(val,z)
#define NEG(val,z) imm = 0; load(imm,z); op(sub, val,z); store(val,z)

#define POP(thing,z) gg(pop, saved_ip); store(thing, z)
#define PUSH(thing,z) load(thing, z); gg(push, saved_ip)
#define POP(thing,z) gg(pop, state->orig_ip); store(thing, z)
#define PUSH(thing,z) load(thing, z); gg(push, state->orig_ip)

#define INC(val,z) load(val, z); gz(inc, z); store(val, z)
#define DEC(val,z) load(val, z); gz(dec, z); store(val, z)
Expand All @@ -252,27 +258,27 @@ static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg a
#define J_REL(cc, off) jcc(cc, fake_ip + off, fake_ip)
#define JN_REL(cc, off) jcc(cc, fake_ip, fake_ip + off)

// saved_ip: for use with page fault handler;
// state->orig_ip: for use with page fault handler;
// -1: will be patched to block address in gen_end();
// fake_ip: the first one is the return address, used for saving to stack and verifying the cached ip in return cache is correct;
// fake_ip: the second one is the return target, patchable by return chaining.
#define CALL(loc) do { \
load(loc, OP_SIZE); \
ggggg(call_indir, saved_ip, -1, fake_ip, fake_ip); \
ggggg(call_indir, state->orig_ip, -1, fake_ip, fake_ip); \
state->block_patch_ip = state->size - 3; \
jump_ips(-1, 0); \
end_block = true; \
} while (0)
// the first four arguments are the same with CALL,
// the last one is the call target, patchable by return chaining.
#define CALL_REL(off) do { \
gggggg(call, saved_ip, -1, fake_ip, fake_ip, fake_ip + off); \
gggggg(call, state->orig_ip, -1, fake_ip, fake_ip, fake_ip + off); \
state->block_patch_ip = state->size - 4; \
jump_ips(-2, -1); \
end_block = true; \
} while (0)
#define RET_NEAR(imm) ggg(ret, saved_ip, 4 + imm); end_block = true
#define INT(code) ggg(interrupt, (uint8_t) code, state->ip); end_block = true
#define RET_NEAR(imm) ggg(ret, state->orig_ip, 4 + imm); end_block = true
#define INT(code) gggg(interrupt, (uint8_t) code, state->ip, 0); end_block = true

#define SET(cc, dst) ga(set, cond_##cc); store(dst, 8)
#define SETN(cc, dst) ga(setn, cond_##cc); store(dst, 8)
Expand Down Expand Up @@ -335,14 +341,14 @@ static inline bool gen_op(struct gen_state *state, gadget_t *gadgets, enum arg a

#define BSWAP(dst) ga(bswap, arg_##dst)

#define strop(op, rep, z) gag(op, sz(z) * size_count + rep_##rep, saved_ip)
#define strop(op, rep, z) gag(op, sz(z) * size_count + rep_##rep, state->orig_ip)
#define STR(op, z) strop(op, once, z)
#define REP(op, z) strop(op, rep, z)
#define REPZ(op, z) strop(op, repz, z)
#define REPNZ(op, z) strop(op, repnz, z)

#define CMPXCHG(src, dst,z) load(src, z); op(cmpxchg, dst, z)
#define CMPXCHG8B(dst,z) g_addr(); gg(cmpxchg8b, saved_ip)
#define CMPXCHG8B(dst,z) g_addr(); gg(cmpxchg8b, state->orig_ip)
#define XADD(src, dst,z) XCHG(src, dst,z); ADD(src, dst,z)

void helper_rdtsc(struct cpu_state *cpu);
Expand All @@ -365,7 +371,7 @@ void helper_rdtsc(struct cpu_state *cpu);
#define ATOMIC_BTC(bit, val,z) lo(atomic_btc, val, bit, z)
#define ATOMIC_BTS(bit, val,z) lo(atomic_bts, val, bit, z)
#define ATOMIC_BTR(bit, val,z) lo(atomic_btr, val, bit, z)
#define ATOMIC_CMPXCHG8B(dst,z) g_addr(); gg(atomic_cmpxchg8b, saved_ip)
#define ATOMIC_CMPXCHG8B(dst,z) g_addr(); gg(atomic_cmpxchg8b, state->orig_ip)

// fpu
#define st_0 0
Expand Down Expand Up @@ -443,7 +449,7 @@ static inline uint16_t cpu_reg_offset(enum arg arg, int index) {
return 0;
}

static inline bool gen_vec(enum arg src, enum arg dst, void (*helper)(), gadget_t read_mem_gadget, gadget_t write_mem_gadget, struct gen_state *state, struct modrm *modrm, uint8_t imm, dword_t saved_ip, bool seg_gs, bool has_imm) {
static inline bool gen_vec(enum arg src, enum arg dst, void (*helper)(), gadget_t read_mem_gadget, gadget_t write_mem_gadget, struct gen_state *state, struct modrm *modrm, uint8_t imm, bool seg_gs, bool has_imm) {
bool rm_is_src = !could_be_memory(dst);
enum arg rm = rm_is_src ? src : dst;
enum arg reg = rm_is_src ? dst : src;
Expand Down Expand Up @@ -479,9 +485,9 @@ static inline bool gen_vec(enum arg src, enum arg dst, void (*helper)(), gadget_
break;

case arg_mem:
gen_addr(state, modrm, seg_gs, saved_ip);
gen_addr(state, modrm, seg_gs);
GEN(rm_is_src ? read_mem_gadget : write_mem_gadget);
GEN(saved_ip);
GEN(state->orig_ip);
GEN(helper);
GEN(reg_offset | imm_arg);
break;
Expand All @@ -504,7 +510,7 @@ static inline bool gen_vec(enum arg src, enum arg dst, void (*helper)(), gadget_
#define _v(src, dst, helper, _imm, z) do { \
extern void gadget_vec_helper_read##z##_imm(void); \
extern void gadget_vec_helper_write##z##_imm(void); \
if (!gen_vec(src, dst, (void (*)()) helper, gadget_vec_helper_read##z##_imm, gadget_vec_helper_write##z##_imm, state, &modrm, imm, saved_ip, seg_gs, has_imm_##_imm)) return false; \
if (!gen_vec(src, dst, (void (*)()) helper, gadget_vec_helper_read##z##_imm, gadget_vec_helper_write##z##_imm, state, &modrm, imm, seg_gs, has_imm_##_imm)) return false; \
} while (0)
#define v_(op, src, dst, _imm,z) _v(arg_##src, arg_##dst, vec_##op##z, _imm,z)
#define v(op, src, dst,z) v_(op, src, dst,,z)
Expand Down Expand Up @@ -535,7 +541,7 @@ static inline bool gen_vec(enum arg src, enum arg dst, void (*helper)(), gadget_
#define V_OP(op, src, dst, z) v(op, src, dst, z)
#define V_OP_IMM(op, src, dst, z) v_imm(op, src, dst, z)

#define DECODER_RET int
#define DECODER_RET static int
#define DECODER_NAME gen_step
#define DECODER_ARGS struct gen_state *state, struct tlb *tlb
#define DECODER_PASS_ARGS state, tlb
Expand Down
4 changes: 2 additions & 2 deletions jit/gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

struct gen_state {
addr_t ip;
addr_t orig_ip;
struct jit_block *block;
unsigned size;
unsigned capacity;
Expand All @@ -17,7 +18,6 @@ void gen_start(addr_t addr, struct gen_state *state);
void gen_exit(struct gen_state *state);
void gen_end(struct gen_state *state);

int gen_step32(struct gen_state *state, struct tlb *tlb);
int gen_step16(struct gen_state *state, struct tlb *tlb);
int gen_step(struct gen_state *state, struct tlb *tlb);

#endif
4 changes: 2 additions & 2 deletions jit/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ static struct jit_block *jit_block_compile(addr_t ip, struct tlb *tlb) {
TRACE("%d %08x --- compiling:\n", current_pid(), ip);
gen_start(ip, &state);
while (true) {
if (!gen_step32(&state, tlb))
if (!gen_step(&state, tlb))
break;
// no block should span more than 2 pages
// guarantee this by limiting total block size to 1 page
Expand Down Expand Up @@ -243,7 +243,7 @@ static int cpu_step_to_interrupt(struct cpu_state *cpu, struct tlb *tlb) {
static int cpu_single_step(struct cpu_state *cpu, struct tlb *tlb) {
struct gen_state state;
gen_start(cpu->eip, &state);
gen_step32(&state, tlb);
gen_step(&state, tlb);
gen_exit(&state);
gen_end(&state);

Expand Down
1 change: 1 addition & 0 deletions jit/offsets.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void cpu() {

OFFSET(TLB, tlb, entries);
OFFSET(TLB, tlb, dirty_page);
OFFSET(TLB, tlb, segfault_addr);
OFFSET(TLB_ENTRY, tlb_entry, page);
OFFSET(TLB_ENTRY, tlb_entry, page_if_writable);
OFFSET(TLB_ENTRY, tlb_entry, data_minus_addr);
Expand Down

0 comments on commit 99dc0f5

Please sign in to comment.