import argparse
import logging
import os

from crypta.lib.python import yaml_config
from crypta.lib.python.custom_ml.classification.model_training import (
    ExistingModelTrainValidateHelper,
    NewModelTrainValidateHelper,
)
from crypta.lib.python.logging import logging_helpers
from crypta.lib.python.proto_secrets import proto_secrets
import crypta.lib.python.yql.client as yql_helpers
from crypta.lib.python.yt import yt_helpers
from crypta.profile.services.train_custom_model.proto.config_pb2 import TConfig

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True)
    return parser.parse_args()


def main():
    logging_helpers.configure_stdout_logger(logging.getLogger())
    args = parse_args()
    config = yaml_config.parse_config(TConfig, args.config)
    logger.info('Config:\n{}'.format(proto_secrets.get_copy_without_secrets(config)))

    yt_client = yt_helpers.get_yt_client(
        yt_proxy=config.Yt.Proxy,
        yt_pool=config.Yt.Pool,
        yt_token=config.Yt.Token,
    )
    yql_client = yql_helpers.create_yql_client(
        yt_proxy=config.Yt.Proxy,
        pool=config.Yt.Pool,
        token=os.getenv('YQL_TOKEN'),
    )

    if config.ModelName:
        task = ExistingModelTrainValidateHelper(
            yt_client=yt_client,
            yql_client=yql_client,
            output_dir=config.OutputDirPath,
            partner=config.Partner,
            login=config.Login,
            if_make_decision=config.IfMakeDecision,
            raw_sample_table=config.SampleTablePath if config.SampleTablePath else None,
            raw_sample_file=config.SampleFilePath if config.SampleFilePath else None,
            audience_id=config.AudienceId if config.AudienceId else None,
            industry_model_name=config.ModelName,
            validate_segments=config.ValidateSegments,
            crypta_identifier_udf_url=config.CryptaIdentifierUdfUrl,
            logger=logger,
            send_results_to_api=config.SendResultsToApi,
            logins_to_share=config.LoginsToShare,
        )
    else:
        task = NewModelTrainValidateHelper(
            yt_client=yt_client,
            yql_client=yql_client,
            output_dir=config.OutputDirPath,
            raw_sample_table=config.SampleTablePath if config.SampleTablePath else None,
            raw_sample_file=config.SampleFilePath if config.SampleFilePath else None,
            audience_id=config.AudienceId if config.AudienceId else None,
            validate_segments=config.ValidateSegments,
            crypta_identifier_udf_url=config.CryptaIdentifierUdfUrl,
            positive_output_segment_size=config.PositiveOutputSegmentSize,
            negative_output_segment_size=config.NegativeOutputSegmentSize,
            logger=logger,
            send_results_to_api=config.SendResultsToApi,
        )

    task.run()

    logger.info('Completed successfully')
