#! /usr/bin/env python
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import py3nvml.py3nvml as nvml
from datetime import datetime
import re
import os
import pwd
from subprocess import Popen, PIPE
import argparse
from time import sleep
import sys

parser = argparse.ArgumentParser(description='Print GPU stats')
parser.add_argument('-l', '--loop', action='store', type=int, default=0, help='Loop period')
parser.add_argument('-f', '--full', action='store_true', help='Print extended version')
parser.add_argument('-w', '--width', type=int, default=77, help='Print width')
parser.add_argument('--left', action='store_true', help='Prints left part of process name')

COL1_WIDTH = 33
COL2_WIDTH = 21
COL3_WIDTH = 21
WIDTH = 77
LONG_FORMAT = False
LEN_PROCESS_LESS_NAME = 51

gpu_format_col1 = '| {:>3} {:4} {:>4} {:>4} {:>11}|'
gpu_format_col2 = ' {:>19} |'
gpu_format_col3 = ' {:>8} {:>10} |'
gpu_longformat_col11 = '| {:>3} {:<22} {:>3}  |'
gpu_longformat_col12 = ' {:>15} {:>3} |'
gpu_longformat_col13 = ' {:>19} |'
gpu_longformat_col21 = '| {:>3} {:>4}   {:>3}     {:>11}|'
gpu_longformat_col22 = ' {:>19} |'
gpu_longformat_col23 = ' {:>8} {:>10} |'


def print_header(driver_version, long_format=False):
    n = 0
    print(datetime.now().strftime('%a %b %d %H:%M:%S %Y'))
    print('+' + '-' * WIDTH + '+')
    print('| ' + '{:<34}'.format('NVIDIA-SMI') +
          'Driver Version: {:<26}'.format(driver_version) + '|')
    print('+' + '-' * COL1_WIDTH +
          '+' + '-' * COL2_WIDTH +
          '+' + '-' * COL3_WIDTH + '+')
    n += 4

    if long_format:
        print('| GPU  Name          Persistence-M| Bus-Id       Disp.A | Volatile Uncorr. ECC|')
        print('| Fan  Temp  Perf    Pwr:Usage/Cap|        Memory-Usage | GPU-Util  Compute M.|')
        n += 2
    else:
        print(gpu_format_col1.format(
            'GPU', 'Fan', 'Temp', 'Perf', 'Pwr:Usage/Cap'), end='')
        print(gpu_format_col2.format(
            'Memory-Usage'), end='')
        print(gpu_format_col3.format(
            'GPU-Util', 'Compute M.'))
        n += 1
    print('+' + '=' * COL1_WIDTH +
          '+' + '=' * COL2_WIDTH +
          '+' + '=' * COL3_WIDTH + '+')
    n += 1
    return n


def print_proc_header():
    print(' '*(args.width+2))
    print('+' + '-' * args.width + '+')
    print('| Processes:' + ' '*(args.width - 22) + 'GPU Memory |')
    print(proc_format.format('GPU', 'Owner', 'PID', 'Uptime', 'Process Name', '   Usage', ''))
    print('+' + '=' * args.width + '+')
    return 6


def enabled_str(x):
    if x == 'Enabled':
        return 'On'
    else:
        return 'Off'


def get_num_from_str(s):
    # This regular expression matches numbers at the start of strings
    num_re = re.compile(r'^-?[\d.]+')
    try:
        num = int(float(num_re.match(s).group()) + 0.5)
    except:
        num = 0
    return num


def try_get_info(f, h, default='N/A'):
    try:
        v = f(h)
    except nvml.NVMLError_NotSupported:
        v = default
    return v


