#!/usr/bin/python

import re
import sys
import getopt

####################################
#
# PMAPSUMMARY
#
# Analyses pmap output and summarizes it
# 
# ----------------------------------
#
# (c) Sergey Klyaus, Tune-IT, 2012
# myaut@tune-it.ru
#
####################################

# int() converter for pmap output
def mint(s):
    if s == '-':
        return 0
    return int(s)
    
class SegException(Exception):
    pass

util_info = '''PMAP summary 

Analyses information collected by ps/pmap utils and summarizes
process segments (data, pmapsummary.py code, heap, etc.)
For Solaris 10/11

Usage:  pmapsummary.py {-r|-v} [-s no|only|once|all] [-c COL1,COL2] [-S SORTCOL] [pmap-collect.out]
        pmapsummary.py -h
            -r is for RSS (Resident Set Size)
            -v is for VSZ (Virtual Size)
            -s is specified how to show shared segs 
            -c specifies what columns to show 
            -S specifies sort key
            
Columns:
    ANON - anonymous memory
    HEAP - heap
    DISM - intimated shared memory (SHM)
    MMAP - mmapped devices and files
    CODE - mmapped program text from binary and shared libraries
    DATA - mmapped program data from binary and shared libraries
    STK - thread stacks
    TOTAL - all of above
Program text, DISM and anonymous memory with 's' permission are considered to be shared
Shared memory accounted only once for process first reported process 
    NOTE: despite the fact that mmap is also shared, it is hard to determine it from pmap -x output, 
so we assume all MMAP segments are not shared.

To collect pmap-collect.out, issue (from bash, if you want =~ operator to be working):
    ( ps -eo pid,rss,vsz,comm |
        while read PID RSS VSZ COMM
        do
            if [[ $COMM =~ "ora" ]]; then
                echo --
                echo $PID $VSZ $RSS
                pmap -x $PID;
            fi
        done ) > /tmp/pmap-collect.out

(c) Sergey Klyaus, Tune-IT, 2012
Report bugs to myaut@tune-it.ru
'''    
    
##################
# Basic objects & utils    
    
class Segment:
    SEG_UNKNOWN = None
    
    SEG_TOTAL   = -1    # 9
    SEG_ANON    = 0
    SEG_HEAP    = 1
    SEG_DISM    = 2
    SEG_MMAP    = 3
    SEG_CODE    = 4
    SEG_DATA    = 5
    SEG_STACK   = 6
    SEG_SHARED  = 7
    
    SEG_NAMES   = ['ANON', 'HEAP', 'DISM',
                   'MMAP', 'CODE', 'DATA',
                   'STK', 'TOTAL']
    SEG_TYPE_COUNT  = 8
    
    STACK_NORMAL    = 0
    STACK_ALT       = 1
    
    vsz = 0
    rss = 0
    
    shared = False
    
    seg_type = -1
    
    # Static
    dism_re = re.compile('\[ dism shmid=(.*?) \]')
    dev_re = re.compile('dev:(\d+),(\d+) ino:(\d+)')
    stack_re = re.compile('\[ (alt)?stack( tid=(.*?))? \]')
    
    def __init__(self, addr, vsz, rss, perm, mapf):
        # References to processes and segments
        self.refs = []
        self.parent = None
        
        self.addr = addr
        self.vsz = vsz
        self.rss = rss
        self.perm = perm
        self.mapf = mapf
        
        self.parse_map_file(mapf)
    
    # Depending on map file determine segment type
    def parse_map_file(self, mapf):
        if mapf.startswith('[') and mapf.endswith(']'):
            if 'heap' in mapf:
                self.seg_type = Segment.SEG_HEAP
            elif 'anon' in mapf:
                self.anon_shared = 's' in self.perm
                self.seg_type = Segment.SEG_ANON
            elif 'stack' in mapf:
                match = Segment.stack_re.match(mapf)
                if match is None:
                    raise SegException("Couldn't parse stack tag %s" % repr(mapf))
                
                self.stack_tid = 0
                self.stack_type = Segment.STACK_NORMAL 
                self.seg_type = Segment.SEG_STACK
                
                if match.group(1) == 'alt':
                    self.stack_type = Segment.STACK_ALT
                if match.group(3) is not None:
                    self.stack_type = int(match.group(3))
            elif 'dism' in mapf:
                match = Segment.dism_re.match(mapf)
                if match is None:
                    raise SegException("Couldn't parse dism tag %s" % repr(mapf))
                
                self.dism_id = match.group(1)
                self.seg_type = Segment.SEG_DISM
        elif mapf.startswith('dev'):
            match = Segment.dev_re.match(mapf)
            if match is None:
                raise SegException("Couldn't parse mmap dev tag %s" % repr(mapf))
            
            self.dev_major = match.group(1)
            self.dev_minor = match.group(2)
            self.dev_inode = match.group(3)
            
            self.seg_type = Segment.SEG_MMAP
        else:
            # Mapped file/binary
            if self.perm == 'r-x--':
                self.seg_type = Segment.SEG_CODE
            elif self.perm == 'rwx--':
                self.seg_type = Segment.SEG_DATA
            else:
                raise SegException("Couldn't parse map file %s" % repr(mapf))
            
    def add_reference(self, seg):
        self.refs.append(seg)
        
        seg.shared = True
        seg.parent = self
        
