import vh3

from . import ops


@vh3.decorator.graph()
def run(arc_token: vh3.Secret = vh3.Factory(lambda: vh3.context.arc_token),
        yt_token: vh3.Secret = vh3.Factory(lambda: vh3.context.yt_token),
        pulsar_token: vh3.Secret = vh3.Factory(lambda: vh3.context.pulsar_token)) -> None:

    code = ops.arc_export_not_deterministic(
        arc_token=arc_token,
        path="maps/wikimap/mapspro/services/mrc/eye/experiments/signs_detection/learning",
        reference="trunk",
        max_ram=500,
    )
    merged_srcs : vh3.Binary = ops.merge_tar_archives(tar_archives=[code])

    data = ops.traffic_signs_dataset_prepare(
        arcadia_revision=7247797,
        yt_token="mrc_yt_token",
        input_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_train",
        input_test_table="//home/maps/core/mrc/datasets/traffic_signs/traffic_signs_mapsmrc_test",
        input_table_cache_date="2022-05-11",
        min_size=30,
        min_amount=400,
        train_part=1.01,
        join_speed_limits=True,
        augm_flip_horz=True,
        filter_class_table="//home/maps/core/mrc/datasets/traffic_signs/extended_set_of_classes_joined_speed_add_warning_uneven_road_ahead_and_prescription_road_hump",
        max_ram=500
    )

    train_data_tar: vh3.Binary = ops.create_tar_archive_10(
        file0=data.train_tfrecord,
        name0="train_tfrecord",
        max_disk=150000,
        ttl=1000
    )

    test_data_tar: vh3.Binary = ops.create_tar_archive_10(
        file0=data.test_tfrecord,
        name0="test_tfrecord",
        max_disk=150000,
        ttl=1000
    )

    train = ops.python_3_deep_learning(
        base_layer="PYDL_V5_GPU",
        pip=["pydl-defaults==2.4", "pillow", "tqdm", "numpy", "pandas", "scipy"],
        gpu_count=4,
        ttl=48000,
        max_disk=204800,
        max_ram=64000,
        gpu_max_ram=42000,
        gpu_type='CUDA_8_0',
        script=merged_srcs,
        run_command="python3 $SOURCE_CODE_PATH/__main__.py",
        data=(train_data_tar, test_data_tar),
    )

    timestamp = ops.get_timestamp_now()

    instance_builder = ops.pulsar_add_instance(
        pulsar_token=pulsar_token,
        model_name="FasterRCNNMain",
        model_version="2.0.0",
        dataset_name="mrc_train_full",
        **vh3.block_args(dynamic_options=timestamp),
    )

    ops.pulsar_update_instance(
        tfevents=train.logs,
        pulsar_token=pulsar_token,
        instance_id="1",
        **vh3.block_args(dynamic_options=instance_builder.info),
    )


class GraphTokensContext(vh3.DefaultContext):
    arc_token: vh3.Secret
    yt_token: vh3.Secret
    pulsar_token: vh3.Secret


def main():
    arc_token = 'maps_mrc_arc_token'
    yt_token='maps_mrc_yt_hahn_token'
    pulsar_token='robot_maps_core_mrc_pulsar_token'

    with vh3.Profile(GraphTokensContext, arc_token=arc_token, yt_token=yt_token, pulsar_token=pulsar_token, quota='geo-common').load_defaults().build(vh3.WorkflowInstance) as wi:
        run()
    wi.run()

if __name__ == '__main__':
    main()
