#!/usr/bin/python

####################################
#
# TMPTOP
#
# Prints top usage of tmpfs in solaris
# Uses DTrace
# 
# ----------------------------------
#
# (c) Sergey Klyaus, Tune-IT, 2012
# myaut@tune-it.ru
#
####################################

import subprocess
import sys
import os
import getopt
from string import Template

usage = """tmptop by Sergey Klyaus, myaut@tune-it.ru
Reports statistics for tmpfs

Usage: tmptop [-h] [-k user,proc,mntpt,file] [-s resv,unresv,delta,total] [-l limit] [interval [count]]"""

# global variables are bad, but rewriting code is pain in the ass
USE_HUMAN_FORMAT = False

# str.partition() not supported in python 2.4 (sol10)
def str_partition(str, substr, reverse=False):
    if not reverse:
        idx = str.find(substr)
    else:
        idx = str.rfind(substr)
    
    return (str[:idx], str[idx+len(substr):])

def parse_stat_str(line):
    (param, right) = str_partition(line, ' ')
    (left, value) = str_partition(right, ' ', reverse=True)
    
    keys = left.strip()
    
    return (param, value, keys)

def human_format(num, show_sign=True):
    if not USE_HUMAN_FORMAT:
        return "%d" % int(num)
    
    if num < 0:
        sign = "-"
        num = -num
    else:
        if num == 0: 
            sign = ""
        else: 
            sign = "+"
    
    num = float(num)
    
    if not show_sign:
        sign = ''
    
    for x in ['b', 'K', 'M', 'G']:
        if num < 1024.0:
            return "%s%.1f%s" % (sign, num, x)
        num /= 1024.0
    
##########
# Key formatters
#
# update() updates internal information of formatter (i.e. list of processes)
# parse_and_format(key_line) parses key_line and return list of formatted keys
# 
# keys is global list of all formatters

class KeyFormatter:
    dtrace_str = ''
    header = [""]
    
    def parse_and_format(self, key_line):
        main_key = key_line 
        return (main_key, (key_line))
        
    def update(self):
        pass

class UserFormatter(KeyFormatter):
    dtrace_str = "uid"
    header = ["UID", "USERNAME"]
    
    def update(self):
        self.user_dict = {}
        
        passwd_file = open('/etc/passwd')
        
        for passwd_line in passwd_file.readlines():
            passwd_entry = passwd_line.split(':')
            
            self.user_dict[passwd_entry[2]] = passwd_entry[0]
            
    def parse_and_format(self, key_line):
        uid = key_line
        
        try:
            username = self.user_dict[uid]
        except KeyError:
            username = "???"
        
        return (uid, ("%8s" % uid, username))
    
class ProcFormatter(KeyFormatter):
    dtrace_str = "pid, execname"
    header = ["PID", "CMD"]
    
    def update(self):
        self.proc_dict = {}
        
        ps_proc = subprocess.Popen("ps -e -o pid,comm | tail +2", stdout=subprocess.PIPE, shell=True)
        
        for ps_line in ps_proc.stdout.readlines():
            ps_line = ps_line.strip()
            
            if ps_line == "" and ps_proc.poll() != None:
                break
            
            (pid, comm) = str_partition(ps_line, ' ')
            self.proc_dict[pid] = comm
            
    def parse_and_format(self, key_line):
        key_list = key_line.split()
        pid = key_list[0]
        
        try:
            # Try to get commandline from ps output
            comm = self.proc_dict[pid]
        except KeyError:
            # Otherwise, use execname from dtrace
            comm = key_list[1] + " <dead>"
        
        return (pid, ("%8s" % pid, comm))
            
    