class Process:
    '''Process is bunch of segments'''
    PROC_RSS        = 0
    PROC_VSZ        = 1
    
    SHR_NO          = 0
    SHR_ONCE        = 1
    SHR_ALL         = 2
    SHR_ONLY        = 3
    
    def __init__(self, pid):
        self.segments = []
        self.pid = pid
        
        self.name = ''
        
        self.mem = [0] * Segment.SEG_TYPE_COUNT 
        
    def add(self, seg):
        self.segments.append(seg)
    
    def account(self, what, shared=0):
        for seg in self.segments:
            # In shared mode we show shared segments 
            # in non-shared mode we show shared segments only once
            if (shared == Process.SHR_ALL)                                          or \
               (shared == Process.SHR_ONCE   and not seg.shared)                    or \
               (shared == Process.SHR_NO     and (not seg.shared and not seg.refs))  or \
               (shared == Process.SHR_ONLY   and (seg.shared or seg.refs)):
                if what == Process.PROC_RSS:
                    self.mem[seg.seg_type] += seg.rss
                elif what == Process.PROC_VSZ:
                    self.mem[seg.seg_type] += seg.vsz
        
        self.mem[-1] = sum(self.mem[:-1])
        
        return self.mem
    
    def dump(self, printer):
        if self.segments:
            printer.print_line(self.pid, self.mem, self.name)

class Printer:
    def __init__(self, col_names = None): 
        if col_names is None:
            self.columns = range(Segment.SEG_TYPE_COUNT)
            self.col_count = Segment.SEG_TYPE_COUNT
        else:
            self.columns = [Segment.SEG_NAMES.index(col_name)
                            for col_name in col_names]
            self.col_count = len(self.columns)
    
    def print_line(self, first, segments, last):
        fmtstr = '%6s ' + '%8s ' * self.col_count + '%s'
        segments = tuple(map(str, [segments[i] for i in self.columns]))
        
        print fmtstr % ((str(first), ) +  segments + (last, ))
        
class SegmentCache(dict):
    '''SegmentCache stores map <key -> Segment>,
    that allows to determine shared segments which have
    equal value of key'''
    def add(self, seg, key):
        if key in self:
            parent = self[key]
            parent.add_reference(seg)
        else:
            self[key] = seg

##################
# Pmap parser / summary dump            
            
