# -*- coding: utf-8 -*-
import argparse
import csv
import pandas as pd

from datacloud.input_pipeline.input_checker.constants import delimiter_vars


def join_csv(inputs, outfile):
    if len(inputs) != 2:
        raise ValueError('Expected 2 table paths')

    a = pd.read_csv(inputs[0], sep=delimiter)
    print(a.head())
    b = pd.read_csv(inputs[1], sep=delimiter)
    print(b.head())

    on = ['external_id', 'retro_date']
    cols_to_use = list(b.columns.difference(a.columns)) + on
    print(cols_to_use)
    merged1 = a.merge(b[cols_to_use], on=on, how='left')

    cols_to_use = list(a.columns.difference(b.columns)) + on
    print(cols_to_use)
    merged2 = b.merge(a[cols_to_use], on=on, how='left')

    final = pd.concat([merged1, merged2], sort=True).drop_duplicates().reset_index(drop=True)
    del merged1
    del merged2

    fin_cols = final.columns
    print(fin_cols)
    same_cols = a.columns & b.columns
    print(same_cols)
    rest_cols = list(fin_cols ^ same_cols)
    print(rest_cols)

    new_order = on + list(set(same_cols) ^ set(on)) + rest_cols
    final = final.reindex_axis(new_order, axis=1)
    print(final.head())

    final.to_csv(outfile, index=False, sep=delimiter)


def concat_csv(inputs, outfile):
    # # First determine the field names from the top line of each input file
    fieldnames = set()
    for filename in inputs:
        with open(filename, 'r') as f_in:
            reader = csv.reader(f_in, delimiter=delimiter)
            headers = next(reader)
            for h in headers:
                fieldnames.add(h)

    fieldnames = list(fieldnames)
    print(fieldnames)
    # Then copy the data
    with open(outfile, 'w') as f_out:
        writer = csv.DictWriter(f_out, fieldnames=fieldnames, delimiter=delimiter)
        writer.writeheader()
        for filename in inputs:
            with open(filename, 'r') as f_in:
                reader = csv.DictReader(f_in, delimiter=delimiter)
                for line in reader:
                    writer.writerow(line)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', '--list', nargs='+', help='<Required> List of input files', required=True)
    parser.add_argument('-o', '--output', help='<Required> Outputfile', required=True)
    parser.add_argument(
        '-d',
        '--delimiter',
        default='tab',
        choices=delimiter_vars.keys(),
        help='delimiter of csv file',
    )
    parser.add_argument(
        '-j',
        '--join',
        action='store_true',
        help='Not only concat files, but join it on key columns'
    )
    args = parser.parse_args()

    inputs = args.list
    print(inputs)

    outfile = args.output
    print(outfile)

    delimiter = delimiter_vars[args.delimiter]
    print(delimiter)

    if args.join:
        join_csv(inputs, outfile)
    else:
        concat_csv(inputs, outfile)
