from nile.api.v1 import (
    statface as ns,
    aggregators as na,
    datetime as nd,
    Record,
    Job,
    stream as nstream,
)
from maps.wikimap.stat.libs.common.lib.datetime_interval import (
    DatetimeInterval,
    datetime_intervals_overlap,
    make_interval,
)
from maps.wikimap.stat.arm.common.lib import (
    object_type,
    report_names,
    table_names,
    helpers,
)
from typing import (
    List,
    Optional,
)
from datetime import datetime, timedelta


def utc_from_timestamp_optional(timestamp: Optional[int]) -> Optional[datetime]:
    if timestamp is not None:
        return datetime.utcfromtimestamp(timestamp)
    else:
        return None


'''
    Result schema example: TODO:
        object_id |      object_type      | region_id | fielddate  | length
       -----------+-----------------------+-----------+------------+--------
          234234  | \t_total_\tclosure\t  |   225     | 2019-08-01 |  150
          234234  | \t_total_\t           |   225     | 2019-08-01 |  1240
          11111   | \t_total_\tTclosure\t |   11111   | 2019-08-01 |  500.8
          11111   | \t_total_\t           |   11111   | 2019-08-01 |  56.15
'''


def add_fielddate(
    objects: nstream.Stream,
    dates: List[nd.Datetime]
) -> nstream.Stream:
    def object_fd_mapper(rows):
        for row in rows:
            object_live_interval: Optional[DatetimeInterval] = make_interval(
                datetime.utcfromtimestamp(row.created_at),
                utc_from_timestamp_optional(row.deleted_at),
            )
            for date in dates:
                date_interval = DatetimeInterval(date.as_datetime(), date.as_datetime()+timedelta(days=1))
                if datetime_intervals_overlap(
                    date_interval,
                    object_live_interval
                ) is not None:
                    yield Record(
                        fielddate=str(date.as_datetime().date()),
                        object_id=row.object_id,
                        object_type=row.object_type,
                        region_id=row.region_id,
                        length=row.geometry_len_meters,
                    )

    return objects.map(object_fd_mapper)


'''
    Result schema example:
        object_id |    object_type   | length | fielddate
       -----------+------------------+--------+-----------
          234234  |     closure      |   225  | 2019-08-01
          234234  | closure_template |   333  | 2019-08-01
          11111   |     closure      |   1111 | 2019-08-01
          11111   |     closure      |   50.8 | 2019-08-01
'''


def versions_add_fielddate(
    versions: nstream.Stream,
    dates: List[nd.Datetime]
) -> nstream.Stream:
    def version_fd_mapper(rows):
        for row in rows:
            if row.start_at is None:
                continue
            closure_interval: Optional[DatetimeInterval] = make_interval(
                datetime.utcfromtimestamp(row.start_at),
                utc_from_timestamp_optional(row.end_at)
            )
            version_interval: Optional[DatetimeInterval] = make_interval(
                datetime.utcfromtimestamp(row.modified_at),
                utc_from_timestamp_optional(row.next_version_at),
            )
            effective_interval = datetime_intervals_overlap(
                closure_interval,
                version_interval,
            )

            for date in dates:
                date_interval = DatetimeInterval(
                    date.as_datetime(),
                    date.as_datetime() + timedelta(days=1),
                )
                if datetime_intervals_overlap(
                    effective_interval,
                    date_interval,
                ) is not None:
                    yield Record(
                        row,
                        fielddate=str(date.as_datetime().date()),
                    )

    return versions.map(
        version_fd_mapper
    ).groupby(
        'fielddate',
        'object_id',
        'object_type',
    ).top(
        count=1,
        by='version',
        mode='max'
    ).project(
        'fielddate',
        'object_id',
        'object_type',
        length='geometry_len_meters',
    )


'''
    regions table has several rows for every region_id
    (one row per nesting level; plus special region category):
        region_id |      region_tree          |  region_name
       -----------+---------------------------+---------------------------
          177796  | \t10000\t                 | Земля
          177796  | \t10000\t225\t            | Земля/Россия
          177796  | \t10000\t225\t102444\t    | Земля/Россия/Сев-кав. округ
          177796  | \t10000\t225\t999999000\t | Земля/Россия/0+
          ...

    Result schema example:
        object_id |      object_type      |  region   | fielddate
       -----------+-----------------------+-----------+-----------
          234234  | \t_total_\tclosure\t  | \t10\t    | 2019-08-01
          11111   | \t_total_\tTclosure\t | \t10\t5\t | 2019-08-01
          222222  | \t_total_\ttemplate\t | \t10\t5\t | 2019-08-02
'''


def join_regions(
    objects: nstream.Stream,
    regions: nstream.Stream
) -> nstream.Stream:
    objects_regions = objects.join(
        regions,
        by='region_id',
        type='left',
    ).project(
        'fielddate',
        'object_id',
        'object_type',
        'length',
        region='region_tree',
    ).map(
        helpers.fix_unknown_regions,
    )
    return objects_regions


def versions_join_regions(
    versions: nstream.Stream,
    objects: nstream.Stream,
    regions: nstream.Stream
) -> nstream.Stream:
    versions_regions = versions.join(
        objects,
        by=['object_id', 'object_type']
    ).join(
        regions,
        by='region_id',
        type='left',
    ).project(
        'fielddate',
        'object_id',
        'length',
        'object_type',
        'template_id',
        'category',
        'fake',
        region='region_tree',
    ).map(
        helpers.fix_unknown_regions,
    )
    return versions_regions


MEASURES = ['fielddate', 'object_type', 'region']


def make_testable_job(
    job: Job,
    dates: List[nd.Datetime],
    dump_date_string: str
) -> nstream.Stream:
    objects = job.table(table_names.OBJECT_TABLE_PREFIX + dump_date_string).label('objects')
    versions = job.table(table_names.ACTION_TABLE_PREFIX + dump_date_string).label('versions')
    regions = job.table(table_names.REGIONS_TABLE_PATH).label('regions')

    # calculate effective objects
    versions_measured = versions.call(
        versions_add_fielddate,
        dates,
    ).label('versions_ofd').call(
        versions_join_regions,
        objects,
        regions,
    ).label('versions_region').call(
        object_type.modify_obj_type,
    ).project(
        'object_id',
        'length',
        *MEASURES
    ).label('versions_measured')

    versions_count = versions_measured.groupby(
        *MEASURES
    ).aggregate(
        count_effective=na.count(),
        length_effective=na.sum('length'),
    ).label('versions_count')

    # calculate objects
    objects_measured = objects.call(
        object_type.modify_obj_type
    ).call(
        add_fielddate,
        dates
    ).label('objects_ofd').call(
        join_regions,
        regions
    ).label('objects_measured')

    objects_count = objects_measured.groupby(
        *MEASURES
    ).aggregate(
        count=na.count(),
        length=na.sum('length'),
    ).label('objects_count')

    # merge all
    all_count = objects_count.join(
        versions_count,
        by=MEASURES,
        type='full',
    ).label('all_count')

    return all_count


def make_job(
    job: Job,
    dates: List[nd.Datetime],
    statface_client: ns.StatfaceClient,
    dump_date_string: str
) -> None:
    objects_counts = make_testable_job(job, dates, dump_date_string)

    report = ns.StatfaceReport() \
        .path(report_names.OBJECTS_COUNT) \
        .scale('daily') \
        .client(statface_client)

    objects_counts.publish(report, allow_change_job=True)
