import json
import subprocess
import argparse
from collections import defaultdict

from prettytable import PrettyTable
import yt.wrapper as yt

from infra.capacity_planning.utils.gpu_scaler.config import (
    COMMAND_POD_SET,
    COMMAND_GPU_LIST_DC,
    PATH,
    CLUSTER,
    HEADERS
)

from infra.capacity_planning.library.python.yt_tables import create_table


def get_data_by_command_dc_name(command):
    return json.loads(subprocess.getoutput(command))


class Table:
    def __init__(self, token, headers=None):
        self.token = token
        self.headers = headers if headers else HEADERS
        self.table_dict = defaultdict(dict)
        self.table = PrettyTable(self.headers)
        self.first_r(get_data_by_command_dc_name(COMMAND_POD_SET.format('sas')), 'SAS')
        self.first_r(get_data_by_command_dc_name(COMMAND_POD_SET.format('vla')), 'VLA')
        self.second_r(get_data_by_command_dc_name(COMMAND_GPU_LIST_DC.format('sas')), 'SAS')
        self.second_r(get_data_by_command_dc_name(COMMAND_GPU_LIST_DC.format('vla')), 'VLA')
        self.filter_empty_rows()

    def first_r(self, nannies, dc):
        for nanny in nannies:
            nanny_name = nanny[0]
            nanny_data = nanny[1]
            if 'downscaling' in nanny_data and 'periods' in nanny_data['downscaling']:
                periods = nanny_data['downscaling']['periods'][0]
                pod_share = periods['pod_share']
                start = periods['start_time']
                end = periods['finish_time']
                days = periods['days_of_week']
                if len(days) == 7:
                    days = 'everyday'
                row = {
                    'nanny': nanny_name,
                    'dc': dc,
                    'pod_share': pod_share,
                    'start_time': f'{start["hours"]}:{start["minutes"]}:{start["seconds"]}',
                    'finish_time': f'{end["hours"]}:{end["minutes"]}:{end["seconds"]}',
                    'days_of_week': days,
                }
                self.table_dict[(nanny_name, dc)] = row

    def second_r(self, nannies, dc):
        for nanny in nannies:
            nanny_name = nanny[1].replace('_', '-')
            key = (nanny_name, dc)
            gpu_per_pod = len(nanny[3])
            scaled_gpu = None
            row = self.table_dict[key]
            pods = row.get('pods', 0) + 1
            total_gpu = pods * gpu_per_pod
            pod_share = row.get('pod_share', None)
            if pod_share:
                scaled_gpu = pod_share * total_gpu
            self.table_dict[key]['nanny'] = nanny_name
            self.table_dict[key]['dc'] = dc
            self.table_dict[key]['pods'] = pods
            self.table_dict[key]['gpu_per_pod'] = gpu_per_pod
            self.table_dict[key]['total_gpu'] = total_gpu
            self.table_dict[key]['scaled_gpu'] = scaled_gpu

    def filter_empty_rows(self):
        self.table_dict = {key: val for key, val in self.table_dict.items() if all(val.values())}

    def get_table(self):
        for key in self.table_dict:
            self.table.add_row([self.table_dict[key].get(col, '') for col in self.headers])
        return self.table

    def upload_data_to_yt(self):
        yt_client = yt.YtClient(CLUSTER, token=self.token)
        schema = [
            'nanny:string',
            'dc:string',
            'pods:int64',
            'gpu_per_pod:int64',
            'total_gpu:int64',
            'pod_share:double',
            'scaled_gpu:double',
            'start_time:string',
            'finish_time:string',
            'days_of_week:string',
        ]
        create_table(client=yt_client, path=PATH, schema=schema)
        for key, value in self.table_dict.items():
            if 'days_of_week' in value:
                item = value['days_of_week']
                value['days_of_week'] = ', '.join(item) if isinstance(item, list) else item
        yt_client.write_table(PATH, self.table_dict.values())


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--token', type=str, default=None)
    parser.add_argument('--update-yt-table', required=False, action='store_true')
    return parser.parse_args()


def main():
    args = parse_args()
    table_cls = Table(token=args.token)
    if args.update_yt_table:
        table_cls.upload_data_to_yt()

    print(table_cls.get_table())
