#!/usr/bin/python

import argparse
import sys, warnings
import math
import re 



parser = argparse.ArgumentParser(description='ROCm Profiler Script: a simple profiling tool for use with hcc profile output')


parser.add_argument('infile', metavar='HCC_PROF_FILE', type=argparse.FileType('r'), 
        help="Input file to process")

parser.add_argument('-t', '--text_trace', 
        action='store_true',
        help="Print profiled commands (kernel, copy, barrier, etc) in start time order to stdout")

parser.add_argument("-g", '--gui_trace', 
        metavar='JSON_FILE', type=argparse.FileType('w'), 
        help="Generate a JSON file suitable for viewing with chrome trace viewer in chrome://tracing")

ROI_HELP = "ROI_POINTs can be specified with line numbers, start times, or search strings." +\
            " Supported options include @LINENUM, ^START_TIME, MATCH_NUM%%SEARCH_RE.  Examples: @1234, ^45679.0123, 2%%MyMarker (finds 2nd instance of MyMarker)";



parser.add_argument('-r', '--roi_start', 
        action="store", 
        help=("Specify start point to use for Region-of-Interest. ROI is inclusive of the start point." + ROI_HELP));

parser.add_argument('-R', '--roi_stop', 
        action="store", 
        help="Specify stop point to use for Region-of-Interest.  ROI is exclusive of the stop point. See roi_start for ROI_STOP format. ")

parser.add_argument('--info', 
        action="store_true",
        help="Print detailed explanation of how to interpet rpt data.  Must specify an input file.")

parser.add_argument('--hide_app_text', 
        action='store_true',
        help="Don't print output from the application on the text trace ; only show the profile records");

#parser.add_argument('--text_trace_verbose', 
#        action='store_true',
#        help="Print verbose timeline");

parser.add_argument('--no_gaps', 
        action="store_true",
        help="Exclude gaps.  Resources will print active commands and compute summaries only for time when they are busy")

#parser.add_argument('--trim_gaps', 
#        action="store_true",
#        help="Trim leading and trailing gaps. ROI for a Resource ranges from start of its first command to end of its last")

parser.add_argument('--ignore_barriers', 
        action="store_true",
        help="Ignore barrier commands.  Sometimes barrier commands are incorrectly tagged with activity that belongs to other kernel events.")


gaps_default = [10, 20, 50, 100, 1000, 10000, 100000] 
parser.add_argument("--gaps", 
        nargs='+', type=int, metavar="GAP",
        default=gaps_default,
        help = "Size of histogram buckets used for gaps breakdown, specified in us.  Default="+str(gaps_default));



parser.add_argument('--topn', 
        type=int,
        default=20,
        help="print top N entries in summary.  Default=10.  -1=all")


parser.add_argument('--db', 
        action="store_true",
        help="print raw log records as read from file")

#parser.add_argument('--summary_items', 
#        action="store_true",
#        help='Number of summary items to print


ROI_HELP2= "  @LINENUM : specify a line number from the input file.  Example: @1342.\n"\
           "  ^TIME    : specify start time from the beginning of the file.  Example: ^55.12345\n"\
           "  MATCH_NUM%SEARCH_RE : specify string in python regular expression format, rpt will use the profile record after the found text. This allows applications to embed markers using simple stderr print statements. Example: %MyStartMarker\n"


def processArgs():
    if len(sys.argv) > 1:
        myargs = sys.argv[1:]
    else:
        if 0:
            # Debug mode to pass dummy args inside the script:
            myargs = "overfeat.fiji.bwd.txt --text_trace"
            #myargs = "barrier_test.prof"
            myargs = myargs.split()
        else:
            parser.print_help()
            exit()
    args = parser.parse_args(myargs)

    args.gaps.sort()

    #if args.text_trace_verbose:
    #    args.text_trace = 1

    #print "info: reading", args.infile
    return args


