from collections import namedtuple
from datetime import datetime
from typing import Dict, List, Tuple
from mythril.laser.plugin.builder import PluginBuilder
from mythril.laser.plugin.interface import LaserPlugin
from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.ethereum.state.global_state import GlobalState
from datetime import datetime
import logging
# Type annotations:
#   start_time: datetime
#   end_time: datetime
_InstrExecRecord = namedtuple("_InstrExecRecord", ["start_time", "end_time"])
# Type annotations:
#   total_time: float
#   total_nr: float
#   min_time: float
#   max_time: float
_InstrExecStatistic = namedtuple(
    "_InstrExecStatistic", ["total_time", "total_nr", "min_time", "max_time"]
)
# Map the instruction opcode to its records if all execution times
_InstrExecRecords = Dict[str, List[_InstrExecRecord]]
# Map the instruction opcode to the statistic of its execution times
_InstrExecStatistics = Dict[str, _InstrExecStatistic]
log = logging.getLogger(__name__)
[docs]
class InstructionProfilerBuilder(PluginBuilder):
    name = "instruction-profiler"
    def __call__(self, *args, **kwargs):
        return InstructionProfiler() 
[docs]
class InstructionProfiler(LaserPlugin):
    """Performance profile for the execution of each instruction."""
    def __init__(self):
        self._reset()
    def _reset(self):
        self.records = dict()
        self.start_time = None
[docs]
    def initialize(self, symbolic_vm: LaserEVM) -> None:
        @symbolic_vm.instr_hook("pre", None)
        def get_start_time(op_code: str):
            def start_time_wrapper(global_state: GlobalState):
                self.start_time = datetime.now()
            return start_time_wrapper
        @symbolic_vm.instr_hook("post", None)
        def record(op_code: str):
            def record_opcode(global_state: GlobalState):
                end_time = datetime.now()
                try:
                    self.records[op_code].append(
                        _InstrExecRecord(self.start_time, end_time)
                    )
                except KeyError:
                    self.records[op_code] = [
                        _InstrExecRecord(self.start_time, end_time)
                    ]
            return record_opcode
        @symbolic_vm.laser_hook("stop_sym_exec")
        def print_stats():
            total, stats = self._make_stats()
            s = "Total: {} s\n".format(total)
            for op in sorted(stats):
                stat = stats[op]
                s += "[{:12s}] {:>8.4f} %,  nr {:>6},  total {:>8.4f} s,  avg {:>8.4f} s,  min {:>8.4f} s,  max {:>8.4f} s\n".format(
                    op,
                    stat.total_time * 100 / total,
                    stat.total_nr,
                    stat.total_time,
                    stat.total_time / stat.total_nr,
                    stat.min_time,
                    stat.max_time,
                )
            log.info(s) 
    def _make_stats(self) -> Tuple[float, _InstrExecStatistics]:
        periods = {
            op: list(
                map(lambda r: r.end_time.timestamp() - r.start_time.timestamp(), rs)
            )
            for op, rs in self.records.items()
        }
        stats = dict()
        total_time = 0
        for _, (op, times) in enumerate(periods.items()):
            stat = _InstrExecStatistic(
                total_time=sum(times),
                total_nr=len(times),
                min_time=min(times),
                max_time=max(times),
            )
            total_time += stat.total_time
            stats[op] = stat
        return total_time, stats