#!/usr/bin/env python

import sys, glob

devices = glob.glob("/dev/nvidia[0-9]")
if len(devices) == 0:
    print("CUDA not available or no CUDA-capable GPU found.")
    sys.exit(0)


cc_cores_per_SM_dict = {
    (2,0) : 32,
    (2,1) : 48,
    (3,0) : 192,
    (3,5) : 192,
    (3,7) : 192,
    (5,0) : 128,
    (5,2) : 128,
    (6,0) : 64,
    (6,1) : 128,
    (7,0) : 64,
    (7,5) : 64,
    (8,0) : 64,
    (8,6) : 128,
    (8,9) : 128,
    (9,0) : 128,
    (10,0) : 128,
    (12,0) : 128
    }

try:
    from numba import cuda
    device = cuda.get_current_device()
    ctx = cuda.current_context()
    meminfo = ctx.get_memory_info()
    compute_capability = device.compute_capability
    sms = getattr(device, 'MULTIPROCESSOR_COUNT')
    cores_per_sm = cc_cores_per_SM_dict.get(compute_capability)
    if not cores_per_sm:
        cores_per_sm = "unknown"
        total_cores = "unknown"
    else:
        total_cores = cores_per_sm * sms

    print(f"                 GPU Name: {device.name if type(device.name) is str else device.name.decode()}")
    print(f"       Compute Capability: {'.'.join(list(map(str, compute_capability))):>7}")
    print(f"Streaming Multiprocessors: {sms:>7}")
    print(f"        CUDA Cores per SM: {cores_per_sm:>7}")
    print(f"         Total CUDA Cores: {total_cores:>7}")
    print(f"             Total Memory: {meminfo.total / 1024 / 1024:>7.0f} mb")
    print(f"              Free Memory: {meminfo.free / 1024 / 1024:>7.0f} mb")
except Exception as e:
    print("CUDA not available or no CUDA-capable GPU found.")
