import os
import sys
import logging
from time import sleep
from typing import Dict, Any, List
from zenyatta.aws import get_aws_metadata, boto_client
from zenyatta.common.errors import ZenyattaError
from zenyatta.common.util import check_response
from botocore.exceptions import ClientError


class EMR():
    """EMR for Zenyatta"""

    def __init__(self, emr_name: str=None, role_arn: str=None):
        self.client, _ = boto_client('emr', role_arn)
        self.cluster_name = emr_name
        self.cluster = self.fetch_emr_cluster()

    def fetch_emr_cluster(self) -> Any:
        """
        Fetch the active cluster instance of a given name
        If name is not found, find an active cluster with name start with
        'zenyatta'
        """
        response = self.client.list_clusters(ClusterStates=[
            'STARTING',
            'BOOTSTRAPPING',
            'RUNNING',
            'WAITING'])
        if check_response(response):
            cluster = [c for c in response.get('Clusters') if self.cluster_name in c['Name']]
            if len(cluster) == 0:
                raise ZenyattaError("no EMR cluster is found under the name of {}"
                                    .format(self.cluster_name))
            elif len(cluster) > 1:
                logging.info("more than one emr clusters are found under the name of {}."
                             .format(self.cluster_name))
            return cluster[0]

    def list_cluster_steps(self, steps: List=[]) -> List[Dict]:
        """
        List the cluster steps for a given list of step id
        If steps is none, default to return all steps
        """
        response = self.client.list_steps(ClusterId=self.cluster['Id'], StepIds=steps)
        if check_response(response):
            steps = response.get('Steps')
            return steps

    def add_step_to_cluster(self, steps: List[Dict]) -> List:
        """
        Add a list of steps to the cluster
        """
        response = self.client.add_job_flow_steps(JobFlowId=self.cluster['Id'], Steps=steps)
        if check_response(response):
            stepIds = response.get('StepIds')
            return stepIds

    def get_step_status(self, stepId: str) -> List[Dict]:
        """
        Get step status for a given step id
        """
        all_steps = self.list_cluster_steps()
        return [item for item in all_steps if item['Id'] == stepId]

    def list_instance_groups(self) -> List[Dict]:
        """
        Return instance groups of the cluster
        """
        response = self.client.list_instance_groups(ClusterId=self.cluster['Id'])
        if check_response(response):
            instance_groups = response.get('InstanceGroups')
            return instance_groups

    def resize_instance_group(self, groupType: str='CORE', new_cnt: int=3) -> bool:
        """
        Resize instance group.
        If groupType is not given, default to CORE group.
        If instance count is not given, default to 3.
        """
        instance_groups = self.list_instance_groups()
        try:
            group = [g for g in instance_groups if g['InstanceGroupType'] == groupType][0]
            cur_instancegroups = []
            cur_instancegroup = {'InstanceGroupId': group['Id'], 'InstanceCount': new_cnt}
            cur_instancegroups.append(cur_instancegroup)
            response = self.client.modify_instance_groups(InstanceGroups=cur_instancegroups)
            return True if check_response(response) else False
        except Exception as e:
            logging.info("failed to resize emr instance gorup: {}".format(e))

    def get_cluster_description(self, cluster_id: str) -> Dict:
        try:
            response = self.client.describe_cluster(ClusterId=cluster_id)
            if check_response(response):
                return response['Cluster']
        except ClientError as e:
            raise ZenyattaError("client error {}".format(e))
