#!/usr/bin/env python

'''
vw is a vivisect workspace/dev environment, allowing easier access to 
Vivisect internals.  
'''

import os
import sys
import code
import time
import struct
import logging
import argparse
import readline
import rlcompleter
readline.parse_and_bind("tab: complete")

import cobra

import envi
import envi.archs.i386 as e_i386
import envi.memcanvas as e_memcanvas

import visgraph.pathcore as vg_pathcore
import visgraph.graphcore as vg_graphcore
import vivisect
import vivisect.cli as v_cli
import vivisect.remote.server as viv_server
import vivisect.tools.graphutil as viv_graph
import vivisect.symboliks as v_symbx
import vivisect.symboliks.common as vs_cmn
import vivisect.symboliks.emulator as vs_emu
import vivisect.symboliks.analysis as vs_anal
import vivisect.symboliks.archs.i386 as vs_i386
from vivisect.const import *
from pprint import pprint

import atlasutils.emutils as aemu


if sys.version_info.major == 3:
    from importlib import reload

from binascii import hexlify, unhexlify

loghandler = logging.StreamHandler()
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[loghandler])
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
loghandler.addFilter(lambda record: not record.name.startswith("parso"))

#formatter = get_ipython().display_formatter.formatters['text/plain']
#formatter.for_type(int, lambda n, p, cycle: p.text("0x%x" % n))

# set up python path to include Vivisect's Scripts and Extensions, as well as VDB Scripts
for pathvar in ('VIV_SCRIPT_PATH', 'VIV_EXT_PATH', 'VDB_SCRIPT_PATH'):
    pathstr = os.getenv(pathvar)

    if pathstr is not None:
        for entry in pathstr.split(os.pathsep):
            if entry not in sys.path:
                sys.path.append(entry)

sys.path.append('.')


intro = '''
welcome to atlas' viv-world playground.
if provided, the file/vivworkspace provided on the commandline has been instantiated as 'vw'
other objects at the ready:
    vw - the vivisect workspace object (locked and loaded)
    emu - the appropriate envi emulator for the vw
    temu - emutils TestEmulator wrapping the emu (if exists)
    i3t - if the vw is an i386 bin, this is the i386 symbolik translator object
    symemu - if the vw is an i386 bin, this is the symbolik emulator object
    sctx - symboliks analysis context for this workspace

also, a few modules have been imported:
    import envi
    import envi.archs.i386 as e_i386

    import visgraph.pathcore as vg_pathcore
    import visgraph.graphcore as vg_graphcore
    import vivisect
    import vivisect.cli as v_cli
    import vivisect.remote.server as viv_server
    import vivisect.tools.graphutil as viv_graph
    import vivisect.symboliks as v_symbx
    import vivisect.symboliks.common as vs_cmn
    import vivisect.symboliks.emulator as vs_emu
    import vivisect.symboliks.analysis as vs_anal
    import vivisect.symboliks.archs.i386 as vs_i386
    from vivisect.const import *
    '''

vwes = { v:k for k,v in globals().items() if k.startswith('VWE_')}

def printVWEs(evts, etype=None, paginate=50):
    if etype is not None:
        reprevts = [(vwes.get(evt), einfo) for evt, einfo in evts if etype==evt]
    else:
        reprevts = [(vwes.get(evt), einfo) for evt, einfo in evts]

    print("\tlen: %d" % len(reprevts))
    for eidx in xrange(len(reprevts)):
        evt, einfo = reprevts[eidx]
        print("%.7d  %-20s: %r" %(eidx, evt, einfo))
        if not ((eidx+1) % paginate):
            raw_input("PAUSE... Press Enter")


def printInterestingVWEs(evts, paginate=50):
    for evt in (VWE_ADDFUNCTION, VWE_SETNAME, VWE_COMMENT, VWE_SETMETA, VWE_SETFILEMETA, VWE_SETVASETROW, VWE_CHAT, ):
        print("==== %s ====" % (vwes.get(evt)))
        printVWEs(evts, evt, paginate)
        raw_input("---- DONE ----")

def insertComment(vw, va, comment):
    cmt = vw.getComment(va)
    if cmt is not None:
        vw.setComment(va, comment + " :: " + cmt)
    else:
        vw.setComment(va, comment)