class MountFormatter(KeyFormatter):
    dtrace_str = "((struct tmount*) arg0)->tm_vfsp, stringof(((struct tmount*) arg0)->tm_mntpath)"
    header = ["VFSP", "USED", "MNTPATH"]
    
    def update(self):
        self.df_dict = {}
        
        df_proc = subprocess.Popen("df -F tmpfs -k", stdout=subprocess.PIPE, shell=True)
        
        for df_line in df_proc.stdout.readlines():
            df_line = df_line.strip()
            
            if 'Filesystem' in df_line:
                #ignore header
                continue
            
            if df_line == "" and df_proc.poll() != None:
                break
            
            (swap, kb, used, avail, capacity, mntpt) = df_line.split(None, 6)
            
            self.df_dict[mntpt] = int(used) * 1024
    
    def parse_and_format(self, key_line):
        keys = str_partition(key_line, ' ')
        vfsp = "0x%X" % int(keys[0])
        mntpath = keys[1].strip()
        
        try:
            used = human_format(self.df_dict[mntpath], False)
        except:
            used = '?'
        
        return (mntpath, (vfsp, used, mntpath))
    
class FileFormatter(KeyFormatter):
    dtrace_str = "((struct tmpnode*) arg1)->tn_vnode, stringof(((struct tmpnode*) arg1)->tn_vnode->v_path)"
    header = ["VP", "SIZE", "FILENAME"]
    
    def parse_and_format(self, key_line):
        keys = str_partition(key_line, ' ')
        
        vp = "0x%X" % int(keys[0])
        filename = keys[1].strip()
        
        try:
            sz = human_format(os.path.getsize(filename), False)
        except os.error:
            sz = '?'
        
        return (keys[0], (vp, sz, filename))

keys = {
    "user": UserFormatter(),
    "proc": ProcFormatter(),
    "mntpt": MountFormatter(),
    "file": FileFormatter()
    }
    
##########
# Param formatters
# 
# formats gathered values

def float_formatter(param, fieldlen):
    return ("%%%d.2f" % fieldlen) % param

def size_formatter(num, fieldlen):
    return ("%%%ds" % fieldlen) % human_format(num)

    
param_formatters = {"resv/s" : float_formatter, "unresv/s": float_formatter, 
            "resv" : size_formatter, "unresv" : size_formatter, 
            "delta" : size_formatter, "total" : size_formatter}
print_params = ["resv/s", "unresv/s", "resv", "unresv", "delta", "total"]

##########
# Stats processing
#
# commit_stat creates array of all gathered stats for last period
# process_stat processes statistics: calculates delta, accumulates total delta
# print_stat prints output

def commit_stat(stat_strings, formatter):
    stat = {}

    for stat_str in stat_strings:
        (param, value, key_string) = parse_stat_str(stat_str)
        (main_key, keys) = formatter.parse_and_format(key_string)
        
        if main_key not in stat:
            stat[main_key] = {"key": keys}
            stat[main_key].update(dict(map(lambda x: (x, 0), print_params)))
        
        stat[main_key][param] = value;
    
    return stat

total_stat = {}

def process_stat(stat, interval):
    global total_stat
    
    for main_key in stat.keys():
        if main_key not in total_stat:
            total_stat[main_key] = 0
        
        stat[main_key]["resv"] = int(stat[main_key]["resv"])
        stat[main_key]["unresv"] = -int(stat[main_key]["unresv"])
        
        if stat[main_key]["resv"] == 0 and stat[main_key]["unresv"] == 0:
            # evict entry
            del stat[main_key]
            continue
        
        stat[main_key]["resv/s"] = float(stat[main_key]["resv/s"]) / interval
        stat[main_key]["unresv/s"] = float(stat[main_key]["unresv/s"]) / interval
        
        stat[main_key]["delta"] = stat[main_key]["resv"] + stat[main_key]["unresv"]
        
        total_stat[main_key] += stat[main_key]["delta"]
        stat[main_key]["total"] = total_stat[main_key]
    
        
def stat_formatted(d):
    l = []
    
    for param in print_params:
        s = param_formatters[param](d[param], 8)
        l.append(s)

    l.extend(d["key"])
    
    return l
        
def print_stat(stat, formatter, sortkey, limit):
    # print header
    stat_list = stat.items()
    sorted_stat_list = sorted(stat_list, key=lambda k: k[1][sortkey], reverse=True)
    header = []
    
    if limit != 0:
        sorted_stat_list = sorted_stat_list[-limit:]
    
    for param in print_params:
        header.append(param.upper())
    
    header.extend(formatter.header)    
    header_str = (" ".join("%8s" for h in header)) % tuple(header)
    print header_str
    
    for stat_entry in sorted_stat_list:
        print " ".join(stat_formatted(stat_entry[1]))
    
    print
        
