#!/usr/bin/env python3

import sys
import math
import argparse

hpl_defaults = {
    'Ns_sz': 1,
    'Ns': 'FIXME',
    'NBs_sz': 1,
    'NBs': 576,
    'Ps': 'FIXME',
    'Qs': 'FIXME',
    'NDIVs': 2,
    'BCASTs': 3
}


hpl_data_tmpl = """HPLinpack benchmark input file
Innovative Computing Laboratory, University of Tennessee
HPL.out      output file name (if any)
6            device out (6=stdout,7=stderr,file)
{Ns_sz}    # of problems sizes (N)
{Ns}         Ns
{NBs_sz}             # of NBs
{NBs}        NBs
0            PMAP process mapping (0=Row-,1=Column-major)
1            # of process grids (P x Q)
{Ps}           Ps
{Qs}           Qs
16.0         threshold
1            # of panel fact
2        PFACTs (0=left, 1=Crout, 2=Right)
1            # of recursive stopping criterium
2          NBMINs (>= 1)
1            # of panels in recursion
{NDIVs}            NDIVs
1            # of recursive panel fact.
0          RFACTs (0=left, 1=Crout, 2=Right)
1            # of broadcast
{BCASTs}          BCASTs (0=1rg,1=1rM,2=2rg,3=2rM,4=Lng,5=LnM)
1            # of lookahead depth
0            DEPTHs (>=0)
1            SWAP (0=bin-exch,1=long,2=mix)
192          swapping threshold
1            L1 in (0=transposed,1=no-transposed) form
0            U  in (0=transposed,1=no-transposed) form
1            Equilibration (0=no,1=yes)
8            memory alignment in double (> 0)
"""

def main():
    parser = argparse.ArgumentParser(description='HPLinpack benchmark input file generator')
    parser.add_argument("-o", dest='output', default='HPL.dat', help='Output file')
    parser.add_argument("--stdout", action='store_true', default=False)
    parser.add_argument('--nbs', type=int, default=hpl_defaults['NBs'])
    parser.add_argument('--ndivs', type=int, default=hpl_defaults['NDIVs'])
    parser.add_argument('--bcasts', type=int, default=hpl_defaults['BCASTs'])
    parser.add_argument('--problems', type=int, default=1, help='Total number of problems')
    parser.add_argument('--gpu-per-node', type=int, default=8)
    parser.add_argument('memsize', type=int, help='Total memory per GPU in MB')
    parser.add_argument('nodes', type=int, help='Total number of nodes')

    args = parser.parse_args()

    ranks = args.nodes * args.gpu_per_node
    
    qbegin = int(math.sqrt(ranks))
    for q in range(qbegin, 0, -1):
        p,r = divmod(ranks, q)
        if r == 0:
            break

    # Find approximate value
    v=math.sqrt(float(ranks)*float(args.memsize)*1024*1024/8)
    max_val=round(math.floor(v/args.nbs))*args.nbs
    ideal_val=int(math.floor(v/args.nbs)*0.99)*args.nbs
    istep=round((0.005*max_val)/args.nbs)*args.nbs

    ns_list = []
    problems_half = int((args.problems+1)/2)
    for v in range(-problems_half,problems_half+1):
        n=ideal_val+v*istep
        ns_list.append(n)
    
    cfg = hpl_defaults
    cfg['NBs'] = args.nbs
    cfg['NDIVs'] = args.ndivs
    cfg['BCASTs'] = args.bcasts
    cfg['Ps'] = p
    cfg['Qs'] = q
    cfg['Ns'] = " ".join(str(ns) for ns in ns_list[0:args.problems])
    cfg['Ns_sz'] = args.problems
                         
    
    cfg_blob = hpl_data_tmpl.format(**cfg)
    if args.stdout:
        print(cfg_blob)
        return
    else:
        open(args.output, "+w").write(cfg_blob)
        print(args.output)

if __name__ == "__main__":
    main()