def mergeInterestingEvents(vw, evts, offset = 0):
    # first merge functions
    functions = [(vwes.get(evt), einfo) for evt, einfo in evts if VWE_ADDFUNCTION==evt]

    for efunctype, (funcinfo) in functions:
        funcva = funcinfo[0]
        print("FUNCTION: 0x%x" % (funcva))
        locs = [(evt, einfo) for evt, einfo in evts if VWE_ADDLOCATION==evt and einfo[0]==funcva]
        if len(locs):
            #print(locs)
            evt, loc = locs[-1]
            locva, lsz, ltype, linfo = loc
            #print("evt, loc:  (%r, %r)" % ((evt, loc)))
            #raw_input("so?")
            #vw._fireEvent(evt, loc)
            vw.makeCode(locva+offset, arch=linfo)
        else:
            insertComment(vw, funcva+offset, "MERGED, but couldn't find original location.")

        vw.makeFunction(funcva + offset)

    # next merge comments
    comments = [(vwes.get(evt), einfo) for evt, einfo in evts if VWE_COMMENT==evt]
    for cevt, einfo in comments:
        va, impcmt = einfo
        print("COMMENT: 0x%x: %s" % (va, impcmt))
        newva = va + offset
        insertComment(vw, newva, impcmt)

    # next merge names
    names = [(vwes.get(evt), einfo) for evt, einfo in evts if VWE_SETNAME==evt]
    for nevt, einfo in names:
        va, impname = einfo
        print("NAME: 0x%x: %s" % (va, impname))
        if impname == "sub_%08x" % va:
            print("skipping...")
            continue

        newva = va + offset
        name = vw.getName(newva)

        if name is not None and not name != "sub_%08x" % newva:
            insertComment(vw, newva, "Import Name: %s" % impname)
        else:
            vw.makeName(newva, impname)


def vprint(strthing):
    '''
    this helps emulate a Viv script, which has vprint() strapped into the globals
    '''
    print(strthing)

def _setVar_cb(symobj, ctx):
    '''
    callback for setVar, which uses walkTree, which requires a callback for each node
    '''
    #if isinstance(symobj, vs_cmn.Var) and symobj.varname == ctx[0]:
    if symobj.solve() == ctx[2]:
        return ctx[1]
    return symobj
        
def setVar(symobj, var, val):
    '''
    walk the symbolik tree (aka. Abstract Data Tree), _symobj_, and swap in _val_ wherever _var_ exists.
    uses _setVar_cb as the callback for walkTree.
    '''
    if type(val) in (int, long):
        val = vs_cmn.Const(val)
    if type(var) in (int, long):
        var = vs_cmn.Const(var)
    if type(val) == str:
        var = vs_cmn.Var(val)
    if type(var) == str:
        var = vs_cmn.Var(var)
        
    ctx = (var, val, var.solve())
    return symobj.walkTree(_setVar_cb, ctx)

def itersymviewpaths(spaths):
    '''
    step through a list/generator of symbolik paths, print the important effects, pause between paths
    '''
    for path in spaths:
        x=[eff.reduce() for eff in path[1]]
        print("\n".join(["%x: %s" %(eff.va,str(eff)) for eff in path[1] if not isinstance(eff, SetVariable) or not eff.varname.startswith('efl')]))
        raw_input("\n PRESS ENTER FOR NEXT PATH\n")

def itersymview(funcva, args=None, maxpath=1000):
    '''
    step through symbolik paths, like itersymviewpaths, but for a given funcva.
    '''
    paths=sctx.getSymbolikPaths(funcva, args, maxpath)
    itersymviewpaths(paths)

def symrepr(syms):
    '''
    step through a list of symbolik effects, and print out the REPR of interesting ones (ignoring eflag updates, etc...)
    '''
    [e.reduce() for e in syms]
    print("\n".join([repr(e) for e in syms if e.efftype != 1 or not e.varname.startswith('efl')]))

def symview(syms):
    '''
    step through a list of symbolik effects, and print out the STR of interesting ones (ignoring eflag updates, etc...)
    '''
    [e.reduce() for e in syms]
    print("\n".join([str(e) for e in syms if e.efftype != 1 or not e.varname.startswith('efl')]))

