import logging
import smtplib
from datetime import datetime, timedelta
from email.header import Header
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.mime.base import MIMEBase
from sandbox.sandboxsdk import environments
from sandbox import sdk2


MOSCOW_UTC_OFFSET = 3


class YabsServerSSPReportByPages(sdk2.Task):
    class Parameters(sdk2.Task.Parameters):
        yql_token_vault_name = sdk2.parameters.String("YQL Token vault name", default="yql_token", required=True)
        db = sdk2.parameters.String("Yt Cluster", default="hahn", required=True)
        recipients = sdk2.parameters.String("EMail recipients (by comma)", default="ssp-int@yandex-team.ru", required=True)
        process_delay = sdk2.parameters.Integer("Process delay (in hours)", default=30, required=True)
        utc_offset = sdk2.parameters.Integer("Offset from UTC time (in hours)", default=MOSCOW_UTC_OFFSET, required=True)
        sspid = sdk2.parameters.Integer("SSPID", required=True)
        group_by_page = sdk2.parameters.String("Group by page", default="No", required=True)
        stat_format = sdk2.parameters.String("Format", default="", required=True)

    class Requirements(sdk2.Task.Requirements):
        environments = [environments.PipEnvironment('yql', version='1.2.91'), environments.PipEnvironment('yandex-yt')]

    def datetime_range(start, end, delta):
        current = start
        if not isinstance(delta, timedelta):
            delta = timedelta(**delta)
        while current < end:
            yield current
            current += delta

    def GetProcessDate(self):
        log_frmt = '%Y-%m-%dT%H:00:00'
        frmt = '%Y-%m-%d'

        process_date = (datetime.utcnow() + timedelta(hours=self.Parameters.utc_offset - self.Parameters.process_delay)).replace(hour=0, minute=0, second=0)
        process_date_msk = process_date + timedelta(hours=MOSCOW_UTC_OFFSET - self.Parameters.utc_offset)

        return process_date_msk.strftime(log_frmt), (process_date_msk + timedelta(hours=23)).strftime(log_frmt), process_date.strftime(frmt)

    def GetEmailBodySummaryByDates(self, report):
        body_txt = ""

        if self.Parameters.stat_format == "":
            body_txt = """
                Date: {process_date}
                Requests: {hits}
                Impressions: {shows}
                Revenue: {cost}
            """.format(process_date=self.process_date, hits=report[0]["Hits"], shows=report[1]["Shows"], cost=report[2]["Cost"])
        elif self.Parameters.stat_format == "csv":
            body_txt = """Date;Requests;Impressions;Revenue
{process_date};{hits};{shows};{cost}
            """.format(process_date=self.process_date, hits=report[0]["Hits"], shows=report[1]["Shows"], cost=report[2]["Cost"])
        else:
            raise Exception('Unknown report format')

        return body_txt

    def GetEmailBodySummaryByPagesDates(self, report):
        delimeter = ""
        result = []
        fields = ["Date", "AppName", "Hits", "Bids", "Wins", "Impressions", "Cost", "CPM", "PartnerStatID"]

        if self.Parameters.stat_format == "csv":
            delimeter = ";"
        else:
            delimeter = "\n"

        for row in report:
            result.append(delimeter.join([self._preprocess_value(k, row[k]) for k in fields]))

        header_txt = delimeter.join(fields)
        body_txt = header_txt + "\n" + "\n".join(result)

        return body_txt

    def _preprocess_value(self, fname, fvalue):
        fvalue = str(fvalue or '0')

        if fname == 'Cost' or fname == 'CPM':
            fvalue.replace(".", ",")

        return fvalue

    def SendEmail(self, body_txt):
        frm = 'Yandex RTB Reporter <ssp-int@yandex-team.ru>'
        to = self.Parameters.recipients.split(',')
        subject = "Yandex Daily RTB Stats {process_date}".format(process_date=self.process_date)

        msg = MIMEMultipart('alternative')
        msg.set_charset('utf8')
        msg['From'] = frm
        msg['To'] = ', '.join(to)
        msg['Subject'] = Header(subject, 'UTF-8').encode()

        if self.Parameters.stat_format == "csv":
            body = MIMEText("", 'plain', 'utf8')
            msg.attach(body)

            attachment = MIMEBase('application', 'csv')

            filename = "yandex-stat-%s.csv" % self.process_date
            header = 'Content-Disposition', 'attachment; filename="%s"' % filename
            attachment.add_header(*header)

            attachment.set_payload(body_txt)
            msg.attach(attachment)
        else:
            body = MIMEText(body_txt, 'plain', 'utf8')
            msg.attach(body)

        try:
            srv = smtplib.SMTP('yabacks.yandex.ru', port=25)
            srv.sendmail(frm, to, msg.as_string())
            srv.quit()
        except smtplib.SMTPException:
            raise Exception('No e-mails was sent. Internal SMTP exception occured')

    def on_execute(self):
        self.yql_token = sdk2.task.Vault.data(self.author, self.Parameters.yql_token_vault_name)

        logging.info("Started!")

        self.process_date_start, self.process_date_end, self.process_date = self.GetProcessDate()
        logging.info("Gonna run for {0} UTC {1} (log from {2} to {3} MSK)".format(self.process_date, self.Parameters.utc_offset, self.process_date_start, self.process_date_end))
        body_txt = ""

        if self.Parameters.group_by_page == "No":
            report = self.GetReportByDates()
            body_txt = self.GetEmailBodySummaryByDates(report)

        elif self.Parameters.group_by_page == "Yes":
            report = self.GetReportByPagesDates()
            body_txt = self.GetEmailBodySummaryByPagesDates(report)

        else:
            raise Exception("Unknown report type")

        logging.info("Sending email...")
        self.SendEmail(body_txt)

        logging.info("Done!")

    def GetReportByPagesDates(self):
        yql_request = '''
            use hahn;

            select
                `Date`,
                AppName,
                Hits,
                Bids,
                Wins,
                Impressions,
                Math::Round(Cost / 1000000.0, -4) as Cost,
                Math::Round(Cost / Impressions / 1000.0, -4) as CPM,
                PartnerStatID
            from (
                select
                    IF(PageToken is not NULL and PageToken != '' or PageName = 'dsp.yandex.ru', PageToken, PageName) as AppName,
                    PartnerStatID,
                    SUM(Hits) as Hits,
                    SUM_IF(1, Bids > 0) as Bids,
                    SUM_IF(1, Impressions > 0) as Impressions,
                    SUM(Cost) as Cost,
                    SUM(Wins) as Wins,
                    '{process_date}' as `Date`
                from (
                    select
                        PartnerStatID,
                        RtbLog.bidreqid,
                        RtbLog.pageid,
                        Page.Name as PageName,
                        RtbLog.pagetoken as PageToken,
                        1 as Hits,
                        sum(DspLog.bids) as Bids,
                        sum(SspLog.wins) as Wins,
                        sum(SspLog.ssppricecur) as Cost,
                        sum(EventLog.impressions) as Impressions,
                        sum(EventLog.clicks) as Clicks
                    from
                        range(`logs/bs-rtb-log/1h`, `{process_date_start}`, `{process_date_end}`) as RtbLog
                        left join (
                            select
                                bidreqid,
                                count_if(countertype == '0') as bids,
                                count_if(countertype == '1') as block_impressions
                            from
                                range(`logs/bs-dsp-log/1h`, `{process_date_start}`, `{process_date_end}`)
                            where
                                sspid == '{sspid}'
                            group by
                                bidreqid
                        ) as DspLog on RtbLog.bidreqid == DspLog.bidreqid
                        left join (
                            select
                                bidreqid,
                                count_if(countertype == '1') as impressions,
                                count_if(countertype == '2') as clicks
                            from
                                range(`logs/bs-chevent-log/1h`, `{process_date_start}`, `{process_date_end}`)
                            where
                                sspid == '{sspid}'
                            group by
                                rtbbidreqid as bidreqid
                        ) as EventLog on RtbLog.bidreqid == EventLog.bidreqid
                        left join (
                            select
                                bidreqid,
                                sum(cast(win as Int64)) as wins,
                                sum(cast(ssppricecur as Int64)) as ssppricecur
                            from
                                range(`logs/bs-ssp-log/1h`, `{process_date_start}`, `{process_date_end}`)
                            where
                                sspid == '{sspid}'
                            group by
                                bidreqid
                        ) as SspLog on RtbLog.bidreqid == SspLog.bidreqid
                        left join `home/yabs/dict/Page` as Page on Page.PageID == cast(RtbLog.pageid as Uint32)
                    where
                        RtbLog.sspid == '{sspid}'
                    group by
                        RtbLog.bidreqid, RtbLog.pageid, Page.Name, RtbLog.pagetoken, RtbLog.partnerstatid as PartnerStatID
                ) as joined_table
                group by
                    PageName, PageToken, PartnerStatID
                order by
                    PartnerStatID, AppName
            ) as total_table;

        '''.format(
            process_date_start=self.process_date_start,
            process_date_end=self.process_date_end,
            process_date=self.process_date,
            sspid=self.Parameters.sspid
        )
        logging.info("Prepared request " + yql_request)

        return self.yql_request_execute(yql_request)

    def GetReportByDates(self):
        yql_request = '''
            use hahn;

            select count(*) as Hits from range(`logs/bs-rtb-log/1h`, `{process_date_start}`, `{process_date_end}`) as Log
            inner join `home/yabs/dict/SSPPageMapping` as PageMapping
            on cast(Log.pageid as Uint32) == PageMapping.PageID
            where PageMapping.SSPID == {sspid};

            select count(*) as Shows from range(`logs/bs-dsp-log/1h`, `{process_date_start}`, `{process_date_end}`) as Log
            inner join `home/yabs/dict/SSPPageMapping` as PageMapping
            on cast(Log.pageid as Uint32) == PageMapping.PageID
            where PageMapping.SSPID == {sspid} and cast(Log.countertype as Uint32) == 1;

            select sum(cast(ssppricecur as Double)) / 1e6 as Cost from range(`logs/bs-ssp-log/1h`, `{process_date_start}`, `{process_date_end}`) as Log
            inner join `home/yabs/dict/SSPPageMapping` as PageMapping
            on cast(Log.pageid as Uint32) == PageMapping.PageID
            where PageMapping.SSPID == {sspid};

        '''.format(
            process_date_start=self.process_date_start,
            process_date_end=self.process_date_end,
            sspid=self.Parameters.sspid
        )
        logging.info("Prepared request " + yql_request)

        return self.yql_request_execute(yql_request)

    def yql_request_execute(self, request):
        from yql.api.v1.client import YqlClient
        client = YqlClient(db=self.Parameters.db, token=self.yql_token)
        query = client.query(request, syntax_version=1)
        query.run()
        query.wait_progress()

        if not query.is_success:
            raise Exception('\n'.join([str(err) for err in query.errors]))

        result = []
        for table in query.get_results():
            table.fetch_full_data()

        columns = []
        for column_name, column_type in table.columns:
            columns.append(column_name)

        for row in table.rows:
            result.append(dict([(columns[i], value) for i, value in enumerate(row)]))

        logging.info("RESULT: {}".format(result))

        return result
