import os
import uuid
import logging

import dateutil.parser
import dateutil.tz

from sandbox import sdk2
from sandbox.common.errors import TaskError
from sandbox.sandboxsdk.environments import PipEnvironment

from sandbox.projects.offerwall.common.parameters import UUID, Datetime
from sandbox.projects.offerwall.common.utils import (
    CSVReader, Enum, LogErrorReader, LogS3Retrieve, S3Retrieve,
    cached_property
)
from sandbox.projects.offerwall import OfferwallPromoCodeInputFile
from sandbox.projects.offerwall.common.environments import (
    YandexInternalRootCAEnvironment)
from sandbox.projects.offerwall.common.tasks import SQLAlchemyTask

logger = logging.getLogger('offerwall.promocodes')

Mode = Enum('Mode', 'STRICT UPDATE NEW', text=True)
PROMOCODES_PATH = 'promocodes'


class OfferwallPromoCodeTask(SQLAlchemyTask):
    class Requirements(SQLAlchemyTask.Requirements):
        environments = (
            PipEnvironment('sqlalchemy'),
            PipEnvironment('psycopg2-binary'),
            PipEnvironment('marshmallow'),
            YandexInternalRootCAEnvironment(),
        )
        disk_space = 128

    class Parameters(SQLAlchemyTask.Parameters):
        offer_id = UUID(label='Offer id', required=True)
        csv_file_name = sdk2.parameters.String(
            label='CSV file name', required=True)
        start_at = Datetime(label='Start date', required=True)
        end_at = Datetime(label='End date', required=True)
        max_retry = sdk2.parameters.Integer(
            label='Max retry', required=False, default=3)

        with sdk2.parameters.Group('DB settings') as db_settings:
            db_batch_size = sdk2.parameters.Integer(
                label='Batch size', required=True, default=50000)

        with sdk2.parameters.String('Mode', required=True) as mode:
            mode.values[Mode.STRICT] = mode.Value(Mode.STRICT, default=True)
            mode.values[Mode.UPDATE] = mode.Value(Mode.UPDATE)
            mode.values[Mode.NEW] = mode.Value(Mode.NEW)

        kill_timeout = 60 * 60

    @cached_property
    def declared_models(self):
        from sqlalchemy import Column, Integer, ForeignKey, String, DateTime, Boolean
        from sqlalchemy.ext.declarative import declarative_base
        from sqlalchemy.dialects.postgresql import UUID
        from sqlalchemy.orm import relationship

        Base = declarative_base()

        class Advertiser(Base):
            __tablename__ = 'core_advertiser'

            id = Column(Integer, primary_key=True)

        # noinspection PyUnresolvedReferences
        class Offer(Base):
            __tablename__ = 'core_offer'

            id = Column(UUID, primary_key=True, default=uuid.uuid4)
            advertiser_id = Column(ForeignKey(
                Advertiser.__table__.c.id, ondelete='SET NULL'), nullable=False)
            advertiser = relationship(Advertiser)

        # noinspection PyUnresolvedReferences
        class PromoCode(Base):
            __tablename__ = 'core_promocode'

            id = Column(Integer, primary_key=True)
            offer_id = Column(ForeignKey(
                Offer.__table__.c.id, ondelete='SET NULL'), nullable=False)
            offer = relationship(Offer)
            advertiser_id = Column(ForeignKey(
                Advertiser.__table__.c.id, ondelete='SET NULL'), nullable=False)
            advertiser = relationship(Advertiser)
            code = Column(String)
            start_at = Column(DateTime)
            end_at = Column(DateTime)
            is_available = Column(Boolean, default=True)
            limit = Column(Integer, default=0)

        return Advertiser, Offer, PromoCode

    @cached_property
    def csv_validator(self):
        import marshmallow
        from marshmallow import Schema, fields, ValidationError

        class CSVValidator(Schema):

            def __init__(self, *args, **kwargs):
                self.post_load = kwargs.pop('post_load', None)
                super(CSVValidator, self).__init__(*args, **kwargs)

            code = fields.String(required=True, allow_none=False)
            start_at = fields.Method(
                deserialize='convert_datetime', required=True, allow_none=True,
                missing=self.Parameters.start_at)
            end_at = fields.Method(
                deserialize='convert_datetime', required=True, allow_none=True,
                missing=self.Parameters.end_at)

            @marshmallow.pre_load(pass_many=True)
            def filter_empty(self, data, many):
                if many:
                    return [self.filter(item) for item in data]
                return self.filter(data)

            @staticmethod
            def filter(data):
                return dict((key, value) for key, value in data.items() if value)

            @staticmethod
            def convert_datetime(data):
                UTC = dateutil.tz.gettz('UTC')
                try:
                    return dateutil.parser.parse(
                        data, ignoretz=True).replace(tzinfo=UTC)
                except ValueError as e:
                    raise ValidationError('{}: {}'.format(e.message, data))

            @marshmallow.post_load
            def update_with_post_load(self, data):
                if self.post_load:
                    data.update(self.post_load)
                return data

            @marshmallow.validates_schema(skip_on_field_errors=True)
            def validate_dates(self, data):
                if data['end_at'] <= data['start_at']:
                    raise ValidationError(
                        'Should be greater than start_at', ['end_at'])
                return data

        return CSVValidator

    @cached_property
    def extra(self):
        from sqlalchemy import select

        _, Offer, _ = self.declared_models
        query = select([Offer.id.label('offer_id'), Offer.advertiser_id])
        query = query.where(Offer.id == str(self.Parameters.offer_id))
        extra = self.connection.execute(query).fetchone()

        if not extra:
            raise ValueError(
                'Offer id {} not found!'.format(self.Parameters.offer_id)
            )

        return extra

    @cached_property
    def insert_stmt(self):
        from sqlalchemy.dialects.postgresql import insert

        _, _, PromoCode = self.declared_models
        insert_stmt = insert(PromoCode)

        if self.Parameters.mode == Mode.STRICT:
            return insert_stmt

        elif self.Parameters.mode == Mode.UPDATE:
            do_update_stmt = insert_stmt.on_conflict_do_update(
                index_elements=[PromoCode.advertiser_id, PromoCode.code],
                where=PromoCode.offer_id == insert_stmt.excluded.offer_id,
                set_=dict(start_at=insert_stmt.excluded.start_at,
                          end_at=insert_stmt.excluded.end_at)
            )
            return do_update_stmt

        elif self.Parameters.mode == Mode.NEW:
            do_update_stmt = insert_stmt.on_conflict_do_nothing(
                index_elements=[PromoCode.advertiser_id, PromoCode.code]
            )
            return do_update_stmt

    @cached_property
    def retrieve(self):
        if not os.path.exists(PROMOCODES_PATH):
            os.mkdir(PROMOCODES_PATH)

        return LogS3Retrieve(
            retrieve=S3Retrieve(
                endpoint_url=self.Parameters.aws_endpoint_url,
                bucket=self.Parameters.aws_bucket,
                path=PROMOCODES_PATH
            ),
            logger=logger
        )

    def on_prepare(self):
        super(OfferwallPromoCodeTask, self).on_prepare()
        with self.memoize_stage.prepare:
            self.Context.runs = 0
            self.Context.csv_file_path = self.retrieve(self.Parameters.csv_file_name)
            sdk2.ResourceData(
                OfferwallPromoCodeInputFile(
                    task=self,
                    description='',
                    path=self.Context.csv_file_path,
                    released=self.Parameters.release_type,
                    ttl=self.Parameters.ttl
                )
            ).ready()
            # noinspection PyArgumentList
            self.Context.save()

    def on_execute(self):
        self.Context.runs += 1
        # noinspection PyArgumentList
        self.Context.save()
        if self.Context.runs > self.Parameters.max_retry:
            raise TaskError

        validator = self.csv_validator(
            strict=self.Parameters.mode == Mode.STRICT,
            post_load=self.extra
        )
        reader = LogErrorReader(
            CSVReader(
                open(self.Context.csv_file_path),
                schema=validator, size=self.Parameters.db_batch_size
            ),
            logger
        )

        with sdk2.helpers.NoTimeout():
            with self.connection.begin():
                counter = 0
                while reader.has_next():
                    rows = reader.read()
                    if not rows:
                        continue
                    self.connection.execute(self.insert_stmt, rows)
                    counter += len(rows)
