from __future__ import absolute_import

import os
import sys
import errno
import argparse
from collections import deque

from ya.skynet.util import logging

log = logging.getLogger('scp')

CHUNK_SIZE = 65536
LINE_SIZE = 4096


def err(msg, stdout=sys.stdout):
    log.error(msg)
    stdout.write(b'\x02' + bytes(msg) + b'\n')
    stdout.flush()
    sys.exit(1)


def warn(msg, stdout=sys.stdout):
    log.warning(msg)
    stdout.write(b'\x01' + bytes(msg) + b'\n')
    stdout.flush()


def ok(stdout=sys.stdout):
    stdout.write(b'\x00')
    stdout.flush()


def expect_ok(warn=True, stdin=sys.stdin):
    byte = stdin.read(1)
    if byte == b'\x00':
        log.debug("got OK")
        return True

    if byte not in (b'\x01', b'\x02'):
        log.error("remote side isn't ready: unexpected response %r" % (byte,))
        sys.exit(1)
    elif byte == b'\x02':
        log.error("remote side error: %s" % (stdin.readline(LINE_SIZE),))
        sys.exit(1)
    elif byte == b'\x01':
        log.warning("remote side warning: %s" % (stdin.readline(LINE_SIZE),))
        if warn:
            return False
        else:
            sys.exit(1)


def process_timestamp(stdin=sys.stdin):
    mtime, mtime_msec, atime, atime_msec = stdin.readline(LINE_SIZE).rstrip(b'\r\n').split(b' ', 3)
    log.debug('timestamp %r, %r, %r, %r', mtime, mtime_msec, atime, atime_msec)
    try:
        mtime = int(mtime)
    except ValueError:
        err("invalid mtime format: %r" % (mtime,))

    try:
        mtime += float(mtime_msec) / 1e6
    except ValueError:
        err("invalid mtime_msec format: %r" % (mtime_msec,))

    try:
        atime = int(atime)
    except ValueError:
        err("invalid atime format: %r" % (atime,))

    try:
        atime += float(atime_msec) / 1e6
    except ValueError:
        err("invalid atime_msec format: %r" % (atime_msec,))

    ok()

    return atime, mtime


def copy_file(dest, dirstack, preserve=False, times=None, stdin=sys.stdin):
    mode, length, filename = stdin.readline(LINE_SIZE).rstrip(b'\r\n').split(b' ', 2)
    log.debug('copy %r, %r, %r', mode, length, filename)

    try:
        mode = int(mode, 8) & 0o777
    except ValueError:
        err('invalid file mode: %r' % (mode,))

    try:
        length = int(length)
    except ValueError:
        err('invalid length: %r' % (mode,))

    if dest and not dirstack:
        if os.path.isdir(dest):
            target = os.path.join(dest, filename)
        else:
            target = dest
    else:
        target = filename
        if dirstack:
            target = os.path.join(dirstack[-1], filename)

    try:
        fd = None
        try:
            fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode & 0o777)
            outf = os.fdopen(fd, 'wb')
        except Exception as e:
            if fd is not None:
                os.close(fd)
            return warn("copy %r failed: %s" % (target, e))
        else:
            ok()

        leftbytes = length
        with outf:
            while leftbytes > 0:
                toread = min(leftbytes, CHUNK_SIZE)
                data = stdin.read(toread)
                outf.write(data)
                leftbytes -= len(data)
                log.debug('tried to read %d bytes, %d actually read, %d left', toread, len(data), leftbytes)

            os.fchmod(outf.fileno(), mode & 0o777)

        if preserve:
            os.utime(target, times)
    except Exception as e:
        if isinstance(e, EnvironmentError) and e.errno == errno.EPIPE:
            err("copy %r failed: %s" % (target, e), warn=False)
        else:
            return warn("copy %r failed: %s" % (target, e))

    ok()


def pushd(dest, dirstack, preserve=False, times=None, stdin=sys.stdin, stdout=sys.stdout):
    mode, _, dirname = stdin.readline(CHUNK_SIZE).rstrip(b'\r\n').split(b' ', 2)
    log.debug('pushd %r, %r', mode, dirname)
    try:
        mode = int(mode, 8) & 0o777
    except ValueError:
        err('invalid file mode: %r' % (mode,))

    if not dirstack and not os.path.isdir(dest):
        target = dest
    elif not dirstack:
        target = os.path.join(dest, dirname)
    else:
        target = os.path.join(dirstack[-1], dirname)

    try:
        exists = os.path.exists(target)
        if exists and not os.path.isdir(target):
            return warn("%r is not a dir" % (target,))
        elif not exists:
            os.mkdir(target, mode)

        if preserve or not exists:  # sometimes mkdir may not respect mode
            os.chmod(target, mode)
        if preserve:
            os.utime(target, times)
    except Exception as e:
        log.exception('%r create failed', dirname, exc_info=sys.exc_info())
        err("%r create failed: %s" % (dirname, e))

    dirstack.append(target)
    ok()


def send_timestamp(stat, stdout=sys.stdout):
    mtime = int(stat.st_mtime)
    mtime_msec = int((stat.st_mtime - mtime) * 1e6)
    atime = int(stat.st_atime)
    atime_msec = int((stat.st_atime - atime) * 1e6)
    stdout.write('T%d %d %d %d\n' % (mtime, mtime_msec, atime, atime_msec))
    log.debug("sending timestamp %s %s %s %s", mtime, mtime_msec, atime, atime_msec)
    return expect_ok()