class ProcessMap:
    # state | line
    #  -4   | --
    #  -3   | 18346 1028352 653640
    #  -2   | 18346:  oracleorcl (LOCAL=NO)
    #  -1   |          Address     Kbytes        RSS       Anon     Locked Mode   Mapped File
    #  0    | 0000000100000000     211656     174200          -          - r-x--  oracle
    
    LINE_DELIMITER  = -4
    LINE_PS_PROC    = -3
    LINE_PMAP_PROC  = -2
    LINE_PMAP_HEADER= -1
    LINE_PMAP_DATA  = 0
    
    def __init__(self, pmap_collect):
        self.pmap_collect = pmap_collect
        self.processes = []
        
        self.mem = [0] * Segment.SEG_TYPE_COUNT 
        
        # Dict caches
        self.code_cache = SegmentCache()
        self.dism_cache = SegmentCache()
        self.share_anon_cache = SegmentCache()

    def parse(self):
        state = 0
        proc = None
        lnum = 0

        for line in self.pmap_collect:
            line = line.strip()
            lnum += 1
            
            if 'total' in line:
                continue
            
            if line.startswith('--'):
                state = ProcessMap.LINE_DELIMITER
                segstack = []
                commit = False
            else:
                state += 1
            
            if state == ProcessMap.LINE_PS_PROC:
                pid = int(line.split()[0])
                proc = Process(pid)
                
                self.processes.append(proc)
            
            if state == ProcessMap.LINE_PMAP_PROC:
                proc.name = line.split(':')[1].strip()
            
            if state >= ProcessMap.LINE_PMAP_DATA:
                seg = self.parse_seg_line(line, lnum)
                if seg is not None:
                    self.add_shared_segment(seg)
                    proc.add(seg)
    
    def account(self, what, shared):
        for proc in self.processes:
            pmem = proc.account(what, shared)
            self.mem = map(sum, zip(self.mem, pmem))
                
    def parse_seg_line(self, line, lnum):   
        ls = line.split(None, 6)
        
        if len(ls) != 7:
            print >> sys.stderr, 'Parse error on line %d [%s]' % (lnum, line)
            return None
        
        try:
            addr = int(ls[0], 16)
            vsz = mint(ls[1])
            rss = mint(ls[2])
            perm = ls[5]
            mapf = ls[6]
        
            return Segment(addr, vsz, rss, perm, mapf)
        except SegException, se:
            print >> sys.stderr, '%s on line %d [%s]' % (se, lnum, line)
        except Exception, e:
            print >> sys.stderr, '%s on line %d [%s]' % (e, lnum, line)
            raise
            
    def add_shared_segment(self, seg):
        # Determine shared segment 
        if seg.seg_type == Segment.SEG_CODE:
            self.code_cache.add(seg, seg.mapf)
        elif seg.seg_type == Segment.SEG_DISM:
            self.dism_cache.add(seg, seg.dism_id)
        elif seg.seg_type == Segment.SEG_ANON and seg.anon_shared:
            self.share_anon_cache.add(seg, seg.addr)
    
    def sort(self, col = None):
        if col is None:
            self.processes = sorted(self.processes, lambda p1, p2: cmp(p1.pid, p2.pid))
        else:
            self.processes = sorted(self.processes, lambda p1, p2: cmp(p1.mem[col], p2.mem[col]))
            
        # self.processes
    
    def dump(self, printer, total=True):
        printer.print_line('PID', Segment.SEG_NAMES, 'NAME')
        
        for proc in self.processes:
            proc.dump(printer)
        
        if total:
            printer.print_line('--->', self.mem, '<--- total')

##################
# Command line 

options, args = getopt.getopt(sys.argv[1:], 'rvc:s:S:')
col_names = None
sort_col = None
shared = Process.SHR_ONCE

options = dict(options)

if '-h' in options or len(args) == 0:
    print util_info
    sys.exit(0)

if args[0] == '-':
    pmap_out = sys.stdin
else:
    pmap_out = file(args[0], 'r')

if '-s' in options:
    try:
        shared = ['no', 'once', 'all', 'only'].index(options['-s'])
    except ValueError, ve:
        print >> sys.stderr, "Invalid shared mode specified: %s" % ve
        sys.exit(1)

if '-S' in options:
    try:
        sort_col = Segment.SEG_NAMES.index(options['-S'])
    except ValueError, ve:
        print >> sys.stderr, "Invalid column specified: %s" % ve
        sys.exit(1)
        
if '-r' in options:
    what = Process.PROC_RSS
elif '-v' in options:
    what = Process.PROC_VSZ
else:
    print >> sys.stderr, "Please specify -r for RSS or -v for VSZ"
    sys.exit(1)
    
if '-c' in options:
    col_names = options['-c'].split(',')

try:
    printer = Printer(col_names)
except ValueError, ve:
    print >> sys.stderr, "Invalid column specified: %s" % ve
    sys.exit(1)

# Total column with shr=all is useless
total = shared == Process.SHR_NO or shared == Process.SHR_ONCE

pmap = ProcessMap(pmap_out)
pmap.parse()
pmap.account(what, shared)
pmap.sort(sort_col)
pmap.dump(printer, total)
