# -*- coding: utf-8 -*-
import xml.etree.ElementTree as et
from datacloud.dev_utils.tvm.tvm_utils import TVMManager
from datacloud.dev_utils.tvm import tvm_id
import requests
from yandex.maps.proto.common2.response_pb2 import Response
from enum import Enum, unique


@unique
class BadAddrTypes(Enum):
    empty_addr = 'empty_addr'
    no_addr_found = 'no_addr_found'


class CoordNotFoundException(LookupError):
    def __init__(self, bad_addr_type):
        self.bad_addr_type = bad_addr_type

    def __str__(self):
        return self.bad_addr_type.value

    def __repr__(self):
        return self.__str__()


class AddrsResolver:
    TESTING = 'http://addrs-testing.search.yandex.net/search/stable/yandsearch'
    PRODUCTION = 'http://addrs.yandex.ru:17140/yandsearch'

    TVM_SRC = tvm_id.DATACLOUD.test
    TVM_DST = tvm_id.GEOCODER.prod
    TVM_DST_TESTING = tvm_id.GEOCODER.test

    def __init__(
            self,
            server=TESTING,
            origin='xprod-datacloud',
            memory_on=True,
            use_tvm=True,
            use_pb=True,
            tvm_src=None,
            tvm_dst=None,
            tvm_log_level=None):

        self.server = server
        self.origin = origin
        self.memory_on = memory_on
        self.memory = {}

        self.tvm_manager = TVMManager(log_level=tvm_log_level)
        self.tvm_src = tvm_src or self.TVM_SRC

        self.tvm_dst = self.TVM_DST
        if tvm_dst is not None:
            self.tvm_dst = tvm_dst
        elif self.server == self.TESTING:
            self.tvm_dst = self.TVM_DST_TESTING

        self.use_tvm = use_tvm
        self.use_pb = use_pb

    def send_request(self, addr):
        request_params = {
            'lang': 'ru',
            'results': 1,
            'origin': self.origin,
            'text': addr
        }
        if self.use_pb:
            request_params['ms'] = 'pb'

        if self.use_tvm:
            if self.server == self.TESTING:
                request_params['tvm'] = 1

            resp = self.tvm_manager.tvm_get_request(
                self.server,
                self.tvm_src,
                self.tvm_dst,
                params=request_params
            ).content
        else:
            resp = requests.get(self.server, request_params).content

        return resp

    def parse_pb_response(self, response):
        pb_response = Response()
        pb_response.ParseFromString(response)

        if len(pb_response.reply.geo_object) == 0:
            raise CoordNotFoundException(BadAddrTypes.no_addr_found)

        point = pb_response.reply.geo_object[0].geometry[0].point
        return point.lon, point.lat

    def parse_xml_response(self, response):
        xml_et = et.fromstring(response)
        interesting_objs = filter(
            lambda obj: 'GeoObject' in obj.tag and 'GeoObjectCollection' not in obj.tag,
            xml_et.getiterator()
        )

        elements = [elem for obj in interesting_objs for elem in obj]
        elements = filter(lambda st: 'Point' in st.tag, elements)

        if len(elements) == 0:
            raise CoordNotFoundException(BadAddrTypes.no_addr_found)

        text = list(elements[0])[0].text
        lon, lat = text.split(' ')
        return float(lon), float(lat)

    def parse_response(self, response):
        if self.use_pb:
            return self.parse_pb_response(response)
        return self.parse_xml_response(response)

    def resolve_addr(self, addr):
        if not addr:
            raise CoordNotFoundException(BadAddrTypes.empty_addr)

        if self.memory_on and addr in self.memory:
            return self.memory[addr]

        response = self.send_request(addr)
        coord = self.parse_response(response)

        if self.memory_on:
            self.memory[addr] = coord
        return coord

    def resolve_addr_or_none(self, addr):
        try:
            return self.resolve_addr(addr)
        except CoordNotFoundException as cnf_ex:
            return cnf_ex.bad_addr_type