def send_file(source, preserve=False, stdout=sys.stdout):
    log.debug("sending file %r", source)
    try:
        size = 0
        with open(source, 'rb') as f:
            stat = os.fstat(f.fileno())

            if preserve and not send_timestamp(stat):
                return

            size = stat.st_size
            stdout.write('C%04o %d %s\n' % (stat.st_mode & 0o777, size, os.path.basename(source)))
            stdout.flush()

            while size > 0:
                data = f.read(min(size, CHUNK_SIZE))
                stdout.write(data)
                size -= len(data)
                log.debug("wrote %d bytes", len(data))
            # stdout.write('\n')
            stdout.flush()
    except Exception as e:
        log.warning("failed to send file, sending remaining %d bytes as zeroes", size)
        while size > 0:
            stdout.write('\x02' * min(CHUNK_SIZE, size))
            size -= min(CHUNK_SIZE, size)
        stdout.write('\n')
        stdout.flush()
        warn("failed to send %r: %s" % (source, e))
    else:
        ok()


def send_dir(dirname, preserve=False, stdout=sys.stdout):
    log.debug("entering dir %r", dirname)
    try:
        stat = os.stat(dirname)
    except Exception as e:
        return warn("cannot stat %r: %s" % (dirname, e))

    if preserve and not send_timestamp(stat):
        log.warning("send timestamp failed, skipping dir %r", dirname)
        return

    stdout.write(b'D%04o 0 %s\n' % (stat.st_mode & 0o777, os.path.basename(dirname)))
    stdout.flush()
    if not expect_ok():
        log.warning("pushd send failed")
        return

    try:
        paths = os.listdir(dirname)
    except Exception as e:
        log.warning("cannot listdir %r" % (dirname,))
    else:
        for path in paths:
            send(os.path.join(dirname, path), recursive=True, preserve=preserve)

    log.debug("exiting dir %r", dirname)
    stdout.write(b'E\n')
    stdout.flush()
    expect_ok(warn=False)


def send(source, recursive=False, preserve=False, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr):
    while len(source) > 1 and source[-1] == '/':
        source = source[:-1]

    if not os.path.exists(source):
        return warn("%r doesn't exist" % (source,))

    if os.path.isdir(source) and recursive:
        send_dir(source, preserve=preserve)
    elif os.path.isfile(source):
        send_file(source, preserve=preserve)
    else:
        warn("%r is not a regular file" % (source,))


def sender(sources, recursive=False, preserve=False, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr):
    # receiver must inform us that it's ready
    expect_ok(warn=False)

    for source in sources:
        send(source, recursive=recursive, preserve=preserve)


def receiver(dest, recursive=False, preserve=False, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr):
    dirstack = deque()
    times = None

    # inform sender that we're ready
    ok()

    while True:
        cmd = stdin.read(1)
        # responses
        if cmd == b'\x00':
            pass
        elif cmd == b'\x01':
            log.warning('warning response received: %s', stdin.readline(LINE_SIZE).rstrip(b'\r\n'))
        elif cmd == b'\x02':
            log.error('error response received: %s', stdin.readline(LINE_SIZE).rstrip(b'\r\n'))
            sys.exit(1)
        elif cmd == b'':  # finished
            sys.exit(0)

        # requests

        elif cmd == b'T':
            times = process_timestamp()
        elif cmd == b'E':
            log.debug('popd')
            if not len(dirstack):
                log.error("popd with empty stack")
                err("invalid command: cannot call 'popd' without preceding 'pushd'")
            elif stdin.read(1) != '\n':
                err("invalid popd command finish")
            else:
                dirstack.pop()
                ok()
        elif cmd == b'C':
            copy_file(dest, dirstack, preserve=preserve, times=times)
        elif cmd == b'D':
            if not recursive:
                err("got directory without -r option set")
            else:
                pushd(dest, dirstack, preserve=preserve, times=times)
        else:
            err('protocol error, unknown command %r' % (cmd,))


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', dest='mode', action='store_const', const='sender')
    parser.add_argument('-t', dest='mode', action='store_const', const='receiver')
    parser.add_argument('-r', '--recursive', action='store_true', default=False)
    parser.add_argument('-d', dest='target_is_dir', action='store_true', default=False)
    parser.add_argument('-p', dest='preserve', action='store_true', default=False)
    parser.add_argument('-v', '--verbose', action='store_true', default=False)
    parser.add_argument('files', nargs='*')
    return parser.parse_known_args()[0]


def main(stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr):
    args = parse_args()
    logging.initialize(logger=log, levelChars=1, handler=logging.StreamHandler(stderr), level=logging.DEBUG if args.verbose else logging.WARNING)

    try:
        if args.mode is None:
            err("no mode specified, either -f or -t must be present")

        elif args.mode == 'sender':
            sender(recursive=args.recursive, preserve=args.preserve, sources=args.files)
        elif len(args.files) != 1:
            err("too many destinations specified for receiver")
        else:
            dest = args.files[0]
            if args.target_is_dir and not os.path.isdir(dest):
                err("target should be a directory")
            receiver(recursive=args.recursive, preserve=args.preserve, dest=dest)
    except Exception as e:
        log.exception("execution failed with %r", e, exc_info=sys.exc_info())
        raise
