import gzip
import io
import logging
import os
import re
import tarfile
from urllib import urlopen

from sandbox.common.errors import TaskFailure
from sandbox import sdk2

from sandbox.projects.maps.MapsDownloadCachesTask.compiled_proto.region_list_pb2 import RegionList, DownloadedRegion


_REGIONS_URL_FORMAT = "http://core-mobile-cacheinfo.maps.yandex.net/mapkit/2.x/regionlist?l={}&lang={}&scale=2"


class MapsOfflineCachesResource(sdk2.Resource):
    """ An archive file of offline caches. """


class MapsDownloadCachesTask(sdk2.Task):
    """ Upload caches with a certain ID or whose name matches Name regex. """

    class Parameters(sdk2.Task.Parameters):
        layers = sdk2.parameters.String(
            "Layers (comma-separated)", default="vmap2fb,search2,driving", required=True
        )
        locale = sdk2.parameters.String("Locale", required=True)
        legacy_format = sdk2.parameters.Bool("Use legacy files emplacement format (for mapkit < 172.2.0)", default=False, required=True)
        region_ids = sdk2.parameters.String("Region IDs (comma-separated)", required=False)
        countries_regex = sdk2.parameters.String("Countries (regex)", required=False)
        region_names_regex = sdk2.parameters.String("Region names (regex)", required=False)
        ttl = sdk2.parameters.String("Time to live (1-999 days or inf)", default=30, required=True)

    def on_execute(self):
        output_resource = MapsOfflineCachesResource(
            self, "Downloaded offline caches", "offline_cache.tar.gz", ttl=self.Parameters.ttl
        )
        output_resource_data = sdk2.ResourceData(output_resource)
        output_filename = str(output_resource_data.path.absolute())
        _download_regions(output_filename, self.Parameters)
        output_resource_data.ready()
        self.set_info("Downloaded offline caches link: {}".format(output_resource.http_proxy))


def _download_regions(output_filename, params):
    if not params.region_ids and not params.region_names_regex and not params.countries_regex:
        raise TaskFailure(
            "One of the following fields must be filled: Region IDs, Countries, Region names"
        )

    region_ids = params.region_ids
    if region_ids:
        region_ids = {int(x) for x in params.region_ids.split(",") if x.strip().isdigit()}

    layers = set(params.layers.split(","))

    locale = params.locale
    legacy_format = params.legacy_format

    countries_regex = params.countries_regex
    countries_regex = _compile_regex_if_not_empty(countries_regex)

    region_names_regex = params.region_names_regex
    region_names_regex = _compile_regex_if_not_empty(region_names_regex)

    params = None

    regions_url = _REGIONS_URL_FORMAT.format("%2C".join(layers), locale)
    region_list = RegionList()
    region_list.ParseFromString(gzip.GzipFile(fileobj=io.BytesIO(urlopen(regions_url).read())).read())
    if not region_list.ListFields():
        raise TaskFailure(
            "Could not get RegionList. You have probably misspelled Locale or Layers"
        )

    with tarfile.open(name=output_filename, mode="w:gz") as offline_cache_tar_gz:
        for region in region_list.regions:
            logging_region_description = "region with id={}, country={} and name={}".format(
                region.id, region.country, region.name
            )
            if not (
                region_ids and region.id in region_ids or
                countries_regex and countries_regex.search(region.country) or
                region_names_regex and region_names_regex.search(region.name)
            ):
                logging.info("Skipped {}".format(logging_region_description))
                continue
            logging.info("Processing {}".format(logging_region_description))
            if legacy_format:
                metadata_path = os.path.join(
                    "mapkit", "metadata", "{}_{}".format(region.id, locale), "region.pb"
                )
            else:
                metadata_path = os.path.join(
                    "mapkit", "offline_caches", str(region.id), "region.pb"
                )
            downloaded_r = _convert_region_to_downloaded_region(region)
            metadata = downloaded_r.SerializeToString()
            _add_to_tar(offline_cache_tar_gz, metadata_path, metadata)
            for f in region.files:
                if f.cache_type in layers:
                    if legacy_format:
                        cache_path = os.path.join(
                            "mapkit", f.cache_type, "{}_{}".format(region.id, locale), "region.fb"
                        )
                    else:
                        cache_path = os.path.join(
                            "mapkit", "offline_caches", str(region.id), f.cache_type + ".fb"
                        )
                    logging.info("Downloading region {} {} {}...".format(region.id, f.cache_type, cache_path))
                    cache = urlopen(f.download_url).read()
                    _add_to_tar(offline_cache_tar_gz, cache_path, cache)


def _add_to_tar(tar, name, content):
    buf = io.BytesIO(content)
    tarinfo = tarfile.TarInfo(name=name)
    tarinfo.size = len(buf.getvalue())
    tar.addfile(tarinfo=tarinfo, fileobj=buf)
    logging.info("Added file with name={}".format(name))


def _compile_regex_if_not_empty(regex):
    return re.compile(regex) if regex else None


def _convert_region_to_downloaded_region(region):
    downloaded_r = DownloadedRegion()
    downloaded_r.name = region.name
    downloaded_r.country = region.country
    downloaded_r.cities.extend(region.cities)
    downloaded_r.center_point.CopyFrom(region.center_point)
    downloaded_r.size.CopyFrom(region.size)
    downloaded_r.release_time = region.release_time
    return downloaded_r
