Skip to content

Commit

Permalink
Merge pull request #80 from angr/feat/pivot
Browse files Browse the repository at this point in the history
Feat/pivot
  • Loading branch information
Kyle-Kyle authored Feb 13, 2024
2 parents b70a888 + 57c438f commit f32174e
Show file tree
Hide file tree
Showing 20 changed files with 1,268 additions and 635 deletions.
9 changes: 8 additions & 1 deletion angrop/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ def __init__(self, project, kernel_mode=False):
self.reg_set = self._get_reg_set()

a = project.arch
self.stack_pointer = a.register_names[a.sp_offset]
self.base_pointer = a.register_names[a.bp_offset]
self.syscall_insts = None
self.ret_insts = None

def _get_reg_set(self):
"""
Expand All @@ -34,6 +37,8 @@ class X86(ROPArch):
def __init__(self, project, kernel_mode=False):
super().__init__(project, kernel_mode=kernel_mode)
self.max_block_size = 20 # X86 and AMD64 have alignment of 1, 8 bytes is certainly not good enough
self.syscall_insts = {b"\xcd\x80"} # int 0x80
self.ret_insts = {b"\xc2", b"\xc3", b"\xca", b"\xcb"}

def block_make_sense(self, block):
capstr = str(block.capstone).lower()
Expand All @@ -47,7 +52,9 @@ def block_make_sense(self, block):
return True

class AMD64(X86):
pass
def __init__(self, project, kernel_mode=False):
super().__init__(project, kernel_mode=kernel_mode)
self.syscall_insts = {b"\x0f\x05"} # syscall

arm_conditional_postfix = ['eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl',
'vs', 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le', 'al']
Expand Down
33 changes: 30 additions & 3 deletions angrop/chain_builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .mem_changer import MemChanger
from .func_caller import FuncCaller
from .sys_caller import SysCaller
from .pivot import Pivot
from .. import rop_utils

l = logging.getLogger("angrop.chain_builder")
Expand All @@ -16,7 +17,7 @@ class ChainBuilder:
This class provides functions to generate common ropchains based on existing gadgets.
"""

def __init__(self, project, gadgets, arch, badbytes, roparg_filler):
def __init__(self, project, rop_gadgets, pivot_gadgets, syscall_gadgets, arch, badbytes, roparg_filler):
"""
Initializes the chain builder.
Expand All @@ -27,17 +28,24 @@ def __init__(self, project, gadgets, arch, badbytes, roparg_filler):
:param roparg_filler: An integer used when popping superfluous registers
"""
self.project = project
self.gadgets = gadgets
self.arch = arch
self.badbytes = badbytes
self.roparg_filler = roparg_filler

self.gadgets = rop_gadgets
self.pivot_gadgets = pivot_gadgets
self.syscall_gadgets = syscall_gadgets

self._reg_setter = RegSetter(self)
self._reg_mover = RegMover(self)
self._mem_writer = MemWriter(self)
self._mem_changer = MemChanger(self)
self._func_caller = FuncCaller(self)
self._pivot = Pivot(self)
self._sys_caller = SysCaller(self)
if not SysCaller.supported_os(self.project.loader.main_object.os):
l.warning("%s is not a fully supported OS, SysCaller may not work on this OS",
self.project.loader.main_object.os)

