import logging

from sandbox import sdk2

from sandbox.projects.common.task_env import TinyRequirements
from sandbox.common.types.misc import DnsType


class SpacyModel(sdk2.Resource):
    model = sdk2.resource.Attributes.String('Model name')
    version = sdk2.resource.Attributes.String('Model version')


class SpacyModelSync(sdk2.Task):
    class Requirements(TinyRequirements):
        dns = DnsType.DNS64

    class Parameters(sdk2.Task.Parameters):
        with sdk2.parameters.Group('YT parameters') as yt_parameters:
            yt_token = sdk2.parameters.YavSecret(
                'YT OAuth token',
                required=True
            )
            clusters = sdk2.parameters.List(
                'Clusters',
                value_type=sdk2.parameters.String,
                default=['hahn', 'arnold']
            )

        with sdk2.parameters.Group('Sync parameters') as sync_parameters:
            min_version = sdk2.parameters.String(
                'Min Spacy version',
                default='3.2.0'
            )
            models_to_sync = sdk2.parameters.List(
                'Models to sync',
                value_type=sdk2.parameters.String,
                default=[
                    'ru_core_news_sm', 'ru_core_news_md', 'ru_core_news_lg'
                ]
            )
            overwrite = sdk2.parameters.Bool(
                'Overwrite existing models',
                default=False
            )

    def on_execute(self):
        from geoproduct.aml.libs.spacy_helper import yt_sync_models

        logging.basicConfig()
        logging.getLogger().setLevel(level=logging.INFO)

        yt_sync_models(
            # yt parameters
            yt_token=self.Parameters.yt_token.value(),
            yt_proxies=self.Parameters.clusters,
            # sync parameters
            min_spacy_version=self.Parameters.min_version,
            models_to_sync=self.Parameters.models_to_sync,
            overwrite=self.Parameters.overwrite
        )