##########
# DTrace utilities
        
dtrace_tmpl = Template('''int lcount;
int limit;

#define PAGESHIFT   `_pageshift
#define PAGEOFFSET  `_pageoffset
#define btoprtob(x)    ((((x) + PAGEOFFSET) >> PAGESHIFT) << PAGESHIFT) 

BEGIN {
    limit = ${limit};
    lcount = ${count};
}

tmp_resv:entry /arg3/ { @s["resv", ${key}] = sum(btoprtob(arg2)); @c["resv/s", ${key}] = count(); }
tmp_unresv:entry { @s["unresv", ${key}] = sum(btoprtob(arg2)); @c["unresv/s", ${key}] = count();  }

tick-${interval} 
/!limit || lcount-- > 0/
{ printf("BEGIN\\n");
    printa(@s); 
    printa(@c); 
    printf("END\\n");
    
    clear(@s);
    clear(@c);
}

tick-${interval} 
/limit && lcount == 0/
{
    exit(0);
}''')

        
def format_interval(secs):   
    return "%ss" % int(secs)    

def run_dtrace(key_str, interval, count):
    if count == 0: 
        limit = "0"
    else:
        limit = "1"
    
    dtrace_script = dtrace_tmpl.substitute(key = key_str, 
                interval = format_interval(str(interval)), 
                limit = limit, 
                count = str(count))
                
    dtrace_cmd = 'dtrace -qCs /dev/stdin'  
    
    dtrace_proc = subprocess.Popen(dtrace_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True)

    dtrace_proc.stdin.write(dtrace_script)
    dtrace_proc.stdin.close()
    
    return dtrace_proc

def parse_dtrace_output(output, formatter, interval, sortkey, limit):
    stat_strings = []

    while True:
        line = output.readline()        
        line = line.strip()
        
        if len(line) == 0:
            if dtrace_proc.poll() != None:
                break
            else:
                continue
        
        if line == "BEGIN":
            stat_strings = []
        elif line == "END":
            formatter.update()
            
            stat = commit_stat(stat_strings, formatter)
            process_stat(stat, interval)
            print_stat(stat, formatter, sortkey, limit)
        else:
            stat_strings.append(line)
        
        sys.stdout.flush()

def usage_error(fault=None):
    if fault:
        print >> sys.stderr, fault
        
    print >> sys.stderr, usage
    sys.exit(0)
        
def get_args():
    global USE_HUMAN_FORMAT
    
    interval = 2
    count = 0
    sortkey = "delta"
    key = "file"
    limit = 0
    
    try:
        optlist, extargs =  getopt.getopt(sys.argv[1:], 'hl:k:s:?')
    except getopt.GetoptError, e:
        usage_error(str(e))
    
    optlist = dict(optlist)
    
    if '-?' in optlist:
        #print usage
        usage_error(None)
    
    if '-h' in optlist:
        USE_HUMAN_FORMAT = True
    
    if '-k' in optlist:
        key = optlist['-k']
        
        if key not in keys:
            usage_error('Invalid grouping key %s, must be %s' % (key, '|'.join(keys.keys())))
    
    if '-l' in optlist:
        try:
            limit = int(optlist['-l'])
        except ValueError:
            usage_error('Limit must be numeric')
    
    if '-s' in optlist:
        sortkey = optlist['-s']
        
        if sortkey not in print_params:
            usage_error('Invalid sorting key %s, must be %s' % (sortkey, '|'.join(print_params)))

    try:
        if len(extargs) == 1:
            interval = int(extargs[0])
        elif len(extargs) == 2:
            interval = int(extargs[0])
            count = int(extargs[1])
        elif len(extargs) > 2:
            usage_error('Extra arguments')
    except ValueError:
            usage_error('Interval and count must be numeric')
        
    return (interval, count, limit, sortkey, key)
        
if __name__ == "__main__":
    interval, count, limit, sortkey, key = get_args()
    
    formatter = keys[key]
    
    dtrace_proc = run_dtrace(formatter.dtrace_str, interval, count)
    parse_dtrace_output(dtrace_proc.stdout, formatter, interval, sortkey, limit)