def set_regs(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -85,6 +93,10 @@ def write_to_mem(self, addr, data, fill_byte=b"\xff"):
addr = rop_utils.cast_rop_value(addr, self.project)
return self._mem_writer.write_to_mem(addr, data, fill_byte=fill_byte)

def pivot(self, thing):
thing = rop_utils.cast_rop_value(thing, self.project)
return self._pivot.pivot(thing)

def func_call(self, address, args, **kwargs):
"""
:param address: address or name of function to call
Expand All @@ -105,6 +117,9 @@ def do_syscall(self, syscall_num, args, needs_return=True, **kwargs):
:param needs_return: whether to continue the ROP after invoking the syscall
:return: a RopChain which makes the system with the requested register contents
"""
if not self._sys_caller:
l.exception("SysCaller does not support OS: %s", self.project.loader.main_object.os)
return None
return self._sys_caller.do_syscall(syscall_num, args, needs_return=needs_return, **kwargs)

def execve(self, path=None, path_addr=None):
Expand All @@ -113,6 +128,9 @@ def execve(self, path=None, path_addr=None):
:param path: path of binary of execute, default to b"/bin/sh\x00"
:param path_addr: where to store this path string
"""
if not self._sys_caller:
l.exception("SysCaller does not support OS: %s", self.project.loader.main_object.os)
return None
return self._sys_caller.execve(path=path, path_addr=path_addr)

def set_badbytes(self, badbytes):
Expand All @@ -121,6 +139,15 @@ def set_badbytes(self, badbytes):
def set_roparg_filler(self, roparg_filler):
self.roparg_filler = roparg_filler

def update(self):
self._reg_setter.update()
self._reg_mover.update()
self._mem_writer.update()
self._mem_changer.update()
#self._func_caller.update()
if self._sys_caller:
self._sys_caller.update()
self._pivot.update()

# should also be able to do execve by providing writable memory
# todo pivot stack
# todo pass values to setregs as symbolic variables
5 changes: 5 additions & 0 deletions angrop/chain_builder/builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import struct
from abc import abstractmethod
from functools import cmp_to_key

import claripy
Expand Down Expand Up @@ -200,3 +201,7 @@ def _get_fill_val(self):
return self.roparg_filler
else:
return claripy.BVS("filler", self.project.arch.bits)

@abstractmethod
def update(self):
raise NotImplementedError("each Builder class should have an `update` method!")
7 changes: 5 additions & 2 deletions angrop/chain_builder/mem_changer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class MemChanger(Builder):
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._mem_change_gadgets = None
self._mem_add_gadgets = None
self.update()

def update(self):
self._mem_change_gadgets = self._get_all_mem_change_gadgets(self.chain_builder.gadgets)
self._mem_add_gadgets = self._get_all_mem_add_gadgets()

Expand All @@ -41,8 +46,6 @@ def _get_all_mem_change_gadgets(gadgets):
for g in gadgets:
if len(g.mem_reads) + len(g.mem_writes) > 0 or len(g.mem_changes) != 1:
continue
if g.bp_moves_to_sp:
continue
if g.stack_change <= 0:
continue
for m_access in g.mem_changes:
Expand Down
6 changes: 4 additions & 2 deletions angrop/chain_builder/mem_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class MemWriter(Builder):
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._mem_write_gadgets = None
self.update()

def update(self):
self._mem_write_gadgets = self._get_all_mem_write_gadgets(self.chain_builder.gadgets)

def _set_regs(self, *args, **kwargs):
Expand All @@ -28,8 +32,6 @@ def _get_all_mem_write_gadgets(gadgets):
for g in gadgets:
if len(g.mem_reads) + len(g.mem_changes) > 0 or len(g.mem_writes) != 1:
continue
if g.bp_moves_to_sp:
continue
if g.stack_change <= 0:
continue
for m_access in g.mem_writes:
Expand Down
150 changes: 150 additions & 0 deletions angrop/chain_builder/pivot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import logging
import functools

from .builder import Builder
from .. import rop_utils
from ..errors import RopException

l = logging.getLogger(__name__)

def cmp(g1, g2):
if len(g1.sp_reg_controllers) < len(g2.sp_reg_controllers):
return -1
if len(g1.sp_reg_controllers) > len(g2.sp_reg_controllers):
return 1

if g1.stack_change + g1.stack_change_after_pivot < g2.stack_change + g2.stack_change_after_pivot:
return -1
if g1.stack_change + g1.stack_change_after_pivot > g2.stack_change + g2.stack_change_after_pivot:
return 1

if g1.block_length < g2.block_length:
return -1
if g1.block_length > g2.block_length:
return 1
return 0

class Pivot(Builder):
"""
a chain_builder that builds stack pivoting rop chains
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._pivot_gadgets = None
self.update()

def update(self):
self._pivot_gadgets = self._filter_gadgets(self.chain_builder.pivot_gadgets)

def pivot(self, thing):
if thing.is_register:
return self.pivot_reg(thing)
return self.pivot_addr(thing)

def pivot_addr(self, addr):
for gadget in self._pivot_gadgets:
# constrain the successor to be at the gadget
# emulate 'pop pc'
init_state = self.make_sim_state(gadget.addr)

# step the gadget
final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state)

# constrain the final sp
final_state.solver.add(final_state.regs.sp == addr.data)
registers = {}
for x in gadget.sp_reg_controllers:
registers[x] = final_state.solver.eval(init_state.registers.load(x))
chain = self.chain_builder.set_regs(**registers)

try:
chain.add_gadget(gadget)
# iterate through the stack values that need to be in the chain
sp = init_state.regs.sp
arch_bytes = self.project.arch.bytes
for i in range(gadget.stack_change // arch_bytes):
sym_word = init_state.memory.load(sp + arch_bytes*i, arch_bytes,
endness=self.project.arch.memory_endness)

val = final_state.solver.eval(sym_word)
chain.add_value(val)
state = chain.exec()
if state.solver.eval(state.regs.sp == addr.data):
return chain
except Exception: # pylint: disable=broad-exception-caught
continue

raise RopException(f"Fail to pivot the stack to {addr.data}!")

def pivot_reg(self, reg_val):
reg = reg_val.reg_name
for gadget in self._pivot_gadgets:
if reg not in gadget.sp_reg_controllers:
continue

init_state = self.make_sim_state(gadget.addr)
final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state)