def symreproutputs(syms):
    '''
    step through a list of symbolik effects, printing the REPR of OUTPUT effects
    ie.  no transient effects like SetVariable or Stack Reads/Writes
    '''
    symbols = []
    for e in syms:
        e.reduce()
        if e.efftype == EFFTYPE_SETVAR:
            continue
        if e.efftype in (EFFTYPE_READMEM, EFFTYPE_WRITEMEM):
            symaddr = e.symaddr
            if symaddr.isDiscrete() and symaddr.solve() & 0xbfbf0000 == 0xbfbf0000:
                continue
        print(repr(e))

def symviewoutputs(syms):
    '''
    step through a list of symbolik effects, printing the STR of OUTPUT effects
    ie.  no transient effects like SetVariable or Stack Reads/Writes
    '''
    symbols = []
    for e in syms:
        e.reduce()
        if e.efftype == EFFTYPE_SETVAR:
            continue
        if e.efftype in (EFFTYPE_READMEM, EFFTYPE_WRITEMEM):
            symaddr = e.symaddr
            if symaddr.isDiscrete() and symaddr.solve() & 0xbfbf0000 == 0xbfbf0000:
                continue
        print(str(e))

def symeff(eff):
    '''
    not ready for prime time yet...  
    intended to repr a symbolik effect taking into account structures and arrays.
    '''
    neff = eval(repr((eff)))
    neff.walkTree(structure_cb, [])


# handed ( eff, ctx )
def structure_cb(eff, ctx):
    '''
    broken, don't use yet
    '''
    if isinstance(eff, Mem):
        if isinstance(eff.kids[0], o_add) and \
                eff.kids[0].kids[1].isDiscrete():
                    # this is a structure reference, clobber the next few reprs
                    # if we hit multi-layers, the parent needs to pass down repr to the left child... but that left child doesn't know it's something special.  this really should be in the __str__ method.  but would that help?  coordination is required here.
                    # this could be all completed with the right context-handling dict.
                    pass


def sym(expr):
    '''
    turn an arbitrary expression, like symbol ('foobin.memcpy'), or "4 + 5" into a value/address
    '''
    return vw.parseExpression(expr)

history = []
def revert(vw, skipsaves=0):
    '''
    ALPHA/TESTING!!  allows some canned manipulation of a workspace's event list.
    intended to allow the saving/slicing/dicing of an event list in order to get 
    it to the desired state (sometimes useful for regenning a new workspace and
    merging in events from an older one)
    '''
    global history
    idx = len(vw._event_list)
    numsaves = skipsaves + 1

    while numsaves:
        idx -= 1
        while idx > 0 and not (vw._event_list[idx][0] == 24 and vw._event_list[idx][1][0] == 'StorageName'):
            idx -= 1
        numsaves -= 1

    if idx == 0:
        print("No more reverts can be done on this workspace.")
        return

    if idx == len(vw._event_list) - 1:
        print("Nothing done.  Already at last save (see numsaves argument)")
        return

    print("idx: %d: %s" % (idx, repr(vw._event_list[idx][1])))
    print("Reverting to last save.")
    idx += 1

    history.append(vw._event_list[idx:])
    vw._event_list = vw._event_list[:idx]
    print("Done.  Workspace has Not been saved...  if you went too far, you may be able to salvage the workspace by reloading.")
    print("alternately, chopped events have been saved to the 'history' list")

def getRemoteWorkspaceDialog(servername=None):
    '''
    pops up a dialog box listing available workspaces on a remote server.
    returns the workspace name

    if servername is not provided, it's assumed a workspace server is already
    setup and stored in global 'wsserver'
    '''
    if servername is not None:
        wsserver = viv_server.connectToServer(servername)

    wslist = wsserver.listWorkspaces()
    dialog = vq_remote.VivServerDialog(wslist, parent=None)
    workspace = dialog.getWorkspaceName()
    return workspace

def listWorkspaces(servername=None):
    '''
    return a list of workspaces available from a remote server
    '''
    global wsserver
    if servername is not None:
        wsserver = viv_server.connectToServer(servername)

    wslist = wsserver.listWorkspaces()
    return wslist

