import vh3

from . import ops


# @vh3.decorator.graph("https://nirvana.yandex-team.ru/flow/b3f6a085-62b1-4fa8-8f3d-e61257bdd9fd")
# def train_traffic_signs_detector() -> None:
#     """
#     Train traffic signs detector
#     """
#     traffic_signs_yt_to_tf_record_result: ops.TrafficSignsDatasetPrepareOutput = ops.traffic_signs_dataset_prepare(
#         join_speed_limits=True,
#         augm_flip_horz=True,
#         max_ram=500,
#         arcadia_revision=7247797,
#         arcadia_patch="",
#         yt_token="maps_mrc_yt_hahn_token",
#         input_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_train",
#         input_test_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_test",
#         input_table_cache_date="2020-08-25",
#         min_size=30,
#         min_amount=400,
#         train_part=1.01,
#         filter_class_table="//home/maps/core/mrc/datasets/traffic_signs/extended_set_of_classes_joined_speed",
#     )
#     train_traffic_signs_faster_rcnn_result: ops.TrainFasterRcnnOutput = ops.train_faster_rcnn(
#         label_map=traffic_signs_yt_to_tf_record_result.label_map,
#         train_data=[traffic_signs_yt_to_tf_record_result.train_tfrecord],
#         config_url="https://paste.yandex-team.ru/617263/text",
#         detection_ckpt_url="https://proxy.sandbox.yandex-team.ru/700051276",
#         save_every_steps=30000,
#         gpu_count=4,
#         tf_research_url="https://proxy.sandbox.yandex-team.ru/671977871",
#         num_epoches=10,
#         optimizer="adam_optimizer",
#         lr_init=5e-05,
#         lr_decay_epoches=3,
#         lr_decay_factor=0.5,
#     )
#     test_faster_rcnn_result: vh3.Binary = ops.test_faster_rcnn(
#         gdef_tgz=train_traffic_signs_faster_rcnn_result.gdef_tgz,
#         test_data=[traffic_signs_yt_to_tf_record_result.test_tfrecord],
#         label_map=traffic_signs_yt_to_tf_record_result.label_map,
#         max_ram=4000,
#         tf_research_url="https://proxy.sandbox.yandex-team.ru/671977871",
#         score_thr=0.7,
#         bbox_cmp_func="iou",
#         bbox_cmp_thr=0.5,
#         bbox_min_size=30,
#         nms_thr=0.9,
#     )


@vh3.decorator.graph()
def new_train() -> None:
    convert_to_tf_dataset_binary: vh3.Binary = ops.build_arcadia_project(
        targets="maps/wikimap/mapspro/services/mrc/tools/traffic_signs_yt_to_tfrecord",
        arts="maps/wikimap/mapspro/services/mrc/tools/traffic_signs_yt_to_tfrecord/traffic_signs_yt_to_tfrecord",
        arcadia_revision=7247797
    )
    # traffic_signs_yt_to_tf_record_result: ops.TrafficSignsDatasetPrepareOutput =
    ops.traffic_signs_prepare_dataset(
        executable=convert_to_tf_dataset_binary.arcadia_project,
        yt_token="maps_mrc_yt_hahn_token",
        mr_default_cluster="hahn",
        input_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_train",
        input_test_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_test",
        # input_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_israel",
        # input_test_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_israel",
        input_table_cache_date="2020-08-25",
        min_size=30,
        min_amount=400,
        train_part=1.01,
        filter_class_table="//home/maps/core/mrc/datasets/traffic_signs/extended_set_of_classes_joined_speed",
        max_ram=1500,
        max_disk=150000,
    )
    # code = ops.arc_export(
    #     arc_token="maps_mrc_arc_token",
    #     path="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_detection/learning/faster_rcnn/srcs",
    #     reference="users/quoter/mrc_train_signs_detector_in_nirvana",
    # )
    # training_data_tar = ops.create_tar_archive(
    #     files=[traffic_signs_yt_to_tf_record_result.train_tfrecord,
    #            traffic_signs_yt_to_tf_record_result.label_map]
    # )
    # training_result = ops.python_3_deep_learning(
    #     gpu_count=4,
    #     ttl=48000,
    #     max_disk=204800,
    #     max_ram=32000,
    #     gpu_max_ram=10500,
    #     gpu_type='CUDA_8_0',
    #     script=code,
    #     data=training_data_tar,
    #     **vh3.block_args(name="Train"),
    # )