#---
class TextLogRecord:
    """ Text preceding a profile record"""
    def __init__(self, lineNum, textLine):
        self.text = []
        self.text.append(textLine)
        self.startLineNum = lineNum
        self.stopLineNum = lineNum

    def append(self, textLine):
        self.text.append(textLine) 
        self.stopLineNum += 1

    def isInRange(self, targetLineNum):
        #print "check ", targetLineNum, "==", self.startLineNum, "...", self.stopLineNum
        return targetLineNum >= self.startLineNum and \
               targetLineNum <= self.stopLineNum

    def printMe(self,showLines=False):
        if showLines:
            print "@%d---\n" % (self.startLineNum),
        for l in self.text:
            print ">", l,
        if showLines:
            print "@%d---\n" % (self.stopLineNum)

    def getLineText(self, lineNum): 
        try:
            return self.text[lineNum - self.startLineNum]
        except:
            return None

    def search(self, targetRe):
        """
        Returns line number where specified regular expression is found, or None if not found
        """
        for (li,l) in enumerate(self.text):
            if targetRe.search(l):
                return (li+self.startLineNum, l)

        return (None, None)



#---
class ProfileLogRecord:
    def __init__(self, lineNum, line, text):
        if args.db:
            print "raw:", line

        self.lineNum = lineNum

        self.text    = text

        self.queue = None
        self.cmdNum  = None

        try:
            fields = [f.strip() for f in line.split(";")]
            self.type = fields[0].split(":")[1].strip()
            self.name = fields[1]
            self.time = (float)(fields[2].split(" ")[0])

            try:
                self.enqTime   =  (int)(fields[3])
                self.startTime =  (int)(fields[4])
                self.stopTime  =  (int)(fields[5])
                fi = 6
                if (self.enqTime > self.startTime):
                    raise ValueError
            except:
                # older traces did not have enq time, so set to start time here.
                self.startTime =  (int)(fields[3])
                self.stopTime  =  (int)(fields[4])
                self.enqTime   = self.startTime-10
                fi = 5

            try:
                (self.device, self.queue, self.cmdNum) = [int(f) for f in fields[fi].lstrip("#").split(".")]
            except:
                print "warning: unknown opid tag=", fields[fi]
                self.device = 0
                self.queue = 0
                self.cmdnum = 0


            if self.type == 'barrier':
                try:
                    (self.depcnt, self.acqFence, self.releaseFence) = self.name.split(",")
                    #print "found depcnt=", self.depcnt
                    if (self.depcnt > 0):
                        # hack to prevent barriers from appears to start before their dependents which hides the dependents on the critical path.
                        # Really should do better critical path analsysis here to add the elem
                        self.startTime = self.stopTime
                except:
                    # legacy trace formats.
                    self.depcnt = 0


            if self.isCopyCmd():
                copyField = fi+1
                try:
                    self.sizeBytes = (int)(fields[copyField].split()[0])
                    self.name = self.name + "_" + str(self.sizeBytes) + "_bytes"
                    #self.bandwidthGBps = (float)(sizeBytes*1000.0/1024.0*1000.0/1024.0) / (self.stopTime - self.startTime)
                except:
                    print "warning: can't find copy size in", fields[copyField]

            # these structures filled in after sorting the records:
            self.criticalTime = None
            self.gapFromPrev  = 0     # gap in ns from previous command.

            self.valid = True
        except:
            print "warning: ignored unparsable profile line:", line,

            self.valid = False



    def isCopyCmd(self):
        return self.type == "copy" or self.type == "copyslo"

    def isGpuCmd(self):
        # commands sent to GPU queues:
        return self.type == "kernel" or self.type == "barrier"


    def isInRange(self, targetLineNum):
        return targetLineNum == self.lineNum or \
               self.text != None and self.text.isInRange(targetLineNum)

    def gapName(self):
        gapName = ""
        if (self.gapFromPrev > 0):
            gapName = "gap >=" + str(args.gaps[-1]) + "us"
            for (gi,g) in enumerate(args.gaps):
              if self.gapFromPrev < g*1000:
                  if (gi == 0):
                      gapName = "gap <" + str(g) + "us"
                  else:
                      gapName = "gap " + str(args.gaps[gi-1]) + "us-" + str(g) + "us"
                  break

        return gapName

    @staticmethod
    def printHeader(printInfo):
        if (printInfo):
            print "Display each of the profiling events on a single line.  Fields:"

            print "  Resource : The resource this command executed on.  May be GPU# for a specific GPU or DATA to indicate data transfer command"
            print "  Start(ms) : Start time of the command, displayed in ms.  Command start from time 0"
            print "  +Time(us) : Time which the command contributed to the 'critical' path "
            print "  Type  : Type of the command (kernel, barrier, copy)"
            print "  Dev.Q.Cmd : The command sequm is a unique mononotically increasing number per-queue, so the triple of device.queue.cmdseqnum uniquely identifies the command during the process execution."
            print "  LineNum : Line number from the input HCC_PROF_FILE"
            print "  Name  : Name of the record.  For kernels, the kernel name.  For copy commands, the name contains the direction and size.  For barrier commands the name contains dependency counts and acquire and release fences."


        print "%4s %16s %8s %7s %10s %10s %-30s"  % \
                ("Resource", "Start(ms)", "+Time(us)", "Type", "#Dev.Q.Cmd", "LineNum", "Name")

    def printOneLine(self, startTime, lineNum):
        if lineNum == self.lineNum:
            printMe(None, startTime, false)
        elif self.text != None:
            try:
                t = self.text.getLineText(lineNum)
                if (t != None):
                    print t 
            except ValueError:
                None

    def printMe(self, timeOffset=0, showText=False):
        if showText and not args.hide_app_text and self.text != None:
            #print "@%d-@%d" % (self.text.startLineNum, self.text.stopLineNum)
            self.text.printMe()

        if (not args.no_gaps) and self.gapFromPrev > 0:
            markers = 0
            if self.gapFromPrev:
                markerCnt = 0
                for g in args.gaps:
                    markerCnt += 1
                    if self.gapFromPrev < g*1000:
                        break;
                markers = markerCnt * "=="
                #markers += self.gapName()
            print "%4s %16.6f: %+10.2f %-7s %10s %-30s" % \
                    (Resource.getName(self),\
                     (self.startTime - self.gapFromPrev - timeOffset)/1000000.0, \
                     self.gapFromPrev/1000.0, \
                     "gap", "", markers)

        if (self.criticalTime == None):
            time = 0
        else:
            time = self.criticalTime

        print     "%4s %16.6f: %+10.2f %-7s #%d.%d.%d %8d: %-30s" % \
                (Resource.getName(self), \
                 (self.startTime - timeOffset) / 1000000.0, time/1000.0, \
                 self.type, 
                 self.device, self.queue, self.cmdNum, 
                 self.lineNum, self.name),

        #if (self.cmdNum != None):
        #    try:
        #        print "#%d.%d.%d" % (self.device, self.queue, self.cmdNum),
        #    except: # TODO - remove t
        #        #print "#?%s.%s.%s" % (self.device, self.queue, self.cmdNum),


	# TODO - figure out if this is useful.
        #if (args.text_trace_verbose):
       #     print "  %16.6f - %16.6f" % \
       #             ((self.startTime - timeOffset)/1000000.0, \
       #              (self.stopTime  - timeOffset)/1000000.0),

        print ""


    def printJSON(self, file, timeOffset=0):
        tid = self.queue
        file.write('{ "pid":1, "tid":%d, "ts":%d, "dur":%d, "ph":"X", "name":"%s", "args":{"dev.queue.op":"%d.%d.%d", "stop":%d } }' %\
                (tid, self.startTime/1000, (self.stopTime - self.startTime)/1000, self.name, \
                    self.device, self.queue, self.cmdNum, self.stopTime/1000) ) 
        file.write(',\n')
            

    def update(self, maxStopTime, gapFromPrev):

        if self.startTime < maxStopTime:
            # snap the start time to the end of last command if start too soon
            self.criticalTime = self.stopTime - maxStopTime

            if (self.criticalTime < 0):
                self.criticalTime = 0 # could be < 0 if this command finishes before last one
        else:
            self.criticalTime = self.stopTime - self.startTime


        # TODO - remove me, use the logic above.
        assert self.criticalTime <= (self.stopTime - self.startTime) # can't contribute more than stop-start
        if (gapFromPrev > 0):
            self.gapFromPrev = gapFromPrev