def loadRemoteWorkspace(vwname=None, servername=None):
    '''
    easiest to use.  connect to a server if necessary, pop up the Remote Workspace Dialog if
    necessary, and load the selected workspace, returning the VivCli object
    '''
    global wsserver
    if servername is not None:
        wsserver = viv_server.connectToServer(servername)

    if vwname is None:
        vwname = getRemoteWorkspaceDialog()

    vw = viv_server.getServerWorkspace(wsserver, vwname)
    print(repr(vw.metadata))
    return vw

def listNewVars():
    '''
    allows the checking of the globals for any variable we've added since startup.
    this helps me stay sane when shutting down a vw instance, making sure i don't
    have temp data that i don't want to lose.
    '''
    global __orig_globals
    out = dict()
    out.update(globals())
    outkeys = out.keys()
    for k,v in __orig_globals.items():
        if k in outkeys:
            out.pop(k)
    return out

def step(width=8):
    while True:
        try:
            stepi(width)
            i = raw_input("--next--:\n")
            if i.startswith('q'):
                break
        except KeyboardInterrupt:
            pass

def stepi(width=8):
    global emu
    op = emu.parseOpcode(emu.getProgramCounter())
    for oper in op.opers:
        if oper.isDeref():
            addr = oper.getOperAddr(op, emu)
            mem = oper.getOperValue(op, emu)
            print("PRE:  oper addr: 0x%x      value: 0x%x"% (addr, mem))
            
    emu.stepi()

    for oper in op.opers:
        if oper.isDeref():
            addr = oper.getOperAddr(op, emu)
            mem = oper.getOperValue(op, emu)
            print("POST: oper addr: 0x%x      value: 0x%x"% (addr, mem))
            
    regvalues = [(regname,emu.getRegister(index)) for index, regname in emu._rctx_ids.items() if index< 0xffff]

    off = 0
    while off<len(regvalues):
        print("  ".join(['%-4s: %.4x' %(regname,val) for regname,val in regvalues[off:off+width]]))
        off += width

    op= vw.parseOpcode(emu.getProgramCounter())
    print("%x:  %s" % (op.va,op))

def vwdis(vw, startva, count=50, endva=None):
    off = 0
    for x in range(count):
        va = startva + off
        try:
            loc = vw.getLocation(va)
            if loc is not None:
                lva, lsz, ltp, ltinfo = loc
            else:
                ltinfo = envi.ARCH_DEFAULT
            
            op = vw.parseOpcode(va, arch=ltinfo)
            oplen = len(op)
            bytez = hexlify(vw.readMemory(startva+off, len(op)))
            print("0x%x:\t%30s\t%s" % (op.va, bytez, repr(op)))
            off += len(op)
        except envi.SegmentationViolation:
            print("SegmentationViolation at 0x%x" % (startva+off))
            return
        except Exception as e:  #envi.InvalidInstruction:
            print(e)
            print("0x%x:\t%30s\t  (bad)" % (startva+off, hexlify(vw.readMemory(startva+off, 1))))
            off += 1
            oplen = 1
        if endva is not None \
            and va <= endva < va+oplen: 
                break

def disasm(binstr, offset=0, baseva=0x41410000, arch=envi.ARCH_I386):
    vw = vivisect.VivWorkspace()
    binlen = len(binstr)
    arch = vw.imem_archs[arch>>16]
    while offset < binlen:
        op = arch.archParseOpcode(binstr, offset, baseva+offset)
        print("0x%x:\t\t%s" % (op.va, repr(op)))
        offset += len(op)

def timeit(testfunc, *args, **kwargs):
    import time
    try:
        starttime = time.time()
        testfunc(*args, **kwargs)
    except Exception as e:
        print("Error: %s" % e)
        import sys
        sys.excepthook(*sys.exc_info())
    finally:
        stoptime = time.time()
        tdelta = stoptime - starttime
        print("Function Completed in %.3fsecs" % (tdelta))
    return tdelta


