diff --git a/angrop/arch.py b/angrop/arch.py index db2d6af..ceed396 100644 --- a/angrop/arch.py +++ b/angrop/arch.py @@ -12,7 +12,10 @@ def __init__(self, project, kernel_mode=False): self.reg_set = self._get_reg_set() a = project.arch + self.stack_pointer = a.register_names[a.sp_offset] self.base_pointer = a.register_names[a.bp_offset] + self.syscall_insts = None + self.ret_insts = None def _get_reg_set(self): """ @@ -34,6 +37,8 @@ class X86(ROPArch): def __init__(self, project, kernel_mode=False): super().__init__(project, kernel_mode=kernel_mode) self.max_block_size = 20 # X86 and AMD64 have alignment of 1, 8 bytes is certainly not good enough + self.syscall_insts = {b"\xcd\x80"} # int 0x80 + self.ret_insts = {b"\xc2", b"\xc3", b"\xca", b"\xcb"} def block_make_sense(self, block): capstr = str(block.capstone).lower() @@ -47,7 +52,9 @@ def block_make_sense(self, block): return True class AMD64(X86): - pass + def __init__(self, project, kernel_mode=False): + super().__init__(project, kernel_mode=kernel_mode) + self.syscall_insts = {b"\x0f\x05"} # syscall arm_conditional_postfix = ['eq', 'ne', 'cs', 'hs', 'cc', 'lo', 'mi', 'pl', 'vs', 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le', 'al'] diff --git a/angrop/chain_builder/__init__.py b/angrop/chain_builder/__init__.py index 7357ba5..f87bc44 100644 --- a/angrop/chain_builder/__init__.py +++ b/angrop/chain_builder/__init__.py @@ -6,6 +6,7 @@ from .mem_changer import MemChanger from .func_caller import FuncCaller from .sys_caller import SysCaller +from .pivot import Pivot from .. import rop_utils l = logging.getLogger("angrop.chain_builder") @@ -16,7 +17,7 @@ class ChainBuilder: This class provides functions to generate common ropchains based on existing gadgets. """ - def __init__(self, project, gadgets, arch, badbytes, roparg_filler): + def __init__(self, project, rop_gadgets, pivot_gadgets, syscall_gadgets, arch, badbytes, roparg_filler): """ Initializes the chain builder. @@ -27,17 +28,24 @@ def __init__(self, project, gadgets, arch, badbytes, roparg_filler): :param roparg_filler: An integer used when popping superfluous registers """ self.project = project - self.gadgets = gadgets self.arch = arch self.badbytes = badbytes self.roparg_filler = roparg_filler + self.gadgets = rop_gadgets + self.pivot_gadgets = pivot_gadgets + self.syscall_gadgets = syscall_gadgets + self._reg_setter = RegSetter(self) self._reg_mover = RegMover(self) self._mem_writer = MemWriter(self) self._mem_changer = MemChanger(self) self._func_caller = FuncCaller(self) + self._pivot = Pivot(self) self._sys_caller = SysCaller(self) + if not SysCaller.supported_os(self.project.loader.main_object.os): + l.warning("%s is not a fully supported OS, SysCaller may not work on this OS", + self.project.loader.main_object.os) def set_regs(self, *args, **kwargs): """ @@ -85,6 +93,10 @@ def write_to_mem(self, addr, data, fill_byte=b"\xff"): addr = rop_utils.cast_rop_value(addr, self.project) return self._mem_writer.write_to_mem(addr, data, fill_byte=fill_byte) + def pivot(self, thing): + thing = rop_utils.cast_rop_value(thing, self.project) + return self._pivot.pivot(thing) + def func_call(self, address, args, **kwargs): """ :param address: address or name of function to call @@ -105,6 +117,9 @@ def do_syscall(self, syscall_num, args, needs_return=True, **kwargs): :param needs_return: whether to continue the ROP after invoking the syscall :return: a RopChain which makes the system with the requested register contents """ + if not self._sys_caller: + l.exception("SysCaller does not support OS: %s", self.project.loader.main_object.os) + return None return self._sys_caller.do_syscall(syscall_num, args, needs_return=needs_return, **kwargs) def execve(self, path=None, path_addr=None): @@ -113,6 +128,9 @@ def execve(self, path=None, path_addr=None): :param path: path of binary of execute, default to b"/bin/sh\x00" :param path_addr: where to store this path string """ + if not self._sys_caller: + l.exception("SysCaller does not support OS: %s", self.project.loader.main_object.os) + return None return self._sys_caller.execve(path=path, path_addr=path_addr) def set_badbytes(self, badbytes): @@ -121,6 +139,15 @@ def set_badbytes(self, badbytes): def set_roparg_filler(self, roparg_filler): self.roparg_filler = roparg_filler + def update(self): + self._reg_setter.update() + self._reg_mover.update() + self._mem_writer.update() + self._mem_changer.update() + #self._func_caller.update() + if self._sys_caller: + self._sys_caller.update() + self._pivot.update() + # should also be able to do execve by providing writable memory - # todo pivot stack # todo pass values to setregs as symbolic variables diff --git a/angrop/chain_builder/builder.py b/angrop/chain_builder/builder.py index cb037ae..894c03f 100644 --- a/angrop/chain_builder/builder.py +++ b/angrop/chain_builder/builder.py @@ -1,4 +1,5 @@ import struct +from abc import abstractmethod from functools import cmp_to_key import claripy @@ -200,3 +201,7 @@ def _get_fill_val(self): return self.roparg_filler else: return claripy.BVS("filler", self.project.arch.bits) + + @abstractmethod + def update(self): + raise NotImplementedError("each Builder class should have an `update` method!") diff --git a/angrop/chain_builder/mem_changer.py b/angrop/chain_builder/mem_changer.py index 7ff3b5c..66d5e20 100644 --- a/angrop/chain_builder/mem_changer.py +++ b/angrop/chain_builder/mem_changer.py @@ -15,6 +15,11 @@ class MemChanger(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) + self._mem_change_gadgets = None + self._mem_add_gadgets = None + self.update() + + def update(self): self._mem_change_gadgets = self._get_all_mem_change_gadgets(self.chain_builder.gadgets) self._mem_add_gadgets = self._get_all_mem_add_gadgets() @@ -41,8 +46,6 @@ def _get_all_mem_change_gadgets(gadgets): for g in gadgets: if len(g.mem_reads) + len(g.mem_writes) > 0 or len(g.mem_changes) != 1: continue - if g.bp_moves_to_sp: - continue if g.stack_change <= 0: continue for m_access in g.mem_changes: diff --git a/angrop/chain_builder/mem_writer.py b/angrop/chain_builder/mem_writer.py index 03d6313..9390539 100644 --- a/angrop/chain_builder/mem_writer.py +++ b/angrop/chain_builder/mem_writer.py @@ -17,6 +17,10 @@ class MemWriter(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) + self._mem_write_gadgets = None + self.update() + + def update(self): self._mem_write_gadgets = self._get_all_mem_write_gadgets(self.chain_builder.gadgets) def _set_regs(self, *args, **kwargs): @@ -28,8 +32,6 @@ def _get_all_mem_write_gadgets(gadgets): for g in gadgets: if len(g.mem_reads) + len(g.mem_changes) > 0 or len(g.mem_writes) != 1: continue - if g.bp_moves_to_sp: - continue if g.stack_change <= 0: continue for m_access in g.mem_writes: diff --git a/angrop/chain_builder/pivot.py b/angrop/chain_builder/pivot.py new file mode 100644 index 0000000..a61926b --- /dev/null +++ b/angrop/chain_builder/pivot.py @@ -0,0 +1,150 @@ +import logging +import functools + +from .builder import Builder +from .. import rop_utils +from ..errors import RopException + +l = logging.getLogger(__name__) + +def cmp(g1, g2): + if len(g1.sp_reg_controllers) < len(g2.sp_reg_controllers): + return -1 + if len(g1.sp_reg_controllers) > len(g2.sp_reg_controllers): + return 1 + + if g1.stack_change + g1.stack_change_after_pivot < g2.stack_change + g2.stack_change_after_pivot: + return -1 + if g1.stack_change + g1.stack_change_after_pivot > g2.stack_change + g2.stack_change_after_pivot: + return 1 + + if g1.block_length < g2.block_length: + return -1 + if g1.block_length > g2.block_length: + return 1 + return 0 + +class Pivot(Builder): + """ + a chain_builder that builds stack pivoting rop chains + """ + def __init__(self, chain_builder): + super().__init__(chain_builder) + self._pivot_gadgets = None + self.update() + + def update(self): + self._pivot_gadgets = self._filter_gadgets(self.chain_builder.pivot_gadgets) + + def pivot(self, thing): + if thing.is_register: + return self.pivot_reg(thing) + return self.pivot_addr(thing) + + def pivot_addr(self, addr): + for gadget in self._pivot_gadgets: + # constrain the successor to be at the gadget + # emulate 'pop pc' + init_state = self.make_sim_state(gadget.addr) + + # step the gadget + final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) + + # constrain the final sp + final_state.solver.add(final_state.regs.sp == addr.data) + registers = {} + for x in gadget.sp_reg_controllers: + registers[x] = final_state.solver.eval(init_state.registers.load(x)) + chain = self.chain_builder.set_regs(**registers) + + try: + chain.add_gadget(gadget) + # iterate through the stack values that need to be in the chain + sp = init_state.regs.sp + arch_bytes = self.project.arch.bytes + for i in range(gadget.stack_change // arch_bytes): + sym_word = init_state.memory.load(sp + arch_bytes*i, arch_bytes, + endness=self.project.arch.memory_endness) + + val = final_state.solver.eval(sym_word) + chain.add_value(val) + state = chain.exec() + if state.solver.eval(state.regs.sp == addr.data): + return chain + except Exception: # pylint: disable=broad-exception-caught + continue + + raise RopException(f"Fail to pivot the stack to {addr.data}!") + + def pivot_reg(self, reg_val): + reg = reg_val.reg_name + for gadget in self._pivot_gadgets: + if reg not in gadget.sp_reg_controllers: + continue + + init_state = self.make_sim_state(gadget.addr) + final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) + + chain = self.chain_builder.set_regs() + + try: + chain.add_gadget(gadget) + # iterate through the stack values that need to be in the chain + sp = init_state.regs.sp + arch_bytes = self.project.arch.bytes + for i in range(gadget.stack_change // arch_bytes): + sym_word = init_state.memory.load(sp + arch_bytes*i, arch_bytes, + endness=self.project.arch.memory_endness) + + val = final_state.solver.eval(sym_word) + chain.add_value(val) + state = chain.exec() + variables = set(state.regs.sp.variables) + if len(variables) == 1 and variables.pop().startswith(f'reg_{reg}'): + return chain + else: + insts = [str(self.project.factory.block(g.addr).capstone) for g in chain._gadgets] + chain_str = '\n-----\n'.join(insts) + l.exception("Somehow angrop thinks\n%s\ncan be use for stack pivoting", chain_str) + except Exception: # pylint: disable=broad-exception-caught + continue + + raise RopException(f"Fail to pivot the stack to {reg}!") + + @staticmethod + def same_effect(g1, g2): + if g1.sp_controllers != g2.sp_controllers: + return False + if g1.changed_regs != g2.changed_regs: + return False + return True + + def better_than(self, g1, g2): + if not self.same_effect(g1, g2): + return False + if g1.stack_change > g2.stack_change: + return False + if g1.stack_change_after_pivot > g2.stack_change_after_pivot: + return False + if g1.num_mem_access > g2.num_mem_access: + return False + return True + + def _filter_gadgets(self, gadgets): + """ + filter gadgets having the same effect + """ + gadgets = set(gadgets) + skip = set({}) + while True: + to_remove = set({}) + for g in gadgets-skip: + to_remove.update({x for x in gadgets-{g} if self.better_than(g, x)}) + if to_remove: + break + skip.add(g) + if not to_remove: + break + gadgets -= to_remove + gadgets = sorted(gadgets, key=functools.cmp_to_key(cmp)) + return gadgets diff --git a/angrop/chain_builder/reg_mover.py b/angrop/chain_builder/reg_mover.py index 2c0bd8e..dc9e7ef 100644 --- a/angrop/chain_builder/reg_mover.py +++ b/angrop/chain_builder/reg_mover.py @@ -16,6 +16,10 @@ class RegMover(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) + self._reg_moving_gadgets = None + self.update() + + def update(self): self._reg_moving_gadgets = self._filter_gadgets(self.chain_builder.gadgets) def verify(self, chain, preserve_regs, registers): @@ -142,12 +146,8 @@ def _find_relevant_gadgets(self, moves): """ gadgets = set() for g in self._reg_moving_gadgets: - if g.makes_syscall: - continue if g.has_symbolic_access(): continue - if g.bp_moves_to_sp: - continue if moves.intersection(set(g.reg_moves)): gadgets.add(g) return gadgets diff --git a/angrop/chain_builder/reg_setter.py b/angrop/chain_builder/reg_setter.py index aa84173..d484d32 100644 --- a/angrop/chain_builder/reg_setter.py +++ b/angrop/chain_builder/reg_setter.py @@ -17,6 +17,11 @@ class RegSetter(Builder): """ def __init__(self, chain_builder): super().__init__(chain_builder) + self._reg_setting_gadgets = None + self.hard_chain_cache = None + self.update() + + def update(self): self._reg_setting_gadgets = self._filter_gadgets(self.chain_builder.gadgets) self.hard_chain_cache = {} @@ -104,8 +109,6 @@ def _find_relevant_gadgets(self, **registers): """ gadgets = set({}) for g in self._reg_setting_gadgets: - if g.makes_syscall: - continue if g.has_symbolic_access(): continue for reg in registers: @@ -329,15 +332,9 @@ def _find_reg_setting_gadgets(self, modifiable_memory_range=None, use_partial_co continue for g in gadgets: - # ignore gadgets which make a syscall when setting regs - if g.makes_syscall: - continue # ignore gadgets which don't have a positive stack change if g.stack_change <= 0: continue - # ignore base pointer moves for now - if g.bp_moves_to_sp: - continue stack_change = data[regs][1] new_stack_change = stack_change + g.stack_change @@ -470,9 +467,6 @@ def _check_if_sufficient_partial_control(self, gadget, reg, value): # make sure the register doesnt depend on itself if reg in gadget.reg_dependencies and reg in gadget.reg_dependencies[reg]: return False - # make sure the gadget doesnt pop bp - if gadget.bp_moves_to_sp: - return False # set the register state = rop_utils.make_symbolic_state(self.project, self.arch.reg_set) diff --git a/angrop/chain_builder/sys_caller.py b/angrop/chain_builder/sys_caller.py index 32fcc78..327e2ed 100644 --- a/angrop/chain_builder/sys_caller.py +++ b/angrop/chain_builder/sys_caller.py @@ -1,14 +1,34 @@ import logging +import functools import angr from .func_caller import FuncCaller -from .. import common from ..errors import RopException -from ..rop_gadget import RopGadget l = logging.getLogger(__name__) +def cmp(g1, g2): + if g1.can_return and not g2.can_return: + return -1 + if not g1.can_return and g2.can_return: + return 1 + + if g1.starts_with_syscall and not g2.starts_with_syscall: + return -1 + if g2.starts_with_syscall and not g1.starts_with_syscall: + return 1 + + if g1.stack_change < g2.stack_change: + return -1 + if g1.stack_change > g2.stack_change: + return 1 + + if g1.block_length < g2.block_length: + return -1 + if g1.block_length > g2.block_length: + return 1 + return 0 class SysCaller(FuncCaller): """ handle linux system calls invocations, only support i386 and x86_64 atm @@ -16,36 +36,22 @@ class SysCaller(FuncCaller): def __init__(self, chain_builder): super().__init__(chain_builder) - self._syscall_instruction = None - if self.project.arch.linux_name == "x86_64": - self._syscall_instructions = {b"\x0f\x05"} - elif self.project.arch.linux_name == "i386": - self._syscall_instructions = {b"\xcd\x80"} - - self._execve_syscall = None - if "unix" in self.project.loader.main_object.os.lower(): - if self.project.arch.bits == 64: - self._execve_syscall = 59 - elif self.project.arch.bits == 32: - self._execve_syscall = 11 - else: - raise RopException("unknown unix platform") - - def _get_syscall_locations(self): - """ - :return: all the locations in the binary with a syscall instruction - """ - addrs = [] - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - num_bytes = segment.max_addr + 1 - segment.min_addr - read_bytes = self.project.loader.memory.load(segment.min_addr, num_bytes) - for syscall_instruction in self._syscall_instructions: - for loc in common.str_find_all(read_bytes, syscall_instruction): - addrs.append(loc + segment.min_addr) - return sorted(addrs) + self.syscall_gadgets = None + self.update() + + @staticmethod + def supported_os(os): + return "unix" in os.lower() + + def update(self): + self.syscall_gadgets = self._filter_gadgets(self.chain_builder.syscall_gadgets) + + @staticmethod + def _filter_gadgets(gadgets): + return sorted(gadgets, key=functools.cmp_to_key(cmp)) def _try_invoke_execve(self, path_addr): + execve_syscall = 0x3b if self.project.arch.bits == 64 else 0xb # next, try to invoke execve(path, ptr, ptr), where ptr points is either NULL or nullptr if 0 not in self.badbytes: ptr = 0 @@ -54,7 +60,7 @@ def _try_invoke_execve(self, path_addr): ptr = nullptr try: - return self.do_syscall(self._execve_syscall, [path_addr, ptr, ptr], + return self.do_syscall(execve_syscall, [path_addr, ptr, ptr], use_partial_controllers=False, needs_return=False) except RopException: pass @@ -62,7 +68,7 @@ def _try_invoke_execve(self, path_addr): # Try to use partial controllers l.warning("Trying to use partial controllers for syscall") try: - return self.do_syscall(self._execve_syscall, [path_addr, 0, 0], + return self.do_syscall(execve_syscall, [path_addr, 0, 0], use_partial_controllers=True, needs_return=False) except RopException: pass @@ -70,11 +76,10 @@ def _try_invoke_execve(self, path_addr): raise RopException("Fail to invoke execve!") def execve(self, path=None, path_addr=None): - # look for good syscall gadgets - syscall_locs = self._get_syscall_locations() - syscall_locs = [x for x in syscall_locs if not self._word_contain_badbyte(x)] - if len(syscall_locs) == 0: - raise RopException("No syscall instruction available") + if "unix" not in self.project.loader.main_object.os.lower(): + raise RopException("unknown unix platform") + if not self.syscall_gadgets: + raise RopException("target does not contain syscall gadget!") # determine the execution path if path is None: @@ -113,7 +118,6 @@ def do_syscall(self, syscall_num, args, needs_return=True, **kwargs): :param needs_return: whether to continue the ROP after invoking the syscall :return: a RopChain which makes the system with the requested register contents """ - # set the system call number extra_regs = {} extra_regs[self.project.arch.register_names[self.project.arch.syscall_num_offset]] = syscall_num @@ -121,25 +125,17 @@ def do_syscall(self, syscall_num, args, needs_return=True, **kwargs): # find small stack change syscall gadget that also fits the stack arguments we want # FIXME: does any arch/OS take syscall arguments on stack? (windows? sysenter?) - smallest = None - stack_arguments = args[len(cc.ARG_REGS):] - for gadget in [x for x in self.chain_builder.gadgets if x.starts_with_syscall]: - # adjust stack change for ret - stack_change = gadget.stack_change - self.project.arch.bytes - required_space = len(stack_arguments) * self.project.arch.bytes - if stack_change >= required_space: - if smallest is None or gadget.stack_change < smallest.stack_change: - smallest = gadget - - if smallest is None and not needs_return: - syscall_locs = self._get_syscall_locations() - if len(syscall_locs) > 0: - smallest = RopGadget(syscall_locs[0]) - smallest.block_length = self.project.factory.block(syscall_locs[0]).size - smallest.stack_change = self.project.arch.bits - - if smallest is None: - raise RopException("No suitable syscall gadgets found") - - return self._func_call(smallest, cc, args, extra_regs=extra_regs, + + if not self.syscall_gadgets: + raise RopException("target does not contain syscall gadget!") + + for gadget in self.syscall_gadgets: + if needs_return and not gadget.can_return: + continue + try: + return self._func_call(gadget, cc, args, extra_regs=extra_regs, needs_return=needs_return, **kwargs) + except Exception: # pylint: disable=broad-exception-caught + continue + + raise RopException(f"Fail to invoke syscall {syscall_num} with arguments: {args}!") diff --git a/angrop/gadget_finder/__init__.py b/angrop/gadget_finder/__init__.py new file mode 100644 index 0000000..2ff6db4 --- /dev/null +++ b/angrop/gadget_finder/__init__.py @@ -0,0 +1,328 @@ +import re +import logging +from multiprocessing import Pool +from collections import defaultdict + +import tqdm + +from angr.errors import SimEngineError, SimMemoryError +from angr.misc.loggers import CuteFormatter +from angr.analyses.bindiff import differing_constants +from angr.analyses.bindiff import UnmatchedStatementsException + +from . import gadget_analyzer +from ..arch import get_arch +from ..errors import RopException +from ..arch import ARM, X86, AMD64 + +l = logging.getLogger(__name__) + +logging.getLogger('pyvex.lifting').setLevel("ERROR") + + +_global_gadget_analyzer = None + +# disable loggers in each worker +def _disable_loggers(): + for handler in logging.root.handlers: + if type(handler.formatter) == CuteFormatter: + logging.root.removeHandler(handler) + return + +# global initializer for multiprocessing +def _set_global_gadget_analyzer(rop_gadget_analyzer): + global _global_gadget_analyzer # pylint: disable=global-statement + _global_gadget_analyzer = rop_gadget_analyzer + _disable_loggers() + +def run_worker(addr): + return _global_gadget_analyzer.analyze_gadget(addr) + +class GadgetFinder: + """ + a class to find ROP gadgets + """ + def __init__(self, project, fast_mode=None, only_check_near_rets=True, max_block_size=None, + max_sym_mem_access=None, is_thumb=False, kernel_mode=False): + # configurations + self.project = project + self.fast_mode = fast_mode + self.arch = get_arch(self.project, kernel_mode=kernel_mode) + self.only_check_near_rets = only_check_near_rets + self.kernel_mode = kernel_mode + + if only_check_near_rets and not isinstance(self.arch, (X86, AMD64)): + l.warning("only_check_near_rets only makes sense for i386/amd64, setting it to False") + self.only_check_near_rets = False + + # override parameters + if max_block_size: + self.arch.max_block_size = max_block_size + if max_sym_mem_access: + self.arch.max_sym_mem_access = max_sym_mem_access + if is_thumb: + self.arch.set_thumb() + + # internal stuff + self._ret_locations = None + self._syscall_locations = None + self._cache = None # cache seen blocks, dict(block_hash => sets of addresses) + self._gadget_analyzer = None + + # silence annoying loggers + logging.getLogger('angr.engines.vex.ccall').setLevel(logging.CRITICAL) + logging.getLogger('angr.engines.vex.expressions.ccall').setLevel(logging.CRITICAL) + logging.getLogger('angr.engines.vex.irop').setLevel(logging.CRITICAL) + logging.getLogger('angr.state_plugins.symbolic_memory').setLevel(logging.CRITICAL) + logging.getLogger('pyvex.lifting.libvex').setLevel(logging.CRITICAL) + logging.getLogger('angr.procedures.cgc.deallocate').setLevel(logging.CRITICAL) + + @property + def gadget_analyzer(self): + if self._gadget_analyzer is not None: + return self._gadget_analyzer + self._initialize_gadget_analyzer() + return self._gadget_analyzer + + def _initialize_gadget_analyzer(self): + + self._syscall_locations = self._get_syscall_locations_by_string() + + # find locations to analyze + if self.only_check_near_rets and not self._ret_locations: + self._ret_locations = self._get_ret_locations() + num_to_check = self._num_addresses_to_check() + + # fast mode + if self.fast_mode is None: + if num_to_check > 20000: + self.fast_mode = True + l.warning("Enabling fast mode for large binary") + else: + self.fast_mode = False + if self.fast_mode: + self.arch.max_block_size = 12 + self.arch.max_sym_mem_access = 1 + # Recalculate num addresses to check based on fast_mode settings + num_to_check = self._num_addresses_to_check() + + l.info("There are %d addresses within %d bytes of a ret", + num_to_check, self.arch.max_block_size) + + self._gadget_analyzer = gadget_analyzer.GadgetAnalyzer(self.project, self.fast_mode, arch=self.arch, + kernel_mode=self.kernel_mode) + + def analyze_gadget(self, addr): + return self.gadget_analyzer.analyze_gadget(addr) + + def get_duplicates(self): + """ + return duplicates that have been seen at least twice + """ + cache = self._cache + return {k:v for k,v in cache.items() if len(v) >= 2} + + def find_gadgets(self, processes=4, show_progress=True): + gadgets = [] + self._cache = defaultdict(set) + + initargs = (self.gadget_analyzer,) + with Pool(processes=processes, initializer=_set_global_gadget_analyzer, initargs=initargs) as pool: + it = pool.imap_unordered(run_worker, self._addresses_to_check_with_caching(show_progress), chunksize=5) + for gadget in it: + if gadget is not None: + gadgets.append(gadget) + + return sorted(gadgets, key=lambda x: x.addr), self.get_duplicates() + + def find_gadgets_single_threaded(self, show_progress=True): + gadgets = [] + self._cache = defaultdict(set) + + assert self.gadget_analyzer is not None + + for addr in self._addresses_to_check_with_caching(show_progress): + gadget = self.gadget_analyzer.analyze_gadget(addr) + if gadget is not None: + gadgets.append(gadget) + + return sorted(gadgets, key=lambda x: x.addr), self.get_duplicates() + + def _block_has_ip_relative(self, addr, bl): + """ + Checks if a block has any ip relative instructions + """ + string = bl.bytes + test_addr = 0x41414140 + addr % 0x10 + bl2 = self.project.factory.block(test_addr, byte_string=string) + try: + diff_constants = differing_constants(bl, bl2) + except UnmatchedStatementsException: + return True + # check if it changes if we move it + bl_end = addr + bl.size + bl2_end = test_addr + bl2.size + filtered_diffs = [] + for d in diff_constants: + if d.value_a < addr or d.value_a >= bl_end or \ + d.value_b < test_addr or d.value_b >= bl2_end: + filtered_diffs.append(d) + return len(filtered_diffs) > 0 + + def _addresses_to_check_with_caching(self, show_progress=True): + num_addrs = self._num_addresses_to_check() + + iterable = self._addresses_to_check() + if show_progress: + iterable = tqdm.tqdm(iterable=iterable, smoothing=0, total=num_addrs, + desc="ROP", maxinterval=0.5, dynamic_ncols=True) + + for a in iterable: + try: + bl = self.project.factory.block(a) + if bl.size > self.arch.max_block_size: + continue + except (SimEngineError, SimMemoryError): + continue + if self._is_simple_gadget(a, bl): + h = self.block_hash(bl) + self._cache[h].add(a) + + yield a + + def block_hash(self, block):# pylint:disable=no-self-use + """ + a hash to uniquely identify a simple block + TODO: block.bytes is too primitive + """ + return block.bytes + + def _addresses_to_check(self): + """ + :return: all the addresses to check + """ + # align block size + alignment = self.arch.alignment + offset = 1 if isinstance(self.arch, ARM) and self.arch.is_thumb else 0 + if self.only_check_near_rets: + block_size = (self.arch.max_block_size & ((1 << self.project.arch.bits) - alignment)) + alignment + slices = [(addr-block_size, addr) for addr in self._ret_locations] + current_addr = 0 + for st, _ in slices: + current_addr = max(current_addr, st) + end_addr = st + block_size + alignment + for i in range(current_addr, end_addr, alignment): + segment = self.project.loader.main_object.find_segment_containing(i) + if segment is not None and segment.is_executable: + yield i+offset + current_addr = max(current_addr, end_addr) + else: + for addr in self._syscall_locations: + yield addr+offset + for segment in self.project.loader.main_object.segments: + if segment.is_executable: + l.debug("Analyzing segment with address range: 0x%x, 0x%x", segment.min_addr, segment.max_addr) + start = segment.min_addr + (alignment - segment.min_addr % alignment) + for addr in range(start, start+segment.memsize, alignment): + yield addr+offset + + def _num_addresses_to_check(self): + if self.only_check_near_rets: + # TODO: This could probably be optimized further by fewer segments checks (i.e. iterating for segments and + # adding ranges instead of incrementing, instead of calling _addressses_to_check) although this is still a + # significant improvement. + return sum(1 for _ in self._addresses_to_check()) + else: + num = 0 + alignment = self.arch.alignment + for segment in self.project.loader.main_object.segments: + if segment.is_executable: + num += segment.memsize // alignment + return num + len(self._syscall_locations) + + def _get_ret_locations(self): + """ + :return: all the locations in the binary with a ret instruction + """ + + try: + return self._get_ret_locations_by_string() + except RopException: + pass + + addrs = [] + seen = set() + for segment in self.project.loader.main_object.segments: + if not segment.is_executable: + continue + + alignment = self.arch.alignment + min_addr = segment.min_addr + (alignment - segment.min_addr % alignment) + + # iterate through the code looking for rets + for addr in range(min_addr, segment.max_addr, alignment): + # dont recheck addresses we've seen before + if addr in seen: + continue + try: + block = self.project.factory.block(addr) + # if it has a ret get the return address + if block.vex.jumpkind.startswith("Ijk_Ret"): + ret_addr = block.instruction_addrs[-1] + # hack for mips pipelining + if self.project.arch.linux_name.startswith("mips"): + ret_addr = block.instruction_addrs[-2] + if ret_addr not in seen: + addrs.append(ret_addr) + # save the addresses in the block + seen.update(block.instruction_addrs) + except (SimEngineError, SimMemoryError): + pass + + return sorted(addrs) + + def _get_ret_locations_by_string(self): + """ + uses a string filter to find the return instructions + :return: all the locations in the binary with a ret instruction + """ + if not self.arch.ret_insts: + raise RopException("Only have ret strings for i386 and x86_64") + return self._get_locations_by_strings(self.arch.ret_insts) + + def _get_syscall_locations_by_string(self): + """ + uses a string filter to find all the system calls instructions + :return: all the locations in the binary with a system call instruction + """ + if not self.arch.syscall_insts: + l.warning("Only have syscall strings for i386 and x86_64") + return [] + return self._get_locations_by_strings(self.arch.syscall_insts) + + def _get_locations_by_strings(self, strings): + fmt = b'(' + b')|('.join(strings) + b')' + + addrs = [] + state = self.project.factory.entry_state() + for segment in self.project.loader.main_object.segments: + if not segment.is_executable: + continue + read_bytes = state.solver.eval(state.memory.load(segment.min_addr, segment.memsize), cast_to=bytes) + # find all occurrences of the ret_instructions + addrs += [segment.min_addr + m.start() for m in re.finditer(fmt, read_bytes)] + return sorted(addrs) + + def _is_simple_gadget(self, addr, block): + """ + is the gadget a simple gadget like + pop rax; ret + """ + if block.vex.jumpkind not in {'Ijk_Boring', 'Ijk_Call', 'Ijk_Ret'}: + return False + if block.vex.constant_jump_targets: + return False + if self._block_has_ip_relative(addr, block): + return False + return True + \ No newline at end of file diff --git a/angrop/gadget_analyzer.py b/angrop/gadget_finder/gadget_analyzer.py similarity index 75% rename from angrop/gadget_analyzer.py rename to angrop/gadget_finder/gadget_analyzer.py index 5d58c3f..b287e04 100644 --- a/angrop/gadget_analyzer.py +++ b/angrop/gadget_finder/gadget_analyzer.py @@ -5,13 +5,17 @@ import pyvex import claripy -from . import rop_utils -from .arch import get_arch -from .rop_gadget import RopGadget, RopMemAccess, RopRegMove, StackPivot -from .errors import RopException, RegNotFoundException +from .. import rop_utils +from ..arch import get_arch +from ..rop_gadget import RopGadget, RopMemAccess, RopRegMove, PivotGadget, SyscallGadget +from ..errors import RopException, RegNotFoundException l = logging.getLogger("angrop.gadget_analyzer") +# the maximum amount of stack shifting after reading saved IP that is allowed after pivoting +# like, mov rsp, rax; ret 0x1000 is not OK +# mov rsp, rax; ret 0x20 is OK +MAX_PIVOT_BYTES = 0x100 class GadgetAnalyzer: @@ -30,14 +34,10 @@ def __init__(self, project, fast_mode, kernel_mode=False, arch=None, stack_gsize # initial state that others are based off, all analysis should copy the state first and work on # the copied state self._stack_bsize = stack_gsize * self.project.arch.bytes # number of controllable bytes on stack - self._state = rop_utils.make_symbolic_state(self.project, self.arch.reg_set, stack_gsize=stack_gsize) + sym_reg_set = self.arch.reg_set.union({self.arch.base_pointer}) + self._state = rop_utils.make_symbolic_state(self.project, sym_reg_set, stack_gsize=stack_gsize) self._concrete_sp = self._state.solver.eval(self._state.regs.sp) - # architecture stuff. we assume every architecture has registers that implements stackframes, - # but may have different names than bp/sp - self._bp_name = self.project.arch.register_names[self.project.arch.bp_offset] - self._sp_name = self.project.arch.register_names[self.project.arch.sp_offset] - @rop_utils.timeout(3) def analyze_gadget(self, addr): """ @@ -56,10 +56,10 @@ def analyze_gadget(self, addr): if not self._can_reach_unconstrained(addr): l.debug("... cannot get to unconstrained successor according to static analysis") return None - init_state, final_state = self._reach_unconstrained(addr) - # TODO: properly handle stack pivoting gadgets, since angrop does not handle them for now - # we do not analyze them as well, check_pivot function is for it - ctrl_type = self._check_for_controll_type(init_state, final_state) + + init_state, final_state = self._reach_unconstrained_or_syscall(addr) + + ctrl_type = self._check_for_control_type(init_state, final_state) if not ctrl_type: # for example, jump outside of the controllable region l.debug("... cannot maintain the control flow hijacking primitive after executing the gadget") @@ -68,10 +68,13 @@ def analyze_gadget(self, addr): # Step 3: gadget effect analysis l.debug("... analyzing rop potential of block") gadget = self._create_gadget(addr, init_state, final_state, ctrl_type) + if not gadget: + return None # Step 4: filter out bad gadgets - # too many mem accesses - if not self._satisfies_mem_access_limits(final_state): + # too many mem accesses, it can only be done after gadget creation + # specifically, memory access analysis + if gadget.num_mem_access > self.arch.max_sym_mem_access: l.debug("... too many symbolic memory accesses") return None @@ -153,6 +156,15 @@ def _block_make_sense(self, addr): return True + def is_in_kernel(self, state): + ip = state.ip + if not ip.symbolic: + obj = self.project.loader.find_object_containing(ip.concrete_value) + if obj.binary == 'cle##kernel': + return True + return False + return False + def _can_reach_unconstrained(self, addr, max_steps=2): """ Use static analysis to check whether the address can lead to unconstrained targets @@ -179,20 +191,50 @@ def _can_reach_unconstrained(self, addr, max_steps=2): return self._can_reach_unconstrained(target_block_addr, max_steps-1) - def _reach_unconstrained(self, addr): + def _reach_unconstrained_or_syscall(self, addr): init_state = self._state.copy() init_state.ip = addr # it will raise errors if angr fails to step the state - final_state = rop_utils.step_to_unconstrained_successor(self.project, state=init_state) - + final_state = rop_utils.step_to_unconstrained_successor(self.project, state=init_state, stop_at_syscall=True) + + if self.is_in_kernel(final_state): + state = final_state.copy() + try: + succ = self.project.factory.successors(state) + state = succ.flat_successors[0] + state2 = rop_utils.step_to_unconstrained_successor(self.project, state=state) + except Exception: # pylint: disable=broad-exception-caught + return init_state, final_state + return init_state, state2 return init_state, final_state - @staticmethod - def _identify_transit_type(final_state, ctrl_type): + def _identify_transit_type(self, final_state, ctrl_type): # FIXME: not always jump, could be call as well if ctrl_type == 'register': return "jmp_reg" + if ctrl_type == 'syscall': + return ctrl_type + + if ctrl_type == 'pivot': + variables = list(final_state.ip.variables) + if all(x.startswith("sreg_") for x in variables): + return "jmp_reg" + for act in final_state.history.actions: + if act.type != 'mem': + continue + if act.size != self.project.arch.bits: + continue + if (act.data.ast == final_state.ip).symbolic or \ + not final_state.solver.eval(act.data.ast == final_state.ip): + continue + sols = final_state.solver.eval_upto(final_state.regs.sp-act.addr.ast, 2) + if len(sols) != 1: + continue + if sols[0] != final_state.arch.bytes: + continue + return "ret" + return "jmp_mem" assert ctrl_type == 'stack' @@ -208,7 +250,15 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type): transit_type = self._identify_transit_type(final_state, ctrl_type) # create the gadget - gadget = RopGadget(addr=addr) + if ctrl_type == 'syscall' or self._does_syscall(final_state): + gadget = SyscallGadget(addr=addr) + gadget.makes_syscall = self._does_syscall(final_state) + gadget.starts_with_syscall = self._starts_with_syscall(addr) + elif ctrl_type == 'pivot' or self._does_pivot(final_state): + gadget = PivotGadget(addr=addr) + else: + gadget = RopGadget(addr=addr) + # FIXME this doesnt handle multiple steps gadget.block_length = self.project.factory.block(addr).size gadget.transit_type = transit_type @@ -233,7 +283,7 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type): # compute sp change l.debug("... computing sp change") - self._compute_sp_change(init_state, gadget) + self._compute_sp_change(init_state, final_state, gadget) if gadget.stack_change % (self.project.arch.bytes) != 0: l.debug("... uneven sp change") return None @@ -242,15 +292,6 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type): #FIXME: technically, it can be negative, e.g. call instructions return None - # if the sp moves to the bp we have to handle it differently - if not gadget.bp_moves_to_sp and self._bp_name != self._sp_name: - rop_utils.make_reg_symbolic(init_state, self._bp_name) - final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) - - l.info("... checking for syscall availability") - gadget.makes_syscall = self._does_syscall(final_state) - gadget.starts_with_syscall = self._starts_with_syscall(addr) - l.info("... checking for controlled regs") self._check_reg_changes(final_state, init_state, gadget) @@ -275,29 +316,6 @@ def _create_gadget(self, addr, init_state, final_state, ctrl_type): return gadget - def _satisfies_mem_access_limits(self, state): - """ - :param symbolic_path: the successor symbolic path - :return: True/False indicating whether or not to keep the gadget - """ - # get all the memory accesses - symbolic_mem_accesses = [] - for a in reversed(state.history.actions): - if a.type == 'mem' and a.addr.ast.symbolic: - symbolic_mem_accesses.append(a) - if len(symbolic_mem_accesses) <= self.arch.max_sym_mem_access: - return True - - # allow mem changes (only add/subtract) to count as a single access - # FIXME: this logic looks terrible - if len(symbolic_mem_accesses) == 2 and self.arch.max_sym_mem_access == 1: - if symbolic_mem_accesses[0].action == "read" and symbolic_mem_accesses[1].action == "write" and \ - symbolic_mem_accesses[1].data.ast.op in ("__sub__", "__add__") and \ - symbolic_mem_accesses[1].data.ast.size() == self.project.arch.bits and \ - symbolic_mem_accesses[0].addr.ast is symbolic_mem_accesses[1].addr.ast: - return True - return False - def _analyze_concrete_regs(self, state, gadget): """ collect registers that are concretized after symbolically executing the block (for example, xor rax, rax) @@ -324,7 +342,7 @@ def _check_reg_changes(self, final_state, init_state, gadget): exit_target = exit_action.target.ast - stack_change = gadget.stack_change if not gadget.bp_moves_to_sp else None + stack_change = gadget.stack_change if type(gadget) == RopGadget else None for reg in self._get_reg_writes(final_state): # we assume any register in reg_writes changed @@ -392,17 +410,21 @@ def _check_reg_movers(self, symbolic_state, symbolic_p, reg_reads, gadget): gadget.reg_moves.append(RopRegMove(from_reg, reg, half_bits)) # TODO: need to handle reg calls - def _check_for_controll_type(self, init_state, final_state): + def _check_for_control_type(self, init_state, final_state): """ :return: the data provenance of the controlled ip in the final state, either the stack or registers """ + ip = final_state.ip + + # this gadget arrives a syscall + if self.is_in_kernel(final_state): + return 'syscall' + # the ip is controlled by stack - if self._check_if_stack_controls_ast(final_state.ip, init_state): + if self._check_if_stack_controls_ast(ip, init_state): return "stack" - ip = final_state.ip - # the ip is not controlled by regs if not ip.variables: return None @@ -412,6 +434,35 @@ def _check_for_controll_type(self, init_state, final_state): if all(x.startswith("sreg_") for x in variables): return "register" + # this is a stack pivoting gadget + if all(x.startswith("symbolic_read_") for x in variables) and len(final_state.regs.sp.variables) == 1: + # we don't fully control sp + if not init_state.solver.satisfiable(extra_constraints=[final_state.regs.sp == 0x41414100]): + return None + # make sure the control after pivot is reasonable + + # find where the ip is read from + saved_ip_addr = None + for act in final_state.history.actions: + if act.type == 'mem' and act.action == 'read': + if act.size == self.project.arch.bits and not (act.data.ast == ip).symbolic: + if init_state.solver.eval(act.data.ast == ip): + saved_ip_addr = act.addr.ast + break + if saved_ip_addr is None: + return None + + # if the saved ip is too far away from the final sp, that's a bad gadget + sols = final_state.solver.eval_upto(final_state.regs.sp - saved_ip_addr, 2) + if len(sols) != 1: # the saved ip has a symbolic distance from the final sp, bad + return None + offset = sols[0] + if offset > MAX_PIVOT_BYTES: # filter out gadgets like mov rsp, rax; ret 0x1000 + return None + if offset % self.project.arch.bytes != 0: # filter misaligned gadgets + return None + return "pivot" + return None @staticmethod @@ -458,51 +509,77 @@ def _check_if_stack_controls_ast(self, ast, initial_state, gadget_stack_change=N concrete_stack_s = initial_state.copy() concrete_stack_s.add_constraints( initial_state.memory.load(initial_state.regs.sp, stack_bytes_length) == concrete_stack) - test_constraint = (ast != test_val) + test_constraint = ast != test_val # stack must have set the register and it must be able to set the register to all 1's or all 0's ans = not concrete_stack_s.solver.satisfiable(extra_constraints=(test_constraint,)) and \ rop_utils.fast_unconstrained_check(initial_state, ast) return ans - def _compute_sp_change(self, symbolic_state, gadget): + def _compute_sp_change(self, init_state, final_state, gadget): """ - Computes the change in the stack pointer for a gadget, including whether or not it moves to the base pointer + Computes the change in the stack pointer for a gadget + for a PivotGadget, it is the sp change right before pivoting :param symbolic_state: the input symbolic state :param gadget: the gadget in which to store the sp change """ - # store symbolic sp and bp and check for dependencies - init_state = symbolic_state.copy() - init_state.regs.bp = init_state.solver.BVS("sreg_" + self._bp_name + "-", self.project.arch.bits) - init_state.regs.sp = init_state.solver.BVS("sreg_" + self._sp_name+ "-", self.project.arch.bits) - final_state = rop_utils.step_to_unconstrained_successor(self.project, init_state) - dependencies = self._get_reg_dependencies(final_state, "sp") - sp_change = final_state.regs.sp - init_state.regs.sp - - # analyze the results - gadget.bp_moves_to_sp = False - if len(dependencies) > 1: - raise RopException("SP has multiple dependencies") - if len(dependencies) == 0 and sp_change.symbolic: - raise RopException("SP change is uncontrolled") - - if len(dependencies) == 0 and not sp_change.symbolic: - stack_changes = [init_state.solver.eval(sp_change)] - elif list(dependencies)[0] == self._sp_name: - stack_changes = init_state.solver.eval_upto(sp_change, 2) - elif list(dependencies)[0] == self._bp_name: - # FIXME: I think this code is meant to handle leave; ret - # but I wonder whether lea rsp, [rbp+offset] is a thing - sp_change = final_state.regs.sp - init_state.regs.bp - stack_changes = init_state.solver.eval_upto(sp_change, 2) - gadget.bp_moves_to_sp = True - else: - raise RopException("SP does not depend on SP or BP") - - if len(stack_changes) != 1: - raise RopException("SP change is symbolic") - - gadget.stack_change = stack_changes[0] + if type(gadget) in (RopGadget, SyscallGadget): + dependencies = self._get_reg_dependencies(final_state, "sp") + sp_change = final_state.regs.sp - init_state.regs.sp + + # analyze the results + if len(dependencies) > 1: + raise RopException("SP has multiple dependencies") + if len(dependencies) == 0 and sp_change.symbolic: + raise RopException("SP change is uncontrolled") + + assert self.arch.base_pointer not in dependencies + if len(dependencies) == 0 and not sp_change.symbolic: + stack_changes = [init_state.solver.eval(sp_change)] + elif list(dependencies)[0] == self.arch.stack_pointer: + stack_changes = init_state.solver.eval_upto(sp_change, 2) + else: + raise RopException("SP does not depend on SP or BP") + + if len(stack_changes) != 1: + raise RopException("SP change is symbolic") + + gadget.stack_change = stack_changes[0] + + elif type(gadget) is PivotGadget: + final_state = rop_utils.step_to_unconstrained_successor(self.project, state=init_state, precise_action=True) + dependencies = self._get_reg_dependencies(final_state, "sp") + last_sp = None + init_sym_sp = None + prev_act = None + for act in final_state.history.actions: + if act.type == 'reg' and act.action == 'write' and act.storage == self.arch.stack_pointer: + if not act.data.ast.symbolic: + last_sp = act.data.ast + else: + init_sym_sp = act.data.ast + break + prev_act = act + if last_sp is not None: + gadget.stack_change = (last_sp - init_state.regs.sp).concrete_value + else: + gadget.stack_change = 0 + + # if is popped from stack, we need to compensate for the popped sp value on the stack + # if it is a pop, then sp comes from stack and the previous action must be a mem read + # and the data is the new sp + variables = init_sym_sp.variables + if prev_act and variables and all(x.startswith('symbolic_stack_') for x in variables): + if prev_act.type == 'mem' and prev_act.action == 'read' and prev_act.data.ast is init_sym_sp: + gadget.stack_change += self.project.arch.bytes + + assert init_sym_sp is not None + sols = final_state.solver.eval_upto(final_state.regs.sp - init_sym_sp, 2) + if len(sols) != 1: + raise RopException("This gadget pivots more than once, which is currently not handled") + gadget.stack_change_after_pivot = sols[0] + gadget.sp_reg_controllers = set(self._get_reg_controllers(init_state, final_state, 'sp', dependencies)) + gadget.sp_stack_controllers = {x for x in final_state.regs.sp.variables if x.startswith("symbolic_stack_")} def _build_mem_access(self, a, gadget, init_state, final_state): """ @@ -536,7 +613,7 @@ def _build_mem_access(self, a, gadget, init_state, final_state): elif len(test_data) == 1: mem_access.data_constant = test_data[0] else: - raise Exception("No data values, something went wrong") + raise RopException("No data values, something went wrong") elif a.action == "read": # for reads we want to know if any register will have the data after succ_state = final_state @@ -624,6 +701,25 @@ def _does_syscall(self, symbolic_p): return False + def _does_pivot(self, final_state): + """ + checks if the path does a stack pivoting at some point + :param final_state: the state that finishes the gadget execution + """ + for act in final_state.history.actions: + if act.type != 'reg' or act.action != 'write': + continue + try: + storage = act.storage + except KeyError: + continue + if storage != self.arch.stack_pointer: + continue + # this gadget has done symbolic pivoting if there is a symbolic write to the stack pointer + if act.data.symbolic: + return True + return False + def _analyze_mem_access(self, final_state, init_state, gadget): """ analyzes memory accesses and stores their info in the gadget @@ -632,6 +728,7 @@ def _analyze_mem_access(self, final_state, init_state, gadget): :param gadget: the gadget to store mem acccess in """ all_mem_actions = [] + sp_vars = final_state.regs.sp.variables # step 1: filter out irrelevant actions and irrelevant memory accesses for a in final_state.history.actions.hardcopy: @@ -642,6 +739,10 @@ def _analyze_mem_access(self, final_state, init_state, gadget): if isinstance(a.data.ast, (claripy.fp.FPV, claripy.ast.FP)): continue + # ignore read/write on stack after pivot + if a.addr.ast.symbolic and not a.addr.ast.variables - sp_vars: + continue + # ignore read/write on stack if not a.addr.ast.symbolic: addr_constant = init_state.solver.eval(a.addr.ast) @@ -682,63 +783,6 @@ def _analyze_mem_access(self, final_state, init_state, gadget): if a.action == "write": gadget.mem_writes.append(mem_access) - def _check_pivot(self, symbolic_p, symbolic_state, addr): - """ - Super basic pivot analysis. Pivots are not really used by angrop right now - :param symbolic_p: the stepped path, symbolic_state is an ancestor of it. - :param symbolic_state: input state for testing - :return: the pivot object - """ - if symbolic_p.history.depth > 1: - return None - pivot = None - reg_deps = rop_utils.get_ast_dependency(symbolic_p.regs.sp) - if len(reg_deps) == 1: - pivot = StackPivot(addr) - pivot.sp_from_reg = list(reg_deps)[0] - elif len(symbolic_p.regs.sp.variables) == 1 and \ - list(symbolic_p.regs.sp.variables)[0].startswith("symbolic_stack"): - offset = None - for a in symbolic_p.regs.sp.recursive_children_asts: - if a.op == "Extract" and a.depth == 2: - offset = a.args[2].size() - 1 - a.args[0] - if offset is None or offset % 8 != 0: - return None - offset_bytes = offset//8 - pivot = StackPivot(addr) - pivot.sp_popped_offset = offset_bytes - - if pivot is not None: - # verify no weird mem accesses - test_p = self.project.factory.simulation_manager(symbolic_state.copy()) - # step until we find the pivot action - for _ in range(self.project.factory.block(symbolic_state.addr).instructions): - test_p.step(num_inst=1) - if len(test_p.active) != 1: - return None - if test_p.one_active.regs.sp.symbolic: - # found the pivot action - break - # now iterate through the remaining instructions with a clean state - test_p.step(num_inst=1) - if len(test_p.active) != 1: - return None - succ1 = test_p.active[0] - ss = symbolic_state.copy() - ss.regs.ip = succ1.addr - succ = self.project.factory.successors(ss) - if len(succ.flat_successors + succ.unconstrained_successors) == 0: - return None - succ2 = (succ.flat_successors + succ.unconstrained_successors)[0] - - all_actions = succ1.history.actions.hardcopy + succ2.history.actions.hardcopy - for a in all_actions: - if a.type == "mem" and a.addr.ast.symbolic: - return None - return pivot - - return None - def _starts_with_syscall(self, addr): """ checks if the path starts with a system call @@ -804,7 +848,7 @@ def _get_reg_reads(self, path): reg_name = rop_utils.get_reg_name(self.project.arch, a.offset) if reg_name in self.arch.reg_set: all_reg_reads.add(reg_name) - elif reg_name != self._sp_name: + elif reg_name != self.arch.stack_pointer: l.info("reg read from register not in reg_set: %s", reg_name) except RegNotFoundException as e: l.debug(e) @@ -823,7 +867,7 @@ def _get_reg_writes(self, path): reg_name = rop_utils.get_reg_name(self.project.arch, a.offset) if reg_name in self.arch.reg_set: all_reg_writes.add(reg_name) - elif reg_name != self._sp_name: + elif reg_name != self.arch.stack_pointer: l.info("reg write from register not in reg_set: %s", reg_name) except RegNotFoundException as e: l.debug(e) diff --git a/angrop/rop.py b/angrop/rop.py index 81ecab0..d928232 100644 --- a/angrop/rop.py +++ b/angrop/rop.py @@ -1,44 +1,14 @@ import pickle import inspect import logging -from multiprocessing import Pool -import tqdm from angr import Analysis, register_analysis -from angr.errors import SimEngineError, SimMemoryError -from angr.misc.loggers import CuteFormatter -from angr.analyses.bindiff import differing_constants -from angr.analyses.bindiff import UnmatchedStatementsException from . import chain_builder -from . import gadget_analyzer -from . import common -from .arch import get_arch, ARM -from .errors import RopException -from .rop_gadget import RopGadget, StackPivot +from .gadget_finder import GadgetFinder +from .rop_gadget import RopGadget, PivotGadget, SyscallGadget l = logging.getLogger('angrop.rop') -logging.getLogger('pyvex.lifting').setLevel("ERROR") - - -_global_gadget_analyzer = None - -# disable loggers in each worker -def _disable_loggers(): - for handler in logging.root.handlers: - if type(handler.formatter) == CuteFormatter: - logging.root.removeHandler(handler) - return - -# global initializer for multiprocessing -def _set_global_gadget_analyzer(rop_gadget_analyzer): - global _global_gadget_analyzer # pylint: disable=global-statement - _global_gadget_analyzer = rop_gadget_analyzer - _disable_loggers() - -def run_worker(addr): - return _global_gadget_analyzer.analyze_gadget(addr) - # todo what if we have mov eax, [rsp+0x20]; ret (cache would need to know where it is or at least a min/max) # todo what if we have pop eax; mov ebx, eax; need to encode that we cannot set them to different values @@ -66,36 +36,24 @@ def __init__(self, only_check_near_rets=True, max_block_size=None, max_sym_mem_a :return: """ - # params - self.arch = get_arch(self.project, kernel_mode=kernel_mode) - self.kernel_mode = kernel_mode - self._only_check_near_rets = only_check_near_rets - - # override parameters - if max_block_size: - self.arch.max_block_size = max_block_size - if max_sym_mem_access: - self.arch.max_sym_mem_access = max_sym_mem_access - if is_thumb: - self.arch.set_thumb() - - # get ret locations - self._ret_locations = None - self._cache = {} + # private list of RopGadget's + self._all_gadgets = [] # all types of gadgets + self._duplicates = None # all equivalent gadgets (with the same instructions) - # list of RopGadget's - self._gadgets = [] - self.stack_pivots = [] - self._duplicates = [] + # public list of RopGadget's + self.rop_gadgets = [] # gadgets used for ROP, like pop rax; ret + self.pivot_gadgets = [] # gadgets used for stack pivoting, like mov rsp, rbp; ret + self.syscall_gadgets = [] # gadgets used for invoking system calls, such as syscall; ret or int 0x80; ret # RopChain settings self.badbytes = [] self.roparg_filler = None - self._fast_mode = fast_mode - - # gadget analyzer - self._gadget_analyzer = None + # gadget finder configurations + self.gadget_finder = GadgetFinder(self.project, fast_mode=fast_mode, only_check_near_rets=only_check_near_rets, + max_block_size=max_block_size, max_sym_mem_access=max_sym_mem_access, + is_thumb=is_thumb, kernel_mode=kernel_mode) + self.arch = self.gadget_finder.arch # chain builder self._chain_builder = None @@ -103,43 +61,43 @@ def __init__(self, only_check_near_rets=True, max_block_size=None, max_sym_mem_a if rebase is not None: l.warning("rebase is deprecated in angrop!") - # silence annoying loggers - logging.getLogger('angr.engines.vex.ccall').setLevel(logging.CRITICAL) - logging.getLogger('angr.engines.vex.expressions.ccall').setLevel(logging.CRITICAL) - logging.getLogger('angr.engines.vex.irop').setLevel(logging.CRITICAL) - logging.getLogger('angr.state_plugins.symbolic_memory').setLevel(logging.CRITICAL) - logging.getLogger('pyvex.lifting.libvex').setLevel(logging.CRITICAL) - logging.getLogger('angr.procedures.cgc.deallocate').setLevel(logging.CRITICAL) - - @property - def gadgets(self): - return [x for x in self._gadgets if not self._contain_badbytes(x.addr)] - - def _initialize_gadget_analyzer(self): - - # find locations to analyze - if self._only_check_near_rets and not self._ret_locations: - self._ret_locations = self._get_ret_locations() - num_to_check = self._num_addresses_to_check() - - # fast mode - if self._fast_mode is None: - if num_to_check > 20000: - self._fast_mode = True - l.warning("Enabling fast mode for large binary") - else: - self._fast_mode = False - if self._fast_mode: - self.arch.max_block_size = 12 - self.arch.max_sym_mem_access = 1 - # Recalculate num addresses to check based on fast_mode settings - num_to_check = self._num_addresses_to_check() - - l.info("There are %d addresses within %d bytes of a ret", - num_to_check, self.arch.max_block_size) - - self._gadget_analyzer = gadget_analyzer.GadgetAnalyzer(self.project, self._fast_mode, arch=self.arch, - kernel_mode=self.kernel_mode) + def _screen_gadgets(self): + # screen gadgets based on badbytes and gadget types + self.rop_gadgets = [] + self.pivot_gadgets = [] + for g in self._all_gadgets: + if self._contain_badbytes(g.addr): + # in case the gadget contains bad byte, try to take an equivalent one from + # the duplicates (other gadgets with the same instructions) + block = self.project.factory.block(g.addr) + h = self.gadget_finder.block_hash(block) + addr = None + if h not in self._duplicates: + continue + for addr in self._duplicates[h]: + if not self._contain_badbytes(addr): + break + if not addr: + continue + g = self.gadget_finder.analyze_gadget(addr) + if type(g) is RopGadget: + self.rop_gadgets.append(g) + if type(g) is PivotGadget: + self.pivot_gadgets.append(g) + if type(g) is SyscallGadget: + self.syscall_gadgets.append(g) + + self.chain_builder.gadgets = self.rop_gadgets + self.chain_builder.pivot_gadgets = self.pivot_gadgets + self.chain_builder.syscall_gadgets = self.syscall_gadgets + self.chain_builder.update() + + def analyze_gadget(self, addr): + g = self.gadget_finder.analyze_gadget(addr) + if g: + self._all_gadgets.append(g) + self._screen_gadgets() + return g def find_gadgets(self, processes=4, show_progress=True): """ @@ -148,33 +106,10 @@ def find_gadgets(self, processes=4, show_progress=True): Saves stack pivots in self.stack_pivots :param processes: number of processes to use """ - self._initialize_gadget_analyzer() - self._gadgets = [] - - - initargs = (self._gadget_analyzer,) - with Pool(processes=processes, initializer=_set_global_gadget_analyzer, initargs=initargs) as pool: - - it = pool.imap_unordered(run_worker, self._addresses_to_check_with_caching(show_progress), chunksize=5) - for gadget in it: - if gadget is not None: - if isinstance(gadget, RopGadget): - self._gadgets.append(gadget) - elif isinstance(gadget, StackPivot): - self.stack_pivots.append(gadget) - - # fix up gadgets from cache - for g in self._gadgets: - if g.addr in self._cache: - dups = {g.addr} - for addr in self._cache[g.addr]: - dups.add(addr) - g_copy = g.copy() - g_copy.addr = addr - self._gadgets.append(g_copy) - self._duplicates.append(dups) - self._gadgets = sorted(self._gadgets, key=lambda x: x.addr) - self._reload_chain_funcs() + self._all_gadgets, self._duplicates = self.gadget_finder.find_gadgets(processes=processes, + show_progress=show_progress) + self._screen_gadgets() + return self.rop_gadgets def find_gadgets_single_threaded(self, show_progress=True): """ @@ -182,30 +117,18 @@ def find_gadgets_single_threaded(self, show_progress=True): Saves gadgets in self.gadgets Saves stack pivots in self.stack_pivots """ - self._initialize_gadget_analyzer() - self._gadgets = [] + self._all_gadgets, self._duplicates = self.gadget_finder.find_gadgets_single_threaded( + show_progress=show_progress) + self._screen_gadgets() + return self.rop_gadgets - _set_global_gadget_analyzer(self._gadget_analyzer) - for _, addr in enumerate(self._addresses_to_check_with_caching(show_progress)): - gadget = _global_gadget_analyzer.analyze_gadget(addr) - if gadget is not None: - if isinstance(gadget, RopGadget): - self._gadgets.append(gadget) - elif isinstance(gadget, StackPivot): - self.stack_pivots.append(gadget) + def _get_cache_tuple(self): + return (self._all_gadgets, self._duplicates) - # fix up gadgets from cache - for g in self._gadgets: - if g.addr in self._cache: - dups = {g.addr} - for addr in self._cache[g.addr]: - dups.add(addr) - g_copy = g.copy() - g_copy.addr = addr - self._gadgets.append(g_copy) - self._duplicates.append(dups) - self._gadgets = sorted(self._gadgets, key=lambda x: x.addr) - self._reload_chain_funcs() + def _load_cache_tuple(self, tup): + self._all_gadgets = tup[0] + self._duplicates = tup[1] + self._screen_gadgets() def save_gadgets(self, path): """ @@ -230,12 +153,13 @@ def set_badbytes(self, badbytes): :param badbytes: a list of 8 bit integers """ if not isinstance(badbytes, list): - print("Require a list, e.g: [0x00, 0x09]") + l.error("Require a list, e.g: [0x00, 0x09]") return badbytes = [x if type(x) == int else ord(x) for x in badbytes] self.badbytes = badbytes - if len(self._gadgets) > 0: - self.chain_builder.set_badbytes(self.badbytes) + if self._chain_builder: + self._chain_builder.set_badbytes(self.badbytes) + self._screen_gadgets() def set_roparg_filler(self, roparg_filler): """ @@ -246,12 +170,11 @@ def set_roparg_filler(self, roparg_filler): :param roparg_filler: A integer which is used when popping useless register or None. """ if not isinstance(roparg_filler, (int, type(None))): - print("Require an integer, e.g: 0x41414141 or None") + l.error("Require an integer, e.g: 0x41414141 or None") return self.roparg_filler = roparg_filler - if len(self._gadgets) > 0: - self.chain_builder.set_roparg_filler(self.roparg_filler) + self.chain_builder.set_roparg_filler(self.roparg_filler) def get_badbytes(self): """ @@ -260,203 +183,22 @@ def get_badbytes(self): """ return self.badbytes - def _get_cache_tuple(self): - return self._gadgets, self.stack_pivots, self._duplicates, self._ret_locations, self._fast_mode, \ - self.arch, self._gadget_analyzer - - def _load_cache_tuple(self, cache_tuple): - self._gadgets, self.stack_pivots, self._duplicates, self._ret_locations, self._fast_mode, \ - self.arch, self._gadget_analyzer = cache_tuple - self._reload_chain_funcs() - - def _reload_chain_funcs(self): - for f_name, f in inspect.getmembers(self.chain_builder, predicate=inspect.ismethod): - if f_name.startswith("_"): - continue - setattr(self, f_name, f) - @property def chain_builder(self): if self._chain_builder is not None: return self._chain_builder - if len(self._gadgets) == 0: + + if len(self._all_gadgets) == 0: l.warning("Could not find gadgets for %s", self.project) l.warning("check your badbytes and make sure find_gadgets() or load_gadgets() was called.") - self._chain_builder = chain_builder.ChainBuilder(self.project, self.gadgets, - self.arch, self.badbytes, + self._chain_builder = chain_builder.ChainBuilder(self.project, self.rop_gadgets, self.pivot_gadgets, + self.syscall_gadgets, self.arch, self.badbytes, self.roparg_filler) - return self._chain_builder - - def _block_has_ip_relative(self, addr, bl): - """ - Checks if a block has any ip relative instructions - """ - string = bl.bytes - test_addr = 0x41414140 + addr % 0x10 - bl2 = self.project.factory.block(test_addr, byte_string=string) - try: - diff_constants = differing_constants(bl, bl2) - except UnmatchedStatementsException: - return True - # check if it changes if we move it - bl_end = addr + bl.size - bl2_end = test_addr + bl2.size - filtered_diffs = [] - for d in diff_constants: - if d.value_a < addr or d.value_a >= bl_end or \ - d.value_b < test_addr or d.value_b >= bl2_end: - filtered_diffs.append(d) - return len(filtered_diffs) > 0 - - def _addresses_to_check_with_caching(self, show_progress=True): - num_addrs = self._num_addresses_to_check() - seen = {} - - iterable = self._addresses_to_check() - if show_progress: - iterable = tqdm.tqdm(iterable=iterable, smoothing=0, total=num_addrs, - desc="ROP", maxinterval=0.5, dynamic_ncols=True) - - for a in iterable: - try: - bl = self.project.factory.block(a) - if bl.size > self.arch.max_block_size: - continue - block_data = bl.bytes - except (SimEngineError, SimMemoryError): - continue - if block_data in seen: - self._cache[seen[block_data]].add(a) + for f_name, f in inspect.getmembers(self._chain_builder, predicate=inspect.ismethod): + if f_name.startswith("_"): continue - if self._is_jumpkind_valid(bl.vex.jumpkind) and \ - len(bl.vex.constant_jump_targets) == 0 and \ - not self._block_has_ip_relative(a, bl): - seen[block_data] = a - self._cache[a] = set() - yield a - - def _addresses_to_check(self): - """ - :return: all the addresses to check - """ - # align block size - alignment = self.arch.alignment - offset = 1 if isinstance(self.arch, ARM) and self.arch.is_thumb else 0 - if self._only_check_near_rets: - block_size = (self.arch.max_block_size & ((1 << self.project.arch.bits) - alignment)) + alignment - slices = [(addr-block_size, addr) for addr in self._ret_locations] - current_addr = 0 - for st, _ in slices: - current_addr = max(current_addr, st) - end_addr = st + block_size + alignment - for i in range(current_addr, end_addr, alignment): - segment = self.project.loader.main_object.find_segment_containing(i) - if segment is not None and segment.is_executable: - yield i+offset - current_addr = max(current_addr, end_addr) - else: - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - l.debug("Analyzing segment with address range: 0x%x, 0x%x", segment.min_addr, segment.max_addr) - start = segment.min_addr + (alignment - segment.min_addr % alignment) - for addr in range(start, segment.max_addr, alignment): - yield addr+offset - - def _num_addresses_to_check(self): - if self._only_check_near_rets: - # TODO: This could probably be optimized further by fewer segments checks (i.e. iterating for segments and - # adding ranges instead of incrementing, instead of calling _addressses_to_check) although this is still a - # significant improvement. - return sum(1 for _ in self._addresses_to_check()) - else: - num = 0 - alignment = self.arch.alignment - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - start = segment.min_addr + (alignment - segment.min_addr % alignment) - num += (segment.max_addr - start) // alignment - return num - - def _get_ret_locations(self): - """ - :return: all the locations in the binary with a ret instruction - """ - - try: - return self._get_ret_locations_by_string() - except RopException: - pass - - addrs = [] - seen = set() - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - num_bytes = segment.max_addr-segment.min_addr - - alignment = self.arch.alignment - - # iterate through the code looking for rets - for addr in range(segment.min_addr, segment.min_addr + num_bytes, alignment): - # dont recheck addresses we've seen before - if addr in seen: - continue - try: - block = self.project.factory.block(addr) - # it it has a ret get the return address - if block.vex.jumpkind.startswith("Ijk_Ret"): - ret_addr = block.instruction_addrs[-1] - # hack for mips pipelining - if self.project.arch.linux_name.startswith("mips"): - ret_addr = block.instruction_addrs[-2] - if ret_addr not in seen: - addrs.append(ret_addr) - # save the addresses in the block - seen.update(block.instruction_addrs) - except (SimEngineError, SimMemoryError): - pass - - return sorted(addrs) - - def _get_ret_locations_by_string(self): - """ - uses a string filter to find the return instructions - :return: all the locations in the binary with a ret instruction - """ - if self.project.arch.linux_name in ("x86_64", "i386"): - ret_instructions = {b"\xc2", b"\xc3", b"\xca", b"\xcb"} - else: - raise RopException("Only have ret strings for i386 and x86_64") - - addrs = [] - try: - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - num_bytes = segment.max_addr-segment.min_addr - read_bytes = self.project.loader.memory.load(segment.min_addr, num_bytes) - for ret_instruction in ret_instructions: - for loc in common.str_find_all(read_bytes, ret_instruction): - addrs.append(loc + segment.min_addr) - except KeyError: - l.warning("Key error with segment analysis") - # try reading from state - state = self.project.factory.entry_state() - for segment in self.project.loader.main_object.segments: - if segment.is_executable: - num_bytes = segment.max_addr - segment.min_addr - - read_bytes = state.solver.eval(state.memory.load(segment.min_addr, num_bytes), cast_to=bytes) - for ret_instruction in ret_instructions: - for loc in common.str_find_all(read_bytes, ret_instruction): - addrs.append(loc + segment.min_addr) - - return sorted(addrs) - - @staticmethod - def _is_jumpkind_valid(jk): - - if jk in {'Ijk_Boring', 'Ijk_Call', 'Ijk_Ret'}: - return True - return False + setattr(self, f_name, f) + return self._chain_builder # inspired by ropper def _contain_badbytes(self, addr): diff --git a/angrop/rop_gadget.py b/angrop/rop_gadget.py index ecf4daf..09db40e 100644 --- a/angrop/rop_gadget.py +++ b/angrop/rop_gadget.py @@ -95,20 +95,23 @@ class RopGadget: """ def __init__(self, addr): self.addr = addr + self.block_length = None + self.stack_change = None + + # register effect information self.changed_regs = set() self.popped_regs = set() self.concrete_regs = {} self.reg_dependencies = {} # like rax might depend on rbx, rcx self.reg_controllers = {} # like rax might be able to be controlled by rbx (for any value of rcx) - self.stack_change = None + self.reg_moves = [] + + # memory effect information self.mem_reads = [] self.mem_writes = [] self.mem_changes = [] - self.reg_moves = [] - self.bp_moves_to_sp = None # whether the new sp depends on bp, e.g. 'leave; ret' overwrites sp with bp - self.block_length = None - self.makes_syscall = False - self.starts_with_syscall = False + + # transition information, i.e. how to pass the control flow to the next gadget self.transit_type = None self.jump_reg = None self.pc_reg = None @@ -129,8 +132,6 @@ def reg_set_same_effect(self, other): return False if self.concrete_regs != other.concrete_regs: return False - if self.bp_moves_to_sp != other.bp_moves_to_sp: - return False if self.reg_dependencies != other.reg_dependencies: return False return True @@ -169,18 +170,15 @@ def reg_move_better_than(self, other): def __str__(self): s = "Gadget %#x\n" % self.addr - if self.bp_moves_to_sp: - s += "Stack change: bp + %#x\n" % self.stack_change - else: - s += "Stack change: %#x\n" % self.stack_change + s += "Stack change: %#x\n" % self.stack_change s += "Changed registers: " + str(self.changed_regs) + "\n" s += "Popped registers: " + str(self.popped_regs) + "\n" for move in self.reg_moves: s += "Register move: [%s to %s, %d bits]\n" % (move.from_reg, move.to_reg, move.bits) s += "Register dependencies:\n" - for reg in self.reg_dependencies: + for reg, deps in self.reg_dependencies.items(): controllers = self.reg_controllers.get(reg, []) - dependencies = [x for x in self.reg_dependencies[reg] if x not in controllers] + dependencies = [x for x in deps if x not in controllers] s += " " + reg + ": [" + " ".join(controllers) + " (" + " ".join(dependencies) + ")]" + "\n" for mem_access in self.mem_changes: if mem_access.op == "__add__": @@ -221,8 +219,6 @@ def __str__(self): s += " " + "address (%d bits): %#x" % (mem_access.addr_size, mem_access.addr_constant) s += " " + "data (%d bits) stored in regs:" % mem_access.data_size s += str(list(mem_access.data_dependencies)) + "\n" - if self.makes_syscall: - s += "Makes a syscall\n" return s def __repr__(self): @@ -241,32 +237,78 @@ def copy(self): out.mem_changes = list(self.mem_changes) out.mem_writes = list(self.mem_writes) out.reg_moves = list(self.reg_moves) - out.bp_moves_to_sp = self.bp_moves_to_sp out.block_length = self.block_length - out.makes_syscall = self.makes_syscall - out.starts_with_syscall = self.starts_with_syscall out.transit_type = self.transit_type out.jump_reg = self.jump_reg out.pc_reg = self.pc_reg return out -class StackPivot: +class PivotGadget(RopGadget): """ - stack pivot gadget + stack pivot gadget, the definition of a PivotGadget is that + it can arbitrarily control the stack pointer register, and do the pivot exactly once + TODO: so currently, it cannot directly construct a `pop rbp; leave ret;` + chain to pivot stack """ def __init__(self, addr): - self.addr = addr - self.sp_from_reg = None - self.sp_popped_offset = None + super().__init__(addr) + self.stack_change_after_pivot = None + # TODO: sp_controllers can be registers, payload on stack, and symbolic read data + # but we do not handle symbolic read data, yet + self.sp_reg_controllers = set() + self.sp_stack_controllers = set() + + def __str__(self): + s = f"PivotGadget {self.addr:#x}\n" + s += f" sp_controllers: {self.sp_controllers}\n" + s += f" stack change: {self.stack_change:#x}\n" + s += f" stack change after pivot: {self.stack_change_after_pivot:#x}\n" + return s + + @property + def sp_controllers(self): + s = self.sp_reg_controllers.copy() + return s.union(self.sp_stack_controllers) + + def __repr__(self): + return f"" + + def copy(self): + new = super().copy() + new.stack_change_after_pivot = self.stack_change_after_pivot + new.sp_reg_controllers = set(self.sp_reg_controllers) + new.sp_stack_controllers = set(self.sp_stack_controllers) + return new + +class SyscallGadget(RopGadget): + """ + we collect two types of syscall gadgets: + 1. with return: syscall; ret + 2. without return: syscall; xxxx + """ + def __init__(self, addr): + super().__init__(addr) + self.makes_syscall = False + self.starts_with_syscall = False def __str__(self): - s = "Pivot %#x\n" % self.addr - if self.sp_from_reg is not None: - s += "sp from reg: %s\n" % self.sp_from_reg - elif self.sp_popped_offset is not None: - s += "sp popped at %#x\n" % self.sp_popped_offset + s = f"SyscallGadget {self.addr:#x}\n" + s += f" stack change: {self.stack_change:#x}\n" + s += f" transit type: {self.transit_type}\n" + s += f" can return: {self.can_return}\n" + s += f" starts_with_syscall: {self.starts_with_syscall}\n" return s def __repr__(self): - return "" % self.addr + return f"" + + @property + def can_return(self): + return self.transit_type != 'syscall' + + def copy(self): + new = super().copy() + new.makes_syscall = self.makes_syscall + new.starts_with_syscall = self.starts_with_syscall + return new diff --git a/angrop/rop_utils.py b/angrop/rop_utils.py index c35053b..469722c 100644 --- a/angrop/rop_utils.py +++ b/angrop/rop_utils.py @@ -187,11 +187,8 @@ def make_symbolic_state(project, reg_set, stack_gsize=80): symbolic_state.registers.store(reg, symbolic_state.solver.BVS("sreg_" + reg + "-", project.arch.bits)) # restore sp symbolic_state.regs.sp = input_state.regs.sp - # restore bp - symbolic_state.regs.bp = input_state.regs.bp return symbolic_state - def make_reg_symbolic(state, reg): state.registers.store(reg, state.solver.BVS("sreg_" + reg + "-", state.arch.bits)) @@ -202,7 +199,70 @@ def cast_rop_value(val, project): val.rebase_analysis() return val -def step_to_unconstrained_successor(project, state, max_steps=2, allow_simprocedures=False): +def is_in_kernel(project, state): + ip = state.ip + if not ip.symbolic: + obj = project.loader.find_object_containing(ip.concrete_value) + if obj is None: + return False + if obj.binary == 'cle##kernel': + return True + return False + return False + +def step_one_block(project, state, stop_at_syscall=False): + block = state.block() + num_insts = len(block.capstone.insns) + + if not num_insts: + raise RopException("No instructions!") + + if project.is_hooked(state.addr): + succ = project.factory.successors(state) + return succ, None + + if is_in_kernel(project, state): + succ = project.factory.successors(state) + if stop_at_syscall: + return None, succ.flat_successors[0] + return succ, None + + if project.arch.linux_name.startswith("mips"): + last_inst_addr = block.capstone.insns[-2].address + else: + last_inst_addr = block.capstone.insns[-1].address + for _ in range(num_insts): # considering that it may get into kernel mode + if state.addr != last_inst_addr: + state = step_one_inst(project, state, stop_at_syscall=stop_at_syscall) + if stop_at_syscall and is_in_kernel(project, state): + return None, state + else: + succ = project.factory.successors(state, num_inst=1) + if not succ.flat_successors: + return succ, None + if stop_at_syscall and is_in_kernel(project, succ.flat_successors[0]): + return None, succ.flat_successors[0] + return succ, None + raise RopException("Fail to reach the last instruction!") + +def step_one_inst(project, state, stop_at_syscall=False): + if is_in_kernel(project, state): + if stop_at_syscall: + return state + succ = project.factory.successors(state) + return step_one_inst(project, succ.flat_successors[0]) + + if project.is_hooked(state.addr): + succ = project.factory.successors(state) + return step_one_inst(project, succ.flat_successors[0]) + + succ = project.factory.successors(state, num_inst=1) + if not succ.flat_successors: + raise RopException(f"fail to step state: {state}") + return succ.flat_successors[0] + +def step_to_unconstrained_successor(project, state, max_steps=2, allow_simprocedures=False, + stop_at_syscall=False, precise_action=False): """ steps up to two times to try to find an unconstrained successor :param state: the input state @@ -214,7 +274,20 @@ def step_to_unconstrained_successor(project, state, max_steps=2, allow_simproced # nums state.options.add(angr.options.BYPASS_UNSUPPORTED_SYSCALL) - succ = project.factory.successors(state) + if not precise_action: + succ = project.factory.successors(state) + if stop_at_syscall and succ.flat_successors: + next_state = succ.flat_successors[0] + if is_in_kernel(project, next_state): + return next_state + else: + # FIXME: we step instruction by instruction because of an angr bug: xxxx + # the bug makes angr may merge sim_actions from two instructions into one + # making analysis based on sim_actions inaccurate + succ, state = step_one_block(project, state, stop_at_syscall=stop_at_syscall) + if state: + return state + if len(succ.flat_successors) + len(succ.unconstrained_successors) != 1: raise RopException("Does not get to a single successor") if len(succ.flat_successors) == 1 and max_steps > 0: diff --git a/angrop/rop_value.py b/angrop/rop_value.py index 0bda88f..efe7e92 100644 --- a/angrop/rop_value.py +++ b/angrop/rop_value.py @@ -12,7 +12,7 @@ def __init__(self, value, project): self.reg_name = None if type(value) is str: if value not in project.arch.default_symbolic_registers: - raise ValueError(f"unknown register: {value}!") + raise ValueError(f"{value} is not a general purpose register!") self.reg_name = value value = claripy.BVS(value, project.arch.bits) diff --git a/tests/test_chainbuilder.py b/tests/test_chainbuilder.py index b1de830..44d42f8 100644 --- a/tests/test_chainbuilder.py +++ b/tests/test_chainbuilder.py @@ -46,8 +46,9 @@ def test_i386_syscall(): rop.load_gadgets(cache_path) else: rop.find_gadgets() + rop.save_gadgets(cache_path) - chain =rop.do_syscall(4, [1, 0x80AC5E8, 17]) + chain = rop.do_syscall(4, [1, 0x80AC5E8, 17]) state = chain.exec() assert state.posix.dumps(1) == b'/usr/share/locale' @@ -182,6 +183,30 @@ def test_add_to_mem(): rop.add_to_mem(0x41414140, 0x42424242) +def test_pivot(): + cache_path = os.path.join(CACHE_DIR, "i386_glibc_2.35") + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "i386_glibc_2.35"), auto_load_libs=False) + rop = proj.analyses.ROP() + + if os.path.exists(cache_path): + rop.load_gadgets(cache_path) + else: + rop.find_gadgets() + rop.save_gadgets(cache_path) + + chain = rop.pivot(0x41414140) + state = chain.exec() + assert state.solver.eval(state.regs.sp == 0x41414140) + + chain = rop.pivot(0x41414140) + state = chain.exec() + assert state.solver.eval(state.regs.sp == 0x41414140) + + chain = rop.set_regs(eax=0x41414140) + chain += rop.pivot('eax') + state = chain.exec() + assert state.solver.eval(state.regs.sp == 0x41414140+4) + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} diff --git a/tests/test_find_gadgets.py b/tests/test_find_gadgets.py index 84d2d9b..1f59b7a 100644 --- a/tests/test_find_gadgets.py +++ b/tests/test_find_gadgets.py @@ -18,20 +18,18 @@ The logging should say why the gadget was discarded. rop = p.analyses.ROP() -angrop.gadget_analyzer.l.setLevel("DEBUG") -rop._gadget_analyzer.analyze_gadget(addr) +rop.analyze_gadget(addr) If a gadget is missing memory reads / memory writes / memory changes, the actions are probably missing. Memory changes require a read action followed by a write action to the same address. """ def gadget_exists(rop, addr): - return rop._gadget_analyzer.analyze_gadget(addr) is not None + return rop.analyze_gadget(addr) is not None def test_badbyte(): proj = angr.Project(os.path.join(tests_dir, "i386", "bronze_ropchain"), auto_load_libs=False) rop = proj.analyses.ROP() - rop._initialize_gadget_analyzer() assert all(gadget_exists(rop, x) for x in [0x080a9773, 0x08091cf5, 0x08092d80, 0x080920d3]) @@ -46,20 +44,58 @@ def local_multiprocess_find_gadgets(): def test_symbolic_memory_access_from_stack(): proj = angr.Project(os.path.join(tests_dir, "armel", "test_angrop_arm_gadget"), auto_load_libs=False) rop = proj.analyses.ROP() - rop._initialize_gadget_analyzer() assert all(gadget_exists(rop, x) for x in [0x000103f4]) def test_arm_thumb_mode(): proj = angr.Project(os.path.join(bin_path, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) - rop._initialize_gadget_analyzer() - gadget = rop._gadget_analyzer.analyze_gadget(0x4bf858+1) + gadget = rop.analyze_gadget(0x4bf858+1) assert gadget assert gadget.block_length == 6 +def test_pivot_gadget(): + # pylint: disable=pointless-string-statement + proj = angr.Project(os.path.join(tests_dir, "i386", "bronze_ropchain"), auto_load_libs=False) + rop = proj.analyses.ROP() + + assert all(gadget_exists(rop, x) for x in [0x80488e8, 0x8048998, 0x8048fd6, 0x8052cac, 0x805658c, ]) + + gadget = rop.analyze_gadget(0x8048592) + assert not gadget + + proj = angr.Project(os.path.join(bin_path, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + + """ + 4c7b5a mov sp, r7 + 4c7b5c pop.w {r4, r5, r6, r7, r8, sb, sl, fp, pc + """ + + gadget = rop.analyze_gadget(0x4c7b5a+1) + assert gadget is not None + + proj = angr.Project(os.path.join(tests_dir, "i386", "i386_glibc_2.35"), auto_load_libs=False) + rop = proj.analyses.ROP() + """ + 439ad3 pop esp + 439ad4 lea esp, [ebp-0xc] + 439ad7 pop ebx + 439ad8 pop esi + 439ad9 pop edi + 439ada pop ebp + 439adb ret + """ + gadget = rop.analyze_gadget(0x439ad3) + assert gadget is None + +def test_syscall_gadget(): + proj = angr.Project(os.path.join(tests_dir, "i386", "bronze_ropchain"), auto_load_libs=False) + rop = proj.analyses.ROP() + assert all(gadget_exists(rop, x) for x in [0x0806f860, 0x0806f85e, 0x080939e3, 0x0806f2f1]) + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} diff --git a/tests/test_gadgets.py b/tests/test_gadgets.py index 7bf0909..06735ed 100644 --- a/tests/test_gadgets.py +++ b/tests/test_gadgets.py @@ -2,6 +2,7 @@ import angr import angrop # pylint: disable=unused-import +from angrop.rop_gadget import RopGadget, PivotGadget, SyscallGadget BIN_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "binaries") CACHE_DIR = os.path.join(BIN_DIR, 'tests_data', 'angrop_gadgets_cache') @@ -27,7 +28,7 @@ def test_arm_conditional(): cond_gadget_addrs = [0x10368, 0x1036c, 0x10370, 0x10380, 0x10384, 0x1038c, 0x1039c, 0x103a0, 0x103b8, 0x103bc, 0x103c4, 0x104e8, 0x104ec] - assert all(x.addr not in cond_gadget_addrs for x in rop._gadgets) + assert all(x.addr not in cond_gadget_addrs for x in rop._all_gadgets) def test_jump_gadget(): """ @@ -36,10 +37,9 @@ def test_jump_gadget(): """ rop = get_rop(os.path.join(BIN_DIR, "tests", "mipsel", "fauxware")) - jump_gadgets = [x for x in rop._gadgets if x.transit_type == "jmp_reg"] + jump_gadgets = [x for x in rop._all_gadgets if x.transit_type == "jmp_reg"] assert len(jump_gadgets) > 0 - jump_regs = [x.jump_reg for x in jump_gadgets] assert 't9' in jump_regs assert 'ra' in jump_regs @@ -49,7 +49,6 @@ def test_arm_mem_change_gadget(): proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) - rop._initialize_gadget_analyzer() """ 0x0004f08c <+28>: ldr r2, [r4, #48] ; 0x30 @@ -59,11 +58,11 @@ def test_arm_mem_change_gadget(): 0x0004f094 <+36>: str r5, [r4, #48] ; 0x30 0x0004f096 <+38>: pop {r3, r4, r5, pc} """ - gadget = rop._gadget_analyzer.analyze_gadget(0x44f08c+1) # thumb mode + gadget = rop.analyze_gadget(0x44f08c+1) # thumb mode assert gadget assert not gadget.mem_changes - gadget = rop._gadget_analyzer.analyze_gadget(0x459eea+1) # thumb mode + gadget = rop.analyze_gadget(0x459eea+1) # thumb mode assert gadget assert not gadget.mem_changes @@ -73,7 +72,7 @@ def test_arm_mem_change_gadget(): 4b1e34 str r4, [r6] 4b1e36 pop {r3, r4, r5, r6, r7, pc} """ - gadget = rop._gadget_analyzer.analyze_gadget(0x4b1e30+1) # thumb mode + gadget = rop.analyze_gadget(0x4b1e30+1) # thumb mode assert gadget.mem_changes """ @@ -82,7 +81,7 @@ def test_arm_mem_change_gadget(): 4c1e7c str r1, [r4,#0x14] 4c1e7e pop {r3, r4, r5, pc} """ - gadget = rop._gadget_analyzer.analyze_gadget(0x4c1e78+1) # thumb mode + gadget = rop.analyze_gadget(0x4c1e78+1) # thumb mode assert gadget.mem_changes """ @@ -91,7 +90,7 @@ def test_arm_mem_change_gadget(): 4c1ea8 str r2, [r3,#0x14] 4c1eaa bx lr """ - gadget = rop._gadget_analyzer.analyze_gadget(0x4c1ea4+1) # thumb mode + gadget = rop.analyze_gadget(0x4c1ea4+1) # thumb mode assert not gadget.mem_changes """ @@ -101,9 +100,168 @@ def test_arm_mem_change_gadget(): 4c1e94 str r1, [r4,#0x14] 4c1e96 pop {r3, r4, r5, pc} """ - gadget = rop._gadget_analyzer.analyze_gadget(0x4c1e8e+1) # thumb mode + gadget = rop.analyze_gadget(0x4c1e8e+1) # thumb mode assert gadget.mem_changes +def test_pivot_gadget(): + # pylint: disable=pointless-string-statement + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "i386_glibc_2.35"), auto_load_libs=False) + rop = proj.analyses.ROP() + + """ + 5719da pop esp + 5719db ret + """ + gadget = rop.analyze_gadget(0x5719da) + assert gadget.stack_change == 0x4 + assert gadget.stack_change_after_pivot == 0x4 + assert len(gadget.sp_controllers) == 1 + assert len(gadget.sp_reg_controllers) == 0 + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "bronze_ropchain"), auto_load_libs=False) + rop = proj.analyses.ROP() + + """ + 80488e8 leave + 80488e9 ret + """ + gadget = rop.analyze_gadget(0x80488e8) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0 + assert gadget.stack_change_after_pivot == 0x8 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'ebp' + + + """ + 8048592 xchg esp, eax + 8048593 ret 0xca21 + """ + gadget = rop.analyze_gadget(0x8048592) + assert not gadget + + """ + 8048998 pop ecx + 8048999 pop ebx + 804899a pop ebp + 804899b lea esp, [ecx-0x4] + 804899e ret + """ + gadget = rop.analyze_gadget(0x8048998) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0xc + assert gadget.stack_change_after_pivot == 0x4 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop().startswith('symbolic_stack_') + + """ + 8048fd6 xchg esp, eax + 8048fd7 ret + """ + gadget = rop.analyze_gadget(0x8048fd6) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0 + assert gadget.stack_change_after_pivot == 0x4 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'eax' + + """ + 8052cac lea esp, [ebp-0xc] + 8052caf pop ebx + 8052cb0 pop esi + 8052cb1 pop edi + 8052cb2 pop ebp + 8052cb3 ret + """ + gadget = rop.analyze_gadget(0x8052cac) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0 + assert gadget.stack_change_after_pivot == 0x14 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'ebp' + + """ + 805658c add BYTE PTR [eax],al + 805658e pop ebx + 805658f pop esi + 8056590 pop edi + 8056591 ret + """ + gadget = rop.analyze_gadget(0x805658c) + assert type(gadget) == RopGadget + assert gadget.stack_change == 0x10 # 3 pops + 1 ret + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False) + rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True) + + """ + 4c7b5a mov sp, r7 + 4c7b5c pop.w {r4, r5, r6, r7, r8, sb, sl, fp, pc} + """ + + #rop.find_gadgets(show_progress=False) + gadget = rop.analyze_gadget(0x4c7b5a+1) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0 + assert gadget.stack_change_after_pivot == 0x24 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'r7' + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "manysum"), load_options={"auto_load_libs": False}) + rop = proj.analyses.ROP() + + """ + 1040c mov r0, r3 + 10410 sub sp, fp, #0x0 + 10414 pop {fp} + 10418 bx lr + """ + gadget = rop.analyze_gadget(0x1040c) + assert type(gadget) == PivotGadget + assert gadget.stack_change == 0 + assert gadget.stack_change_after_pivot == 0x4 + assert len(gadget.sp_controllers) == 1 and gadget.sp_controllers.pop() == 'r11' + +def test_syscall_gadget(): + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "i386_glibc_2.35"), auto_load_libs=False) + rop = proj.analyses.ROP() + + gadget = rop.analyze_gadget(0x437765) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0 + assert not gadget.can_return + + gadget = rop.analyze_gadget(0x5212f6) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0 + assert not gadget.can_return + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "i386", "bronze_ropchain"), auto_load_libs=False) + rop = proj.analyses.ROP() + + gadget = rop.analyze_gadget(0x0806f860) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0x4 + assert gadget.can_return + + gadget = rop.analyze_gadget(0x0806f85e) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0x4 + assert gadget.can_return + + gadget = rop.analyze_gadget(0x080939e3) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0x0 + assert not gadget.can_return + + gadget = rop.analyze_gadget(0x0806f2f1) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0x0 + assert not gadget.can_return + + proj = angr.Project(os.path.join(BIN_DIR, "tests", "x86_64", "roptest"), auto_load_libs=False) + rop = proj.analyses.ROP() + gadget = rop.analyze_gadget(0x4000c1) + assert type(gadget) == SyscallGadget + assert gadget.stack_change == 0 + assert not gadget.can_return + def run_all(): functions = globals() all_functions = {x:y for x, y in functions.items() if x.startswith('test_')} diff --git a/tests/test_rop.py b/tests/test_rop.py index c1abadf..8191b01 100644 --- a/tests/test_rop.py +++ b/tests/test_rop.py @@ -43,7 +43,8 @@ def assert_gadgets_equal(known_gadget, test_gadget): assert known_gadget.reg_dependencies == test_gadget.reg_dependencies assert known_gadget.reg_controllers == test_gadget.reg_controllers assert known_gadget.stack_change == test_gadget.stack_change - assert known_gadget.makes_syscall == test_gadget.makes_syscall + if hasattr(known_gadget, "makes_syscall"): + assert known_gadget.makes_syscall == test_gadget.makes_syscall assert len(known_gadget.mem_reads) == len(test_gadget.mem_reads) for m1, m2 in zip(known_gadget.mem_reads, test_gadget.mem_reads): @@ -109,7 +110,7 @@ def test_rop_x86_64(): # check gadgets tup = pickle.load(open(cache_path, "rb")) - compare_gadgets(rop.gadgets, tup[0]) + compare_gadgets(rop._all_gadgets, tup[0]) # test creating a rop chain chain = rop.set_regs(rbp=0x1212, rbx=0x1234567890123456) @@ -137,7 +138,7 @@ def test_rop_i386_cgc(): # check gadgets tup = pickle.load(open(os.path.join(test_data_location, "0b32aa01_01_gadgets"), "rb")) - compare_gadgets(rop.gadgets, tup[0]) + compare_gadgets(rop._all_gadgets, tup[0]) # test creating a rop chain chain = rop.set_regs(ebx=0x98765432, ecx=0x12345678) @@ -164,7 +165,7 @@ def test_rop_arm(): # check gadgets tup = pickle.load(open(os.path.join(test_data_location, "arm_manysum_test_gadgets"), "rb")) - compare_gadgets(rop.gadgets, tup[0]) + compare_gadgets(rop._all_gadgets, tup[0]) # test creating a rop chain chain = rop.set_regs(r11=0x99887766) @@ -184,9 +185,9 @@ def test_rop_arm(): def test_roptest_x86_64(): p = angr.Project(os.path.join(public_bin_location, "x86_64/roptest"), auto_load_libs=False) - r = p.analyses.ROP() + r = p.analyses.ROP(only_check_near_rets=False) r.find_gadgets_single_threaded(show_progress=False) - c = r.execve(b"/bin/sh") + c = r.execve(path=b"/bin/sh") # verifying this is a giant pain, partially because the binary is so tiny, and there's no code beyond the syscall assert len(c._gadgets) == 8 diff --git a/tests/test_ropchain.py b/tests/test_ropchain.py index 66cd762..d410875 100644 --- a/tests/test_ropchain.py +++ b/tests/test_ropchain.py @@ -21,7 +21,7 @@ def test_chain_exec(): rop.save_gadgets(cache_path) # make sure the target gadget exist - gadgets = [x for x in rop.gadgets if x.addr == 0x402503] + gadgets = [x for x in rop._all_gadgets if x.addr == 0x402503] assert len(gadgets) == 1 gadget = gadgets[0]