""" A summary record that accumulates stats for ProfileLogRecord that share the same name """
class ProfileSummaryRecord:
    def __init__(self, name):
        self.type         = None   # Type - gap, kernel, copy, etc.
        self.name         = name
        self.refs = 0  # Number of records that share this ProfileSummaryRecord

        self.totalCriticalTime = 0   # sum of all critical times
        self.firstRefTime = None
        self.lastRefTime  = 0
        self.minTime = None
        self.maxTime = 0


    def setFields(self, type, criticalTime, startTime, stopTime):

        self.type = type
        self.refs  += 1
        self.totalCriticalTime += criticalTime

        if (self.minTime == None) or (criticalTime < self.minTime):
            self.minTime = criticalTime
        if (criticalTime > self.maxTime):
            self.maxTime = criticalTime

        if (self.firstRefTime == None) or (startTime < self.firstRefTime):
            self.firstRefTime = startTime
        if (stopTime > self.lastRefTime):
            self.lastRefTime = stopTime


    def addRecord(self, summaryDict, logRec):
        self.setFields(logRec.type, logRec.criticalTime, logRec.startTime, logRec.stopTime)

        gapName = logRec.gapName();

        if (logRec.gapFromPrev > 0):
            summaryDict.setdefault(gapName, ProfileSummaryRecord(gapName)).addGapRecord(logRec)


    def addGapRecord(self, logRec):
        self.setFields("gap", logRec.gapFromPrev, logRec.startTime-logRec.gapFromPrev, logRec.startTime)


    @staticmethod
    def printHeader():
        print "%14s %11s %8s %8s %8s %8s  %-30s" % \
                ("Total(%)", "Time(us)", "Calls", "Avg(us)", "Min(us)", "Max(us)", "Name")


    def printMe(self, totalTimeNs):
        if args.no_gaps and  self.type == "gap":
            None
        else:
            timeUs = self.totalCriticalTime / 1000.0;

            print "%13.2f%% %10.1f %8d %8.1f %8.1f %8.1f  %-30s" %\
                  (self.totalCriticalTime*100.0/totalTimeNs, timeUs,
                   self.refs,
                   timeUs/self.refs,  self.minTime/1000.0, self.maxTime/1000.0, self.name)