def _troubled_dis(vw, va, count):
    possibles = 10
    while True:
        try:
            vwdis(vw, va-30-possibles, count, va)
            break
        except:
            possibles -= 1


def rop_amd64(vw, negoff=0):
    patlist = []
    # pop rsp; ret
    patlist.append('5cc3')
    # pop rdi; ret
    patlist.append('5fc3')
    # pop rsi; ret
    patlist.append('5ec3')
    # pop rdx; ret
    patlist.append('5ac3')
    # pop rsi; pop rdi, ret
    patlist.append('5e5fc3')
    # pop rdi; pop rsi, ret
    patlist.append('5f5ec3')
    # straigh up ret
    patlist.append('c3')
    # ret
    patlist.append('c2')

    for pattern in patlist:
        for va in vw.searchMemory(pattern.decode('hex')):
            #vwdis(vw, va-negoff, endva=va+len(pattern)-1)
            #print("")
            #udata = raw_input("PRESS ENTER or 'next'")
            #if len(udata) and udata[0].lower() == 'n':
            #    break
            _troubled_dis(vw, va, 20)
            udata = raw_input("PRESS ENTER or 'next'")
            if len(udata) and udata[0].lower() == 'n':
                break
            while udata:
                try:
                    va = int(udata,16)
                    vwdis(vw, va, 20)
                    #_troubled_dis(vw, va, 20)

                except ValueError:
                    pass

                except Exception as e:
                    print(type(e))

                udata = raw_input("? ")

def rop_i386(vw, startoff=10):
    possiblelist = vw.searchMemory('\xc3')
    possiblelist.extend(vw.searchMemory('\xc2'))
    #for x in vw.getLocations(vivisect.LOC_OP):
    #    if x[3] & envi.IF_RET:
    #        va = x[0]

    for x in possiblelist:
            va = x
            _troubled_dis(vw, va, 20)
            udata = raw_input("? ")
            while udata:
                try:
                    va = int(udata,16)
                    vwdis(vw, va, 20)

                except ValueError:
                    pass

                except Exception as e:
                    print(type(e))

                udata = raw_input("? ")


symtypes = {y:x for x,y in globals().items() if 'SYM' in x}

def getUnknowns(symvar):
   # determine unknown registers in this symbolik object
    def _cb_grab_vars(path, symobj, ctx):
        '''
        walkTree callback for grabbing Var objects
        '''
        if symobj.symtype in (SYMT_VAR, SYMT_MEM, SYMT_ARG):
            if symobj not in ctx:
                ctx.append(symobj)
    
    unks = []
    symvar.walkTree(_cb_grab_vars, unks, False)
    return unks

def deida(string):
    return string.replace(' ', '').decode('hex')


def shareWorkspace(vw):
    import vivisect.remote.share as viv_share

    daemon = viv_share.shareWorkspace(vw)
    vw.vprint('Workspace Listening Port: %d' % daemon.port)
    vw.vprint('Clients may now connect to your host on port %d' % daemon.port)


def connectToRemoteWorkspace():
    import vivisect.qt.remote as viv_q_remote

    vw = vivisect.VivWorkspace()
    viv_q_remote.openSharedWorkspace(vw)
    return vw


def _menuShareConnectServer(self):
    viv_q_remote.openServerAndWorkspace(self.vw, parent=self)


def setupVW(vw):
    if vw is None:
        print("vw is None, not setting up environment.  run setupVW(vw) with a valid vw")
        return

    try:
        import time
        vprint('loading emulator...')
        count = 30
        while count:
            if vw.getMeta('Architecture') is not None:
                break
            print("waiting....")
            count -= 1
            time.sleep(.1)
        emu = gbls['emu'] = vw.getEmulator()
        temu = gbls['temu'] = aemu.TestEmulator(emu)

    except Exception as e:
        if vw is not None:print("\n\t" + '\n\t'.join(["%s: \t%s" % (k,v) for k,v in vw.metadata.items()]))
        print("FAILED to get Emulator: %s" % e)

    try:
        sctx = gbls['sctx'] = vs_anal.getSymbolikAnalysisContext(vw)
        gbls['xlate'] = sctx.getTranslator()
        gbls['symemu'] = vs_emu.SymbolikEmulator(vw)
    except Exception as e:
        pass