def print_gpu_info(index, long_format=False):
    n = 0
    try:
        h = nvml.nvmlDeviceGetHandleByIndex(index)
    except nvml.NVMLError_GpuIsLost:
        # Gpu is dead
        if long_format:
            print('| {:>3}                             |'.format(index), end='')
            print(gpu_format_col2.format(''), end='')
            print(gpu_format_col3.format('', ''))
            print('|              DEAD               |', end='')
            print(gpu_format_col2.format(''), end='')
            print(gpu_format_col3.format('', ''))
            print('+' + '-'*33 + '+' + '-'*21 + '+' + '-'*21 + '+')
            return 3
        else:
            #  gpu_format_col1 = '| {:>3} {:3} {:>5} {:>4} {:>11}|'
            print('| {:>3}          DEAD               |'.format(index), end='')
            print(gpu_format_col2.format(''), end='')
            print(gpu_format_col3.format('', ''))
            return 1

    min_number = try_get_info(nvml.nvmlDeviceGetMinorNumber, h)
    prod_name = try_get_info(nvml.nvmlDeviceGetName, h)
    pers_mode = try_get_info(nvml.nvmlDeviceGetPersistenceMode, h, 0)
    if pers_mode == 0:
        pers_mode = 'Off'
    else:
        pers_mode = 'On'
    pci_info = try_get_info(nvml.nvmlDeviceGetPciInfo, h)
    if pci_info != 'N/A':
        bus_id = pci_info.busId.decode('utf-8')
    else:
        bus_id = '-'
    disp_active = try_get_info(nvml.nvmlDeviceGetDisplayActive, h, 0)
    if disp_active == 0:
        disp_active = 'Off'
    else:
        disp_active = 'On'
    ecc_error = try_get_info(lambda h: nvml.nvmlDeviceGetTotalEccErrors(
        h, nvml.NVML_MEMORY_ERROR_TYPE_UNCORRECTED, nvml.NVML_VOLATILE_ECC), h)
    fan_speed = try_get_info(nvml.nvmlDeviceGetFanSpeed, h)
    perf_state = try_get_info(nvml.nvmlDeviceGetPerformanceState, h)
    temp = try_get_info(lambda h: nvml.nvmlDeviceGetTemperature(h, nvml.NVML_TEMPERATURE_GPU), h, -1)
    power_draw = try_get_info(nvml.nvmlDeviceGetPowerUsage, h, -1000) // 1000
    power_lim = try_get_info(nvml.nvmlDeviceGetPowerManagementLimit, h, -1000) // 1000
    mem_info = try_get_info(nvml.nvmlDeviceGetMemoryInfo, h)
    if mem_info != 'N/A':
        used = mem_info.used >> 20
        total = mem_info.total >> 20
    else:
        used = 0
        total = 0
    util = try_get_info(nvml.nvmlDeviceGetUtilizationRates, h)
    if util != 'N/A':
        gpu_util = util.gpu
    else:
        gpu_util = 0
    compute_mode = try_get_info(nvml.nvmlDeviceGetComputeMode, h, -1)
    if compute_mode == 0:
        mode = 'Default'
    elif compute_mode == 1:
        mode = 'Excl. Thrd'
    elif compute_mode == 2:
        mode = 'Prohibtd'
    elif compute_mode == 3:
        mode = 'Excl. Proc'
    else:
        mode = 'Unknown'

    if long_format:
        print(gpu_longformat_col11.format(min_number, prod_name, pers_mode),
              end='')
        print(gpu_longformat_col12.format(bus_id, disp_active), end='')
        print(gpu_longformat_col13.format(ecc_error))
        print(gpu_longformat_col21.format(
            '{}%'.format(fan_speed), '{}C'.format(temp), 'P{}'.format(perf_state),
            '{:>4}W /{:>4}W '.format(power_draw, power_lim)), end='')
        print(gpu_longformat_col22.format(
            '{:>5}MiB / {:>5}MiB'.format(used, total)), end='')
        print(gpu_longformat_col23.format(
            '{}%'.format(gpu_util), mode))
        print('+' + '-'*33 + '+' + '-'*21 + '+' + '-'*21 + '+')
        n += 3
    else:
        print(gpu_format_col1.format(
            min_number, '{}%'.format(fan_speed), '{}C'.format(temp), perf_state,
            '{:>4}W /{:>4}W '.format(power_draw, power_lim)), end='')
        print(gpu_format_col2.format(
            '{:>5}MiB / {:>5}MiB'.format(used, total)), end='')
        print(gpu_format_col3.format('{}%'.format(gpu_util), mode))
        n += 1

    return n