"""
Profile information for a specific resource
Resource can be a GPU, Bus, or CPU (eventually)
"""
class Resource:
    def __init__(self, resourceName):
        self.name = resourceName

        # dictionary indexed by log record name 
        # Each dictionary entry is a ProfileSummaryRecord containing 
        # summary info for ProfileLogRecords that share the same name
        self.summaryByName= {} 


        # Running tracker for the last stop time we've seen
        # Assumes records are all added in time order.
        self.maxStopTime = None

        # Time that this resource was busy.  Excludes GAP time.
        self.busyTime = 0

        self.firstLogRecord = None
        self.lastLogRecord = None

        self.opsPerQueue = {}  # hash of queue IDs (0,1,2,3).  Each has index has count of how many ops sent to that queue.

    @staticmethod
    def getName(lr):
        if lr.isGpuCmd():
            return "GPU"+str(lr.device) 
        elif lr.isCopyCmd():
            return "DATA"
        else:
            print "warning:unknown type='%s'" % ( lr.type )
            return None

    def selectMaxTime(self, roiTime):
        # Users may have different interpretations of the gap
        if args.no_gaps:
            return self.busyTime;
        #elif args.trim_gaps:
        #    return self.lastLogRecord.stopTime - self.firstLogRecord.startTime
        else :
            return roiTime


    def addLogRecord(self, lr, roiStartTime):
        gapFromPrev = 0
        if (self.maxStopTime != None):
            gapFromPrev = lr.startTime  - self.maxStopTime

        lr.update(self.maxStopTime, gapFromPrev)

        if args.text_trace:
            lr.printMe(roiStartTime, True)

        if (lr.stopTime > self.maxStopTime):
            self.maxStopTime =  lr.stopTime

        self.summaryByName.setdefault(lr.name, ProfileSummaryRecord(lr.name)).addRecord(self.summaryByName, lr)
        self.busyTime += lr.criticalTime

        self.opsPerQueue.setdefault(lr.queue, 0)
        self.opsPerQueue[lr.queue] += 1


    def printSummary(self, roiTime):

        sortByTime = sorted(self.summaryByName.iteritems(), key=lambda x: x[1].totalCriticalTime, reverse=True)

        maxToPrint = args.topn
        if maxToPrint < 0 or maxToPrint > len(sortByTime) :
            maxToPrint = len(sortByTime)

        print "Resource=%s Showing %d/%d records  %6.2f%% busy" % \
                (self.name, maxToPrint,  len(sortByTime), self.busyTime*100.0/roiTime)

        time = self.selectMaxTime(roiTime)

        ProfileSummaryRecord.printHeader()
        maxToPrint=args.topn;
        for k in sortByTime:
            self.summaryByName[k[0]].printMe(time)
            maxToPrint -= 1
            if (maxToPrint==0):
                break


    def printJSON(self, file):
        for q in self.opsPerQueue:
            print "Q=", q
            file.write('{ "name": "thread_name", "ph": "M", "pid": 1, "tid":%d ,\n' % (q))
            file.write('  "args": { "name" : "stream:%d" }\n' % (q))
            file.write('},\n')



