# -----------------------------------------
# Sextant
# Copyright 2014, Ensoft Ltd.
# Author: Patrick Stevens
# -----------------------------------------

#!/usr/bin/python3

import re
import argparse
import os.path
import subprocess
import logging


class ParsedObject():
    """
    Represents a function as parsed from an objdump disassembly.
    Has a name (which is the verbatim name like '__libc_start_main@plt'),
        a position (which is the virtual memory location in hex, like '08048320'
                    extracted from the dump),
        and a canonical_position (which is the virtual memory location in hex
                                  but stripped of leading 0s, so it should be a
                                  unique id).
    It also has a list what_do_i_call of ParsedObjects it calls using the
      assembly keyword 'call'.
    It has a list original_code of its assembler code, too, in case it's useful.
    """

    @staticmethod
    def get_canonical_position(position):
        return position.lstrip('0')

    def __eq__(self, other):
        return self.name == other.name

    def __init__(self, input_lines=None, assembler_section='', function_name='',
                 ignore_function_pointers=True, function_pointer_id=None):
        """
        Create a new ParsedObject given the definition-lines from objdump -S.
        A sample first definition-line is '08048300 <__gmon_start__@plt>:\n'
         but this method
         expects to see the entire definition eg

080482f0 <puts@plt>:
 80482f0:	ff 25 00 a0 04 08    	jmp    *0x804a000
 80482f6:	68 00 00 00 00       	push   $0x0
 80482fb:	e9 e0 ff ff ff       	jmp    80482e0 <_init+0x30>

          We also might expect assembler_section, which is for instance '.init'
            in 'Disassembly of section .init:'
          function_name is used if we want to give this function a custom name.
          ignore_function_pointers=True will pretend that calls to (eg) *eax do
            not exist; setting to False makes us create stubs for those calls.
          function_pointer_id is only used internally; it refers to labelling
            of function pointers if ignore_function_pointers is False. Each
            stub is given a unique numeric ID: this parameter tells init where
            to start counting these IDs from.

        """
        if input_lines is None:
            # get around Python's inability to pass in empty lists by value
            input_lines = []

        self.name = function_name or re.search(r'<.+>', input_lines[0]).group(0).strip('<>')
        self.what_do_i_call = []
        self.position = ''

        if input_lines:
            self.position = re.search(r'^[0-9a-f]+', input_lines[0]).group(0)
            self.canonical_position = ParsedObject.get_canonical_position(self.position)
            self.assembler_section = assembler_section
            self.original_code = input_lines[1:]
            
            call_regex_compiled = (ignore_function_pointers and re.compile(r'\tcall. +[^\*]+\n')) or re.compile(r'\tcall. +.+\n')

            lines_where_i_call = [line for line in input_lines if call_regex_compiled.search(line)]

            if not ignore_function_pointers and not function_pointer_id:
                function_pointer_id = [1]

            for line in lines_where_i_call:
                # we'll catch call and callq for the moment
                called = (call_regex_compiled.search(line).group(0))[8:].lstrip(' ').rstrip('\n')
                if called[0] == '*' and ignore_function_pointers == False:
                    # we have a function pointer, which we'll want to give a distinct name
                    address = '0'
                    name = '_._function_pointer_' + str(function_pointer_id[0])
                    function_pointer_id[0] += 1

                    self.what_do_i_call.append((address, name))

                else: # we're not on a function pointer
                    called_split = called.split(' ')
                    if len(called_split) == 2:
                        address, name = called_split
                        name = name.strip('<>')
                        # we still want to remove address offsets like +0x09 from the end of name
                        match = re.match(r'^.+(?=\+0x[a-f0-9]+$)', name)
                        if match is not None:
                            name = match.group(0)
                        self.what_do_i_call.append((address, name.strip('<>')))
                    else:  # the format of the "what do i call" is not recognised as a name/address pair
                        self.what_do_i_call.append(tuple(called_split))

    def __str__(self):
        if self.position:
            return 'Memory address ' + self.position + ' with name ' + self.name + ' in section ' + str(
                self.assembler_section)
        else:
            return 'Name ' + self.name

    def __repr__(self):
        out_str = 'Disassembly of section ' + self.assembler_section + ':\n\n' + self.position + ' ' + self.name + ':\n'
        return out_str + '\n'.join([' ' + line for line in self.original_code])