STYPE_NONE = 0
STYPE_IPYTHON = 1
STYPE_CODE_INTERACT = 2

def interact(lcls, gbls, intro=""):
    shelltype = STYPE_NONE
    try:
        import IPython.Shell
        ipsh = IPython.Shell.IPShell(argv=[''], user_ns=lcls, user_global_ns=gbls)
        print(intro)
        shelltype = STYPE_IPYTHON

    except ImportError as e:
        try:
            from IPython.terminal.interactiveshell import TerminalInteractiveShell
            ipsh = TerminalInteractiveShell()
            ipsh.user_global_ns.update(gbls)
            ipsh.user_global_ns.update(lcls)
            ipsh.autocall = 2       # don't require parenthesis around *everything*.  be smart!
            shelltype = STYPE_IPYTHON
            print(intro)

        except ImportError as e:
            try:
                from IPython.frontend.terminal.interactiveshell import TerminalInteractiveShell
                ipsh = TerminalInteractiveShell()
                ipsh.user_global_ns.update(gbls)
                ipsh.user_global_ns.update(lcls)
                ipsh.autocall = 2       # don't require parenthesis around *everything*.  be smart!
                shelltype = STYPE_IPYTHON

                print(intro)
            except ImportError as e:
                print(e)
                shell = code.InteractiveConsole(gbls)
                shelltype = STYPE_IPYTHON
                print(intro)

    if shelltype == STYPE_IPYTHON:
        ipsh.mainloop()

    elif shelltype == STYPE_CODE_INTERACT:
        shell.interact()

    else:
        print("SORRY, NO INTERACTIVE OPTIONS AVAILABLE!!  wtfo?")


if __name__ == "__main__":
    global wsserver

    from vivisect.symboliks.common import *
    from vivisect.symboliks.effects import *
    parser = argparse.ArgumentParser()
    parser.add_argument('-O', '--option', dest='options', action='append', default=[], help='workspace/architecture options (typically for BLOB and IHEX files)')
    parser.add_argument('-S', '--server', dest='server', default=None, help='remote workspace server')
    parser.add_argument('-s', '--shared', dest='shared', default=None, help='remote shared workspace (host:port)')
    parser.add_argument('workspace', nargs="*", help="workspace file, or hyphen ('-') if using a remote workspace and you want the dialog")
    args = parser.parse_args()

    # vw now starts a gui in the background for popup dialogs
    try:
        import vqt.main as vqt_main
        import vqt.application as vqta
        import vqt.colors as vqt_colors
        import vivisect.qt.main as vq_main
        import vivisect.qt.remote as vq_remote
        if not hasattr(vqt_main, 'qapp'):
            vqt_main.startup(css=vqt_colors.qt_matrix)

    except Exception as e:
        print("error: %r" % e)

    vw = None
    vws = []
    mcanvs = []

    if not args.workspace and None not in (args.server, args.shared):
        args.workspace = '-'

    for vwname in args.workspace:
        vw = v_cli.VivCli()
        vws.append(vw)
        mcanv = e_memcanvas.StringMemoryCanvas(vw)
        mcanvs.append(mcanv)

        if args.server is not None:
            wsserver = viv_server.connectToServer(args.server)

            if vwname == '-':
                vwname = getRemoteWorkspaceDialog()

            if vwname is None:
                vw = v_cli.VivCli()
            else:
                vw = viv_server.getServerWorkspace(wsserver, vwname)

        elif args.shared is not None:
            uri = 'cobra://%s/vivisect.remote.client?msgpack=1' % args.shared
            server = cobra.CobraProxy(uri, msgpack=True)
            vw.initWorkspaceClient(server)

        else:
            for optarg in args.options:
                x = optarg.split(":")
                secname, optname, value = x
                vw.setOption(secname, optname, value)

            if vwname.endswith(".viv"):
                vw.loadWorkspace(vwname)
            else:
                vw.loadFromFile(vwname)

    lcls = locals()
    gbls = globals()
    gbls['vw'] = vw
    gbls['vws'] = vws
    gbls['vprint'] = vprint

    setupVW(vw)
    interact(lcls, gbls, intro)