class FileParser:
    def __init__(self, file):
        # Parsing pass - extract profile lines from file, parse, and add to profileLogRecords array
        self.profileLogRecords = []

        lineNum = 0
        text = None
	p = None

        for line in file:
            lineNum = lineNum+1
            if (line.startswith("profile:")):
                #if text != None:
                    #print "TTT=", 
                    #text.printMe(True)
                p = ProfileLogRecord(lineNum, line, text)
                if p.valid and not (args.ignore_barriers and p.type == 'barrier'):
                    #print "ProfileMarker", lineNum, ":", line,
                    self.profileLogRecords.append (p)
                    text = None  # reset to accumulate new text.
                    continue

            #print "Append:", lineNum, ":", line, 
            # not a valid profile line, accumulate lines into text structure.
            if (text == None):
                text = TextLogRecord(lineNum, line)
            else:
                text.append(line)

        if text and p:
            p.text = text



        # Sort all records by start time:
        self.profileLogRecords = sorted(self.profileLogRecords, key=lambda A: A.startTime)

        self.setRoi(True)


    def setRoi(self, show):
        (self.roiStartIndex, roiStartLine) = self.processRoi(self.profileLogRecords, args.roi_start, self.profileLogRecords[0].startTime, 0)
        (self.roiStopIndex,  roiStopLine)  = self.processRoi(self.profileLogRecords, args.roi_stop,  self.profileLogRecords[0].startTime, len(self.profileLogRecords)-1)

        self.roiStart = self.profileLogRecords[self.roiStartIndex]
        self.roiStop  = self.profileLogRecords[self.roiStopIndex]

        if show:
            print "ROI_START:",
            self.roiStart.printMe(self.roiStart.startTime)
            print "ROI_STOP :" ,
            self.roiStop.printMe(self.roiStart.startTime)  # yes, this should be startTime
            print "ROI_TIME=%8.3f secs" % ((self.roiStop.stopTime - self.roiStart.startTime)/1000000000.0)

        assert (self.roiStartIndex <= self.roiStopIndex)
        assert (self.roiStartIndex >=0 and self.roiStartIndex<len(self.profileLogRecords))
        assert (self.roiStopIndex >=0 and self.roiStopIndex<len(self.profileLogRecords))



    def generateResources(self):
        self.resources = {}
        for lri in range(self.roiStartIndex,self.roiStopIndex):
            lr = self.profileLogRecords[lri]
            resourceName = Resource.getName(lr)

            self.resources.setdefault(resourceName, Resource(resourceName))
            self.resources[resourceName].addLogRecord(lr, self.roiStart.startTime)


    def generateJSON(self, file):
        file.write ('{\n')
        file.write ('"traceEvents": [\n')
        for lri in range(self.roiStartIndex,self.roiStopIndex):
            lr = self.profileLogRecords[lri]
            lr.printJSON(file)


        # add resource info like thread/stream names:
        for ri in self.resources:
            self.resources[ri].printJSON(file)

        file.write('"meta_event":"rpt"\n')  # add add least one dummy arg to ensure no dangling ,

        file.write('],\n')  # end traceEvents
        file.write('"meta_name":"rpt"\n')  # add add least one dummy arg to ensure no dangling ,
        file.write('}\n')  # end traceEvents


    @staticmethod
    def processRoi(sortedProfileLogRecords, roiArg, roiStartTime, default):
        if (roiArg == None):
            return (default, None)
        else:
            roiArgMatch = re.search("([0-9]*)([@^%])(.*)",  roiArg)
            try:
                leadingNumbers = int(roiArgMatch.group(1))
            except ValueError:
                leadingNumbers = 1
            markerChar = roiArgMatch.group(2)
            roiStr = roiArgMatch.group(3)

            if markerChar == '@':
                lineNum = int(roiStr)
                # search for first record with exact match on line number:
                for (lri,lr) in enumerate(sortedProfileLogRecords):
                    if lr.isInRange(lineNum):
                        return (lri, lineNum)

                raise Exception ("roi line number invalid ", roiArg)
                
            elif markerChar == '^':
                startTime = int(roiStr)
                # search for first record past the ROI start time
                for (lri,lr) in enumerate(sortedProfileLogRecords):
                    if startTime > (lr.startTime - roiStartTime) :
                        return (lri, None)

                raise Exception ("roi time out-of-range=", roiArg)

            elif markerChar == '%':
                # search regular expression strings:
                lineNum = None
                roiRe = re.compile(roiStr)
                for (lri,lr) in enumerate(sortedProfileLogRecords):
                    if lr.text:
                        (lineNum, line) = lr.text.search(roiRe)
                        if lineNum != None:
                            leadingNumbers -= 1
                            if leadingNumbers <= 0:
                                print "info: %%roi regular expression matched @%d '%s'" %(lineNum, line.rstrip())
                                return (lri, lineNum) 

                raise Exception ("roi search string not found=", roiStr)

            else:
                raise Exception ("Invalid roi expression=", roiArg)


    def printSummary(self):
        for r in self.resources:
            print
            self.resources[r].printSummary(self.roiStop.stopTime-self.roiStart.startTime)

            
def main():
    p = FileParser(args.infile)

    if (args.text_trace):
        ProfileLogRecord.printHeader(args.info)

    p.generateResources()

    if args.gui_trace:
        p.generateJSON(args.gui_trace)

    p.printSummary()


args = processArgs();

if  __name__ == "__main__":
    main()

# TODO
#roi=
#  #= device.queue.op sequmber
#  @= absolute time in ns,us,ms.  Default is ns.    Example: @17.5ms, @5500000

# Fix barrier deps automatically with new format.

# Documentation:
# - add to --help in args.  Explain ROI.
# Example what a Resource is.  
#   - GPU or Data or Host
#   - All commands are executed by a Resource.
#   -  Some commands may be executed in the background or overlap other commands.
#      Show latency of commands.
# 
# - add print_time that prints appropriate time ms, seconds, etc.

# How is barrier started earlier than commands it dependes on ?A
#- could be the deps are not correctly set, maybe it starts before it should?
#- of suffers from hw-induced fals dependenies so it never finishes.

# add -autorange opt

# Per-resource text trace
# Add synchronization points to display.
# Handle dependencies.

# Data to and from
# align headers
# Compute total APIs and total time for each resource
# Show Min/Max/Avg bandwidth?