chain = self.chain_builder.set_regs()

try:
chain.add_gadget(gadget)
# iterate through the stack values that need to be in the chain
sp = init_state.regs.sp
arch_bytes = self.project.arch.bytes
for i in range(gadget.stack_change // arch_bytes):
sym_word = init_state.memory.load(sp + arch_bytes*i, arch_bytes,
endness=self.project.arch.memory_endness)

val = final_state.solver.eval(sym_word)
chain.add_value(val)
state = chain.exec()
variables = set(state.regs.sp.variables)
if len(variables) == 1 and variables.pop().startswith(f'reg_{reg}'):
return chain
else:
insts = [str(self.project.factory.block(g.addr).capstone) for g in chain._gadgets]
chain_str = '\n-----\n'.join(insts)
l.exception("Somehow angrop thinks\n%s\ncan be use for stack pivoting", chain_str)
except Exception: # pylint: disable=broad-exception-caught
continue

raise RopException(f"Fail to pivot the stack to {reg}!")

@staticmethod
def same_effect(g1, g2):
if g1.sp_controllers != g2.sp_controllers:
return False
if g1.changed_regs != g2.changed_regs:
return False
return True

def better_than(self, g1, g2):
if not self.same_effect(g1, g2):
return False
if g1.stack_change > g2.stack_change:
return False
if g1.stack_change_after_pivot > g2.stack_change_after_pivot:
return False
if g1.num_mem_access > g2.num_mem_access:
return False
return True

def _filter_gadgets(self, gadgets):
"""
filter gadgets having the same effect
"""
gadgets = set(gadgets)
skip = set({})
while True:
to_remove = set({})
for g in gadgets-skip:
to_remove.update({x for x in gadgets-{g} if self.better_than(g, x)})
if to_remove:
break
skip.add(g)
if not to_remove:
break
gadgets -= to_remove
gadgets = sorted(gadgets, key=functools.cmp_to_key(cmp))
return gadgets
8 changes: 4 additions & 4 deletions angrop/chain_builder/reg_mover.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class RegMover(Builder):
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._reg_moving_gadgets = None
self.update()

def update(self):
self._reg_moving_gadgets = self._filter_gadgets(self.chain_builder.gadgets)

def verify(self, chain, preserve_regs, registers):
Expand Down Expand Up @@ -142,12 +146,8 @@ def _find_relevant_gadgets(self, moves):
"""
gadgets = set()
for g in self._reg_moving_gadgets:
if g.makes_syscall:
continue
if g.has_symbolic_access():
continue
if g.bp_moves_to_sp:
continue
if moves.intersection(set(g.reg_moves)):
gadgets.add(g)
return gadgets
16 changes: 5 additions & 11 deletions angrop/chain_builder/reg_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ class RegSetter(Builder):
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)
self._reg_setting_gadgets = None
self.hard_chain_cache = None
self.update()

def update(self):
self._reg_setting_gadgets = self._filter_gadgets(self.chain_builder.gadgets)
self.hard_chain_cache = {}

Expand Down Expand Up @@ -104,8 +109,6 @@ def _find_relevant_gadgets(self, **registers):
"""
gadgets = set({})
for g in self._reg_setting_gadgets:
if g.makes_syscall:
continue
if g.has_symbolic_access():
continue
for reg in registers:
Expand Down Expand Up @@ -329,15 +332,9 @@ def _find_reg_setting_gadgets(self, modifiable_memory_range=None, use_partial_co
continue

for g in gadgets:
# ignore gadgets which make a syscall when setting regs
if g.makes_syscall:
continue
# ignore gadgets which don't have a positive stack change
if g.stack_change <= 0:
continue
# ignore base pointer moves for now
if g.bp_moves_to_sp:
continue

stack_change = data[regs][1]
new_stack_change = stack_change + g.stack_change
Expand Down Expand Up @@ -470,9 +467,6 @@ def _check_if_sufficient_partial_control(self, gadget, reg, value):
# make sure the register doesnt depend on itself
if reg in gadget.reg_dependencies and reg in gadget.reg_dependencies[reg]:
return False
# make sure the gadget doesnt pop bp
if gadget.bp_moves_to_sp:
return False

# set the register
state = rop_utils.make_symbolic_state(self.project, self.arch.reg_set)
Expand Down
Loading

0 comments on commit f32174e

Please sign in to comment.