import argparse
import logging
import time
import zlib
import os
import grequests
import math

import yt.wrapper as yt
import travel.hotels.proto2.bus_messages_pb2 as bus_messages_pb2
from ticket_parser2.api.v1 import TvmClient, TvmClientStatus, TvmApiClientSettings

api_address = "/api/booking_flow/v1/get_order_info_by_token?" \
              "token={token}&customerIp=127.0.0.1&customerUserAgent=yaTravel/tokenValidator"


def get_last_row_id(table):
    res = yt.get(table + "/@tablets/0/flushed_row_count")
    return res


def select(table, start, limit):
    query = "MessageType, Codec, Bytes, [$row_index] FROM [{}] " \
            "WHERE [$row_index] >= {} AND [$row_index] < {} AND [$tablet_index] = 0".format(table, start, start + limit)
    res = list(yt.select_rows(query))
    return res


def initialize(table, batch_size):
    logging.info("Initializing: bus fast forward")
    start = get_last_row_id(table)

    while True:
        rows = select(table, start, batch_size)
        start = rows[-1]["$row_index"]
        if len(rows) < batch_size:
            logging.info("Initialized, began listening")
            return start


def get_url_and_message(row, host):
    if not host.startswith("http"):
        host = "http://" + host
    if row["MessageType"] == "ru.yandex.travel.hotels.TSearcherMessage" and row["Codec"] == 1:
        bytes = row["Bytes"]
        decompressed = zlib.decompress(bytes)
        message = bus_messages_pb2.TSearcherMessage()
        message.ParseFromString(decompressed)
        marker = "book?token="
        if message.Response.Offers is not None:
            for offer in message.Response.Offers.Offer:
                if offer.OperatorId == 40:  # expedia
                    url = offer.LandingInfo.LandingPageUrl
                    if marker in url:
                        token = url[url.index(marker) + len(marker):]
                        url_to_call = host + api_address.format(token=token)
                        return url_to_call, message
                    else:
                        logging.warn("No marker in url: " + url)
    return None, None


def handle_result(res, message_map):
    if res is None:
        logging.error("NONE result!")
        return False
    if res.status_code == 200:
        url = res.url
        message = message_map.get(url, None)
        offer_price = -1
        for offer in message.Response.Offers.Offer:
            if offer.OperatorId == 40:
                offer_price = offer.Price.Amount
                break
        data = res.json()
        # logging.info(json.dumps(data, indent=4))
        amount = data["rate_info"]["hotel_charges"]["totals"]["grand"]["amount"]
        amount = int(math.ceil(float(amount)))
        if amount < offer_price:
            logging.warn("Price mismatch: got price {} expected offer price is {}".format(amount, offer_price))
            logging.warn(data["rate_info"])
        elif amount > offer_price:
            logging.error("Price mismatch: got price {} expected offer price is {}".format(amount, offer_price))
            logging.error(message)
            return False
        return True
    elif res.status_code == 500:
        logging.error("Error, status code is {}".format(res.status_code))
        url = res.url
        message = message_map.get(url, None)
        logging.error(message)
        return False
    elif res.status_code == 404:
        message = message_map.get(res.url, None)
        logging.error(message)
        logging.error(str(res.json()) + '\n' + get_landing_from_message(message))
        return False
    elif res.status_code == 409:
        logging.warn("Expired offer" + '\n' + get_landing_from_message(message_map.get(res.url, None)))
        return False

    else:
        logging.error("Error, status code is {}".format(res.status_code))
        url = res.url
        message = message_map.get(url, None)
        logging.error(message)
        return False


def get_landing_from_message(message):
    if message.Response.Offers is not None:
        for offer in message.Response.Offers.Offer:
            if offer.OperatorId == 40:  # expedia
                return offer.LandingInfo.LandingPageUrl
    return None


def make_tvm_client(client_id, secret, destinations):
    settings = TvmApiClientSettings(
        self_client_id=client_id,
        enable_service_ticket_checking=True,
        enable_user_ticket_checking=False,
        self_secret=secret,
        dsts=destinations,
    )
    client = TvmClient(settings)
    if client.status != TvmClientStatus.Ok:
        raise Exception("tvm client has bad status: " + TvmClient.status_to_string(client.status))
    return client


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true", default=False)
    parser.add_argument("--host", default="https://api.travel-balancer-test.yandex.net")
    parser.add_argument("--port", default=None)
    parser.add_argument("--table", default="//home/travel/prod/offer_bus")
    parser.add_argument("--batch-size", default=10000)
    parser.add_argument("--sleep", default=10)
    parser.add_argument("--skip-authentication", action='store_true', default=False)
    parser.add_argument("--cli-tvm-service-id", default=2002574, type=int)
    parser.add_argument("--api-tvm-service-id", default=2002548, type=int)
    parser.add_argument("--cli-tvm-secret")
    args = parser.parse_args()

    if args.skip_authentication:
        ticket = None
    else:
        if args.cli_tvm_secret is None:
            if "TVM_SECRET" not in os.environ:
                logging.error("TVM secret must be passed either as environment variable 'TVM_SECRET' "
                              "or as --cli-tvm-secret parameter")
                return
            secret = os.environ['TVM_SECRET']
        else:
            secret = args.tvm_secret
        tvm_client = make_tvm_client(args.cli_tvm_service_id, secret, {'api': args.api_tvm_service_id})
        ticket = tvm_client.get_service_ticket_for("api")

    auth_headers = {'X-Ya-Service-Ticket': ticket}

    start = initialize(args.table, args.batch_size)
    host = "{}:{}".format(args.host, args.port) if args.port is not None else args.host
    while True:
        logging.info("Fetching offers from cache")
        rows = select(args.table, start, args.batch_size)
        logging.info("%s rows fetched", len(rows))

        start = rows[-1]["$row_index"]

        url_message_map = {}
        urls_list = []
        for row in rows:
            url, message = get_url_and_message(row, host)
            if url is None:
                continue
            urls_list.append(url)
            url_message_map[url] = message

        requests = (grequests.get(url, headers=auth_headers) for url in urls_list)
        logging.info("Sending %d requests", len(urls_list))
        results = grequests.map(requests)
        num_ok = 0
        num_fail = 0

        for res in results:
            if handle_result(res, url_message_map):
                num_ok += 1
            else:
                num_fail += 1
        logging.info("{} requests completed, {} ok, {} failed".format(len(urls_list), num_ok, num_fail))

        if len(rows) < args.batch_size:
            logging.info("Will sleep for %s seconds", args.sleep)
            time.sleep(args.sleep)


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format="%(asctime)-15s | %(levelname)s | %(message)s", )
    main()