class Parser:
    # Class to manipulate the output of objdump

    def __init__(self, input_file_location='', file_contents=None, sections_to_view=None, ignore_function_pointers=False):
        """Creates a new Parser, given an input file path. That path should be an output from objdump -D.
        Alternatively, supply file_contents, as a list of each line of the objdump output. We expect newlines
         to have been stripped from the end of each of these lines.
         sections_to_view makes sure we only use the specified sections (use [] for 'all sections' and None for none).
        """
        if file_contents is None:
            file_contents = []

        if sections_to_view is None:
            sections_to_view = []

        if input_file_location:
            file_to_read = open(input_file_location, 'r')
            self.source_string_list = [line for line in file_to_read]
            file_to_read.close()
        elif file_contents:
            self.source_string_list = [string + '\n' for string in file_contents]
        self.parsed_objects = []
        self.sections_to_view = sections_to_view
        self.ignore_function_pointers = ignore_function_pointers
        self.pointer_identifier = [1]

    def create_objects(self):
        """ Go through the source_string_list, getting object names (like 'main') along with the corresponding
         definitions, and put them into parsed_objects """
        if self.sections_to_view is None:
            return

        is_in_section = lambda name: self.sections_to_view == [] or name in self.sections_to_view

        parsed_objects = []
        current_object = []
        current_section = ''
        regex_compiled_addr_and_name = re.compile(r'[0-9a-f]+ <.+>:\n')
        regex_compiled_section = re.compile(r'section .+:\n')

        for line in self.source_string_list[4:]:  # we bodge, since the file starts with a little bit of guff
            if regex_compiled_addr_and_name.match(line):
                # we are a starting line
                current_object = [line]
            elif re.match(r'Disassembly of section', line):
                current_section = regex_compiled_section.search(line).group(0).lstrip('section ').rstrip(':\n')
                current_object = []
            elif line == '\n':
                # we now need to stop parsing the current block, and store it
                if len(current_object) > 0 and is_in_section(current_section):
                    parsed_objects.append(ParsedObject(input_lines=current_object, assembler_section=current_section,
                                                       ignore_function_pointers=self.ignore_function_pointers,
                                                       function_pointer_id=self.pointer_identifier))
            else:
                current_object.append(line)

        # now we should be done. We assumed that blocks begin with r'[0-9a-f]+ <.+>:\n' and end with a newline.
        # clear duplicates:

        self.parsed_objects = []
        for obj in parsed_objects:
            if obj not in self.parsed_objects: # this is so that if we jump into the function at an offset,
                # we still register it as being the old function, not some new function at a different address
                # with the same name
                self.parsed_objects.append(obj)

                # by this point, each object contains a self.what_do_i_call which is a list of tuples
                #  ('address', 'name') if the address and name were recognised, or else (thing1, thing2, ...)
                # where the instruction was call thing1 thing2 thing3... .

    def object_lookup(self, object_name='', object_address=''):
        """Returns the object with name object_name or address object_address (at least one must be given).
        If objects with the given name or address
        are not found, returns None."""

        if object_name == '' and object_address == '':
            return None

        trial_obj = self.parsed_objects

        if object_name != '':
            trial_obj = [obj for obj in trial_obj if obj.name == object_name]

        if object_address != '':
            trial_obj = [obj for obj in trial_obj if
                         obj.canonical_position == ParsedObject.get_canonical_position(object_address)]

        if len(trial_obj) == 0:
            return None

        return trial_obj

def get_parsed_objects(filepath, sections_to_view, not_object_file, readable=False, ignore_function_pointers=False):
    if sections_to_view is None:
        sections_to_view = []  # because we use None for "no sections"; the intent of not providing any sections
        # on the command line was to look at all sections, not none

    # first, check whether the given file exists
    if not os.path.isfile(filepath):
        # we'd like to use FileNotFoundError, but we might be running under
        # Python 2, which doesn't have it.
        raise IOError(filepath + 'is not found.')

    #now the file should exist
    if not not_object_file:  #if it is something we need to run through objdump first
        #we need first to run the object file through objdump

        objdump_file_contents = subprocess.check_output(['objdump', '-D', filepath])
        objdump_str = objdump_file_contents.decode('utf-8')

        p = Parser(file_contents=objdump_str.split('\n'), sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers)
    else:
        try:
            p = Parser(input_file_location=filepath, sections_to_view=sections_to_view, ignore_function_pointers=ignore_function_pointers)
        except UnicodeDecodeError:
            logging.error('File could not be parsed as a string. Did you mean to supply --object-file?')
            return False

    if readable: # if we're being called from the command line
        print('File read; beginning parse.')
    #file is now read, and we start parsing

    p.create_objects()
    return p.parsed_objects

def main():
    argumentparser = argparse.ArgumentParser(description="Parse the output of objdump.")
    argumentparser.add_argument('--filepath', metavar="FILEPATH", help="path to input file", type=str, nargs=1)
    argumentparser.add_argument('--not-object-file', help="import text objdump output instead of the compiled file", default=False,
                                action='store_true')
    argumentparser.add_argument('--sections-to-view', metavar="SECTIONS",
                                help="sections of disassembly to view, like '.text'; leave blank for 'all'",
                                type=str, nargs='*')
    argumentparser.add_argument('--ignore-function-pointers', help='whether to skip parsing calls to function pointers', action='store_true', default=False)

    parsed = argumentparser.parse_args()
    
    filepath = parsed.filepath[0]
    sections_to_view = parsed.sections_to_view
    not_object_file = parsed.not_object_file
    readable = True
    function_pointers = parsed.ignore_function_pointers

    parsed_objs = get_parsed_objects(filepath, sections_to_view, not_object_file, readable, function_pointers)
    if parsed_objs is False:
        return 1

    if readable:
        for named_function in parsed_objs:
            print(named_function.name)
            print([f[-1] for f in named_function.what_do_i_call])  # use [-1] to get the last element, since:
        #either we are in ('address', 'name'), when we want the last element, or else we are in (thing1, thing2, ...)
        #so for the sake of argument we'll take the last thing

if __name__ == "__main__":
    main()
