import json
import time
import datetime

from builtins import RecursionError

from threading import Thread

import boto3
from botocore.exceptions import ClientError

from bebo.logger import get_logger
from bebo.config import cfg as config

default_log = get_logger('bebo-python-commons')

MAX_RETRIES = 10 # 10 * 60 seconds = 10 minutes

def get_queue_arn_from_queue_name(queue_name):
    return 'arn:aws:sqs:{}:302024746872:{}'.format(config.AWS_REGION, queue_name) #todo move aws account id out

def get_queue_url_from_queue_name(queue_name):
    return 'https://sqs.{}.amazonaws.com/302024746872/{}'.format(config.AWS_REGION, queue_name)

class Runner(Thread):
    def __init__(self, runner_id, loop_func, callback_func, log, sleep=3):
        Thread.__init__(self, name='SQSRunnerThread-{}'.format(runner_id))
        self.loop_func = loop_func
        self.callback_func = callback_func
        self.sleep_timeout = sleep
        self.run_loop = True
        self.log = log

    def run(self):
        while self.run_loop:
            try:
                items = self.loop_func()
                if items:
                    self.callback_func(items)
            except KeyboardInterrupt as e:
                self.run_loop = False
                raise e
            except Exception:
                self.log.exception("failed to run SQS Runner Loop")

    def stop(self):
        self.run_loop = False

class SQS():
    def __init__(self, process_messages, sns_topic_arn, **kwargs):
        self.runner_count = kwargs.get('runner_count', 1)
        self.queue_name = kwargs.get('queue_name', '{}_{}_{}'.format(config.REGION, config.SERVICE, config.HOSTNAME))
        self.queue_url = get_queue_url_from_queue_name(self.queue_name)
        self.queue_arn = get_queue_arn_from_queue_name(self.queue_name)

        self.log = kwargs.get('log', default_log)

        sns_destination_region = kwargs.get("sns_destination_region", None)
        if not sns_destination_region:
            if config.ENV == "bebo-prod":
                sns_destination_region = "us-west-2"
            else:
                sns_destination_region = "us-west-1"

        region = config.AWS_REGION

        self.log.info("Initialize SQS Queues[{}]: {}".format(region, self.queue_url))
        self.sqs = boto3.resource('sqs', region_name=region)
        self.sns = boto3.client('sns', region_name=sns_destination_region)

        self.sns_topic_arn = sns_topic_arn

        self.process_messages = process_messages

        self.runners = []
        self.queue = self.sqs.Queue(self.queue_url)
        self.ensure_queue_present()

    def ensure_queue_present(self):
        try:
            self.queue.load()
        except ClientError as e:
            if e.response['Error']['Code'] == 'AWS.SimpleQueueService.NonExistentQueue':
                self.log.info("Queue {} does not exist -- creating".format(self.queue.url))
                self.create_queue()
            else:
                raise e

    def create_queue(self, tries=0):
        if tries > MAX_RETRIES:
            raise RecursionError()

        try:
            self.log.info('Creating queue: {}'.format(self.queue_name))
            self.sqs.create_queue(
                QueueName=self.queue_name,
                Attributes={
                    'Policy': json.dumps({
                        "Version": "2012-10-17",
                        "Id": '{}/SQSDefaultPolicy'.format(self.queue_arn),
                        "Statement": [
                            {
                                "Sid": 'Sid{}'.format(int(datetime.datetime.utcnow().timestamp())),
                                "Effect": "Allow",
                                "Principal": {
                                    "AWS": "*"
                                    },
                                "Action": "SQS:SendMessage",
                                "Resource": self.queue_arn,
                                "Condition": {
                                    "ArnEquals": {
                                        "aws:SourceArn": self.sns_topic_arn
                                        }
                                    }
                                }
                            ]
                        })
                    }
                )

            self.log.info('created queue: {}'.format(self.queue_name))

            return self.subscribe_queue_to_sns(self.queue_arn)
        except ClientError as e:
            if e.response['Error']['Code'] == 'AWS.SimpleQueueService.QueueDeletedRecently':
                self.log.info('Retrying creating queue {} in 10s'.format(self.queue_url))
                time.sleep(10)
                return self.create_queue(tries+1)
            raise e

    def subscribe_queue_to_sns(self, queue_arn):
        subscription_response = self.sns.subscribe(
            TopicArn=self.sns_topic_arn,
            Protocol='sqs',
            Endpoint=queue_arn
        )
        subscription_arn = subscription_response['SubscriptionArn']
        self.log.info('created subscription from {} to {} -- subscriptionArn: {}'.format(self.sns_topic_arn, queue_arn, subscription_arn))

    def get_messages(self):
        items = []
        self.ensure_queue_present()
        messages = self.queue.receive_messages(MaxNumberOfMessages=10,
                                               VisibilityTimeout=1,
                                               WaitTimeSeconds=20)
        for message in messages:
            body = json.loads(message.body)
            content = json.loads(body['Message'])
            items.append(content)
            message.delete()

        return items

    def start(self):
        runners_needed = self.runner_count - len(self.runners)
        if runners_needed:
            self.log.info("SQS.starting %s runners", runners_needed)
        for i in range(runners_needed):
            runner = Runner(i, self.get_messages, self.process_messages, self.log)
            runner.daemon = True
            self.runners.append(runner)

        for runner in self.runners:
            if not runner.is_alive():
                runner.start()

    def stop(self):
        for runner in self.runners:
            runner.stop()
