#!/usr/bin/env python3

import argparse
import re

re_kv_sep = re.compile(r"\s*=\s*")
re_non_space = re.compile(r"\S")
re_word = re.compile(r"\w")
re_list = re.compile(r"\[(\d+)\]$")
re_map = re.compile(r"\[\"?(\S+)\"?\]$")


def find_block(content, offt):
    max_idx = len(content)
    idx = offt
    # block begin, content begin, content end, block end
    result = [-1, -1, -1, -1]
    block_border = None
    bbp_depth = [0, 0, 0]
    c_prev = ""
    while idx < max_idx:
        c = content[idx]
        if c == '"':
            if not block_border:
                block_border = c
                result[0:2] = [idx, idx + 1]
            idx = content.find('"', idx + 1)
            if idx < 0:
                break
            if block_border == c:
                result[2:4] = [idx, idx + 1]
                return result
        elif c == '#' or (c == '/' and c_prev == '/'):
            idx = content.find('\n', idx)
            if idx < 0:
                break
            c = '\n'
        elif c == '*' and c_prev == '/':
            idx = content.find("*/", idx) + 1
            if idx < 1:
                break
            c = '/'
        elif c == '<' and c_prev == '<':
            new_idx = content.find('\n', idx)
            if new_idx < 0:
                break
            term = content[idx + 1:new_idx]
            if not block_border:
                block_border = term
                result[0:2] = [idx - 1, new_idx + 1]
            new_idx = content.find(term, new_idx)
            if new_idx < 0:
                break
            idx = new_idx + len(term) - 1
            c = term[-1]
            if block_border == term:
                result[2:4] = [new_idx - 1, idx + 1]
                return result
        elif c in "{[(":
            b_idx = "{[(".find(c)
            bbp_depth[b_idx] += 1
            if not block_border:
                block_border = "}])"[b_idx]
                result[0:2] = [idx, idx + 1]
            elif block_border == ' ':
                block_border = "}])"[b_idx]
        elif c in "}])":
            bbp_depth["}])".find(c)] -= 1
            if c == block_border and bbp_depth == [0, 0, 0]:
                result[2:4] = [idx, idx + 1]
                return result
        elif not block_border and re_word.search(c):
            block_border = ' '
            result[0:2] = [idx, idx]
        elif block_border == ' ' and (c in " \t\n\r\f\v" or idx + 1 == max_idx):
            result[2:4] = [idx, idx]
            return result
        idx += 1
        c_prev = c
    return None


def parse_kv(content, offt):
    d = {}
    max_idx = len(content)
    idx = offt + 1  # step into dict
    while idx < max_idx:
        m = re_non_space.search(content, idx)
        if not m:
            return d
        idx = m.start()

        if content[idx] == "}":
            break
        key_block = find_block(content, idx)
        if not key_block:
            return d

        m = re_kv_sep.search(content, key_block[3])
        if not m or m.start() != key_block[3]:
            return d
        value_block = find_block(content, m.end())
        if not value_block:
            return d

        d[content[key_block[1]:key_block[2]]] = value_block
        idx = value_block[3]
    return d


def parse_list(content, offt):
    l = []
    max_idx = len(content)
    idx = offt + 1  # step into list
    while idx < max_idx:
        m = re_non_space.search(content, idx)
        if not m:
            return l
        idx = m.start()

        if content[idx] == ']':
            break
        if content[idx] == ',':
            m = re_non_space.search(content, idx + 1)
            if not m:
                return l
            idx = m.start()
        item_block = find_block(content, idx)
        if not item_block:
            return l

        l.append(item_block)
        idx = item_block[3]
    return l


def find_var_block(content, var):
    m = re.search(r"variable\s*\"?" + var + r"\"?\s*{", content)
    if m:
        begin = m.end() - 1
        return parse_kv(content, begin)
    return None


def find_path_location(content, path_list):
    block_dict = find_var_block(content, path_list[0])
    if not block_dict:
        print("Variable block not found")
        return None
    if "default" not in block_dict:
        print("No default key in variable block")
        return None

    if len(path_list) == 1:
        return block_dict["default"]
    else:
        block = block_dict["default"]
        for path in path_list[1:]:
            m = re_list.search(path)
            if m:
                path = path[:m.start()]
            block = parse_kv(content, block[0])[path]
            if m:
                idx = int(m.group(1))
                block = parse_list(content, block[0])[idx]
        return block


def main():
    parser = argparse.ArgumentParser(description="update terraform variables")
    parser.add_argument("-f", "--file", required=True, nargs=1, type=str, help="terraform variables file")
    parser.add_argument("-p", "--path", required=True, nargs=1, type=str, help="variable path to change")
    parser.add_argument("-v", "--value", required=False, default=None, nargs=1, type=str, help="new variable value")
    args = parser.parse_args()

    filename = args.file[0]
    with open(filename, mode='r') as file:
        content = file.read()
    if len(content) == 0:
        print("Failed to read variables file")
        exit(1)
    path_list = [a for a in args.path[0].split(".") if len(a) > 0]
    if len(path_list) == 0:
        print("Too short variable path")
        exit(1)

    block = find_path_location(content, path_list)
    if not block:
        exit(1)

    if args.value is None:
        print(content[block[1]:block[2]])
    else:
        new_content = content[:block[1]] + args.value[0] + content[block[2]:]
        with open(filename, mode='w') as file:
            file.write(new_content)


if __name__ == '__main__':
    main()