def cut_proc_name(name, maxlen, left=False):
    if len(name) > maxlen:
        #  return '...' + name[-maxlen+3:]
        if left:
            return name[:maxlen-2] + '..'
        return '..' + name[-maxlen+2:]
    else:
        return name


def get_uname_pid(pid):
    try:
        # the /proc/PID is owned by process creator
        proc_stat_file = os.stat("/proc/%d" % pid)
        # # get UID via stat call
        uid = proc_stat_file.st_uid
        # # look up the username from uid
        username = pwd.getpwuid(uid)[0]
    except:
        username = '???'
    return username


def get_pname(id):
    try:
        sess = Popen(['ps', '-o', 'cmd=', '{}'.format(id)], stdout=PIPE, stderr=PIPE)
        stdout, _ = sess.communicate()
        name = stdout.decode('utf-8').strip()
    except:
        name = ""
    return name


def get_uptime(pid):
    try:
        sess = Popen(['ps', '-q', str(pid), '-o', 'etime='], stdout=PIPE, stderr=PIPE)
        stdout, _ = sess.communicate()
        time = stdout.decode('utf-8').strip()
    except:
        time = '?'
    return time


def main(full=False, left=False):
    driver_version = nvml.nvmlSystemGetDriverVersion()
    header_lines = print_header(driver_version, full)

    # Print the top info - GPU stats
    gpu_lines = 0
    num_gpus = nvml.nvmlDeviceGetCount()
    for i in range(num_gpus):
        gpu_lines += print_gpu_info(i, full)

    # Print the bottom info - Process stats
    if not full:
        print('+' + '-' * COL1_WIDTH +
              '+' + '-' * COL2_WIDTH +
              '+' + '-' * COL3_WIDTH + '+')
    proc_header_lines = print_proc_header()

    proc_lines = 0
    for i in range(num_gpus):
        try:
            h = nvml.nvmlDeviceGetHandleByIndex(i)
        except:
            continue

        procs = try_get_info(nvml.nvmlDeviceGetComputeRunningProcesses, h)
        if procs == 'N/A':
            continue

        min_number = try_get_info(nvml.nvmlDeviceGetMinorNumber, h, i)
        for p in procs:
            uname = get_uname_pid(p.pid)
            procname = get_pname(p.pid)
            uptime = get_uptime(p.pid)
            print(proc_format.format(
                min_number, uname, p.pid, uptime,
                cut_proc_name(procname, args.width-LEN_PROCESS_LESS_NAME, left),
                p.usedGpuMemory >> 20, 'MiB'))
            proc_lines += 1
    print('+' + '-' * args.width + '+')

    #  print('header lines: {}'.format(header_lines))
    #  print('gpu lines: {}'.format(gpu_lines))
    #  print('proc header lines: {}'.format(proc_header_lines))
    #  print('proc lines: {}'.format(proc_lines))
    return header_lines + gpu_lines + proc_header_lines + proc_lines + 1


if __name__ == '__main__':
    args = parser.parse_args()
    proc_format = '| {:>3}  {:>11}  {:>7}  {:>10}  {: <' + str(args.width-LEN_PROCESS_LESS_NAME) + '}  {:>5}{:3<} |'
    nvml.nvmlInit()
    print_lines = main(args.full, args.left)

    if args.loop > 0:
        try:
            while True:
                sleep(args.loop)
                sys.stdout.write("\033[F" * print_lines)
                print_lines_new = main(args.full, args.left)
                if print_lines_new < print_lines:
                    sys.stdout.write((' '*(args.width+2)+'\n')*(print_lines - print_lines_new))
                    sys.stdout.write("\033[F" * (print_lines - print_lines_new))
                print_lines = print_lines_new
        except KeyboardInterrupt:
            pass
    nvml.nvmlShutdown()
