# coding: utf-8

from __future__ import print_function
import logging

from psycopg2.extras import LogicalReplicationConnection
import psycopg2

from pymdb.replication.parser import Parser


log = logging.getLogger(__name__)


class ReplicationStream(object):
    parser = Parser()

    def __init__(self, cur):
        self._cur = cur

    def _get_messages(self):
        last_msg = None
        while True:
            msg = self._cur.read_message()
            if msg is None:
                if last_msg is not None:
                    self._cur.send_feedback(
                        flush_lsn=last_msg.data_start
                    )
                return
            last_msg = msg
            yield msg

    def skip_changes(self):
        for _ in self._get_messages():
            pass

    def get_changes(self):
        for msg in self._get_messages():
            # no more messages from the server at the moment
            if msg is None:
                yield None
            else:
                yield self.parser.parse_message(msg.payload)


class Replica(object):
    slot_name = 'big_brother'

    def __init__(self, dsn):
        self._dsn = dsn

        conn = self._make_connection()
        try:
            self._drop_slot(conn)
        except psycopg2.ProgrammingError as exc:
            log.warning('Got %s - probably no such slot', exc)
        self._create_slot(conn)

        conn.close()

        self._replication_conn = None

    def _make_connection(self):
        return psycopg2.connect(self._dsn, connection_factory=LogicalReplicationConnection)

    def _drop_slot(self, conn):
        conn.cursor().drop_replication_slot(slot_name=self.slot_name)

    def _create_slot(self, conn):
        conn.cursor().create_replication_slot(self.slot_name, output_plugin='test_decoding')

    def start_replication(self, start_lsn):
        # close previous connection
        self.stop_replication()

        self._replication_conn = self._make_connection()

        cur = self._replication_conn.cursor()
        cur.start_replication(
            slot_name=self.slot_name,
            start_lsn=start_lsn,
            decode=True)

        return ReplicationStream(cur)

    def stop_replication(self):
        if self._replication_conn is not None:
            self._replication_conn.close()

    def terminate(self):
        conn = self._make_connection()
        self._drop_slot(conn)
        conn.close()
