import sys
import json
import random
import codecs
from collections import defaultdict
from itertools import combinations
from itertools import product
from itertools import groupby

def sample_documents_from_single_serpset(data, max_depth, n_candidates, with_page_screenshot, with_serp_params):
    random.shuffle(data)

    out_data = []

    device_map = {
        0: 'DESKTOP',
        1: 'ANDROID',
        2: 'IPHONE',
        3: 'UNKNOWN',
        4: 'WINDOWS_PHONE'
    }

    for serp in data:
        if 'components' not in serp:
            continue

        query_text = serp['query']['text']
        query_region_id = serp['query']['regionId']
        query_country = serp['query']['country']
        query_device_int = serp['query']['device']
        query_device = device_map.get(query_device_int, 'UNKNOWN')

        this_query_docs = []

        for i, comp in enumerate(serp['components'][:max_depth]):
            if 'url.mimcaMdsUrl' not in comp or 'long.MIMCA_CRC64' not in comp or (with_page_screenshot and 'normalizations.images_page_screenshot' not in comp):
                continue

            image_url = comp['url.imageUrl']
            image_mds = comp['url.mimcaMdsUrl']
            image_crc_long = comp['long.MIMCA_CRC64']
            page_url = comp['componentUrl']['pageUrl']
            snippet_text = comp.get('text.snippet', '')
            snippet_title = comp.get('text.title', '')

            current_doc = {
                            'query': query_text,
                            'region_id': query_region_id,
                            'country': query_country,
                            'device': query_device,
                            'image_url': image_url,
                            'page_url': page_url,
                            'image_crc_long': image_crc_long,
                            'image_mds': image_mds,
                            'title': snippet_title,
                            'snippet': snippet_text
                            }
            if with_page_screenshot:
                current_doc['page_screenshot_mds'] = comp['normalizations.images_page_screenshot']

            if 'url.imageBigThumbHref' in comp:
                current_doc['image_thumb'] = comp['url.imageBigThumbHref']

            if with_serp_params:
                current_doc['queryfresh'] = serp.get('serp_query_param.queryfresh')
                current_doc['query_date'] = serp.get('serp_query_param.query_date')
                current_doc['actuality'] = serp.get('serp_query_param.actuality')

            this_query_docs.append(current_doc)

        n = len(this_query_docs)
        if n < n_candidates:
            print >>sys.stderr, "Not enough candidates for", query_text, query_region_id, query_device
            continue

        for span_i in range(n_candidates):
            start = n * span_i / n_candidates
            end = n * (span_i + 1) / n_candidates
            out_data.append(random.choice(this_query_docs[start:end]))

    return out_data


def sample_support_point_candidates_foreach_serpset(serpsets_info, max_depth, candidates_per_serp, append_serp_name=False, with_page_screenshot=False, with_serp_params=False):
    all_documents = []

    for info_elem in serpsets_info:
        filename = '{}.json'.format(info_elem['id'])
        print >>sys.stderr, "Opening", filename
        with codecs.open(filename, 'r', 'utf8') as f:
            serpset = json.load(f, encoding='utf8')
        print >>sys.stderr, "Sampling..."
        new_docs = sample_documents_from_single_serpset(serpset, max_depth, candidates_per_serp, with_page_screenshot, with_serp_params)
        if append_serp_name:
            for doc in new_docs:
                doc['serp_name'] = info_elem['name']
        all_documents.extend(new_docs)

    return all_documents


def generate_full_document_pairs(data, limit_docs_per_query=32):
    out_data = []

    allowed_fields = ['image_url', 'image_mds', 'page_url', 'image_crc', 'page_screenshot_mds', 'title', 'snippet', 'device', 'queryfresh']

    data_dict = defaultdict(list)

    for elem in data:
        out_elem = {}
        for field in elem:
            if field in allowed_fields:
                out_elem[field] = elem[field]
        data_dict[(elem['query'], elem['region_id'], elem['platform'])].append(out_elem)

    for key in data_dict:
        query, region_id, platform = key

        if len(data_dict[key]) > limit_docs_per_query:
            docs = random.sample(data_dict[key], limit_docs_per_query)
        else:
            docs = data_dict[key]

        for left, right in combinations(docs, 2):
            out_elem = {'query': query,
                        'region_id': region_id,
                        'platform': platform,
                        'device': left['device'],
                        'queryfresh': left['queryfresh']
                        }
            for field in left:
                out_elem['left_' + field] = left[field]
            for field in right:
                out_elem['right_' + field] = right[field]
            out_data.append(out_elem)

    return out_data


def wrap_left_right(data):
    out_data = []

    for elem in data:
        out_elem = {'left': {}, 'right': {}}
        for field, value in elem.items():
            if field.startswith('left_'):
                suffix = field[5:]
                out_elem['left'][suffix] = value
            elif field.startswith('right_'):
                suffix = field[6:]
                out_elem['right'][suffix] = value
            else:
                out_elem[field] = value

        out_elem['left']['platform'] = elem['platform']
        out_elem['right']['platform'] = elem['platform']

        out_data.append(out_elem)

    return out_data


def calc_document_scores(data):
    side_swap = {'left': 'right', 'right': 'left'}

    out_data = []

    data.sort(key=lambda elem: (elem['query'], elem['region_id'], elem['platform']))

    for key, group in groupby(data, key=lambda elem: (elem['query'], elem['region_id'], elem['platform'])):
        query, region_id, platform = key
        device = None
        group_lst = list(group)
        doc_wins = defaultdict(int)
        docs = {}

        for pair in group_lst:
            docs[pair['left']['image_url']] = pair['left']
            docs[pair['right']['image_url']] = pair['right']

            device = pair['device'] # should be the same for all in group

            result = pair['label'] # left of right
            another_result = side_swap[result]

            cur_total = pair['overlap'] * 2
            cur_wins = int(cur_total * pair['probability'] + 0.01)

            if result == 'right':
                doc_wins[pair['right']['image_url']] += cur_wins
                doc_wins[pair['left']['image_url']] += cur_total - cur_wins
            else:
                doc_wins[pair['left']['image_url']] += cur_wins
                doc_wins[pair['right']['image_url']] += cur_total - cur_wins

        doc_stat_arr = []
        for image_url in docs:
            doc_stat_arr.append((doc_wins.get(image_url, 0), docs[image_url]))

        doc_stat_arr.sort(key=lambda x: x[0], reverse=True)

        for n_wins, doc in doc_stat_arr:
            out_elem = {
                        'query': query,
                        'region_id': region_id,
                        'platform': platform,
                        'device': device,
                        'n_wins': n_wins
            }
            out_elem.update(doc)
            out_data.append(out_elem)

    return out_data


def pick_support_points(data, min_docs=3):
    out_data = []

    data.sort(key=lambda elem: (elem['query'], elem['region_id'], elem['platform']))

    for key, group in groupby(data, key=lambda elem: (elem['query'], elem['region_id'], elem['platform'])):
        query, region_id, platform = key
        group_lst = sorted(list(group), key=lambda x: x['n_wins'], reverse=True)

        if len(group_lst) < max(3, min_docs):
            continue
        if len(group_lst) <= 3:
            out_elems = group_lst
        else:
            max_wins = group_lst[0]['n_wins']
            min_wins = group_lst[-1]['n_wins']

            middle_idx = 1
            best_diff = max(abs(group_lst[middle_idx]['n_wins'] - max_wins),
                            abs(group_lst[middle_idx]['n_wins'] - min_wins))

            for i in range(2, len(group_lst)):
                cur_diff = max(abs(group_lst[i]['n_wins'] - max_wins),
                            abs(group_lst[i]['n_wins'] - min_wins))
                if cur_diff < best_diff:
                    best_diff = cur_diff
                    middle_idx = i

            out_elems = [group_lst[0], group_lst[middle_idx], group_lst[-1]]

        for rank, elem in enumerate(out_elems, 1):
            out_data.append(dict(elem, rank=rank))

    return out_data
