#!/usr/bin/env python

import os
import pickle

import libmxnet


# to get .info matrixnet model you can use arcadia/quality/relev_tools/mx_ops bin2info <matrixnet.bin> <matrixnet.info>
def TestInfo():
    data = [1029790.0, 0.0, 0.0, 0.0, 0.0, 91000.0, 46.912100000000002, -1.0, -1.0, 1280000000.0, 0.0, 154.78800000000001, 19.77, 0.30866700000000002, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0083999999999999995, 1.0, 0.0, -230.25899999999999, 0.990676, 1415260.0, 0.0, 13011.0, 1444.0, 0.34999999999999998, 0.62307900000000005, 0.0, 0.0, -1.0, 0.0, 0.0, 1594.0, 238.0, 23774.0, 205.0, 191.0, 48.0, 215.0, 1.0, 1329338624.0, 1329338624.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.6326, 4.5612599999999999, 7.1691200000000004, 6.8840599999999998, 14323.0, 1823.0, 1011.0, 0.070585800000000004, 0.55457999999999996, 14246.0, 3788.0, 2349.0, 0.16488800000000001, 0.620116, 0.65059, 814.0, 0.0, 0.0, 46.0, 0.377056, 314.0, 68.0, 142.0, 71.0, 19.0, 88.0, 1.0, 1329427584.0, 1329427584.0, 0.27800000000000002, 30.491499999999998, 0.165432, 0.040194500000000001]
    o = libmxnet.TMXNetInfo("ip25.info")
    res = o.Calculate([data])
    print res, res == [0.8053852333738295]

    o.SetProperty("hello", "world")
    o.Save("ip25.copy.info")
    oCopy = libmxnet.TMXNetInfo("ip25.copy.info")
    res = oCopy.Calculate([data])
    print res, res == [0.8053852333738295]
    print oCopy.GetProperty("hello"), oCopy.GetProperty("hello") == "world"
    os.remove("ip25.copy.info")

    c = oCopy.CopyTruncated(1000)
    res = c.Calculate([data])
    print res, res == [0.8051807285657016]

    d = libmxnet.TMXNetInfo()
    d.Load('ip25.info')
    c = pickle.loads(pickle.dumps(d))
    res = c.Calculate([data])
    print res, res == [0.8053852333738295]


def TestInfoSlices():
    data = [1029790.0, 0.0, 0.0, 0.0, 0.0, 91000.0, 46.912100000000002, -1.0, -1.0, 1280000000.0, 0.0, 154.78800000000001, 19.77, 0.30866700000000002, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0083999999999999995, 1.0, 0.0, -230.25899999999999, 0.990676, 1415260.0, 0.0, 13011.0, 1444.0, 0.34999999999999998, 0.62307900000000005, 0.0, 0.0, -1.0, 0.0, 0.0, 1594.0, 238.0, 23774.0, 205.0, 191.0, 48.0, 215.0, 1.0, 1329338624.0, 1329338624.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.6326, 4.5612599999999999, 7.1691200000000004, 6.8840599999999998, 14323.0, 1823.0, 1011.0, 0.070585800000000004, 0.55457999999999996, 14246.0, 3788.0, 2349.0, 0.16488800000000001, 0.620116, 0.65059, 814.0, 0.0, 0.0, 46.0, 0.377056, 314.0, 68.0, 142.0, 71.0, 19.0, 88.0, 1.0, 1329427584.0, 1329427584.0, 0.27800000000000002, 30.491499999999998, 0.165432, 0.040194500000000001]
    no_slice = libmxnet.TMXNetInfo("ip25.info")
    with_slice = libmxnet.TMXNetInfo("ip25.info")
    with_custom_slice = libmxnet.TMXNetInfo("ip25.info")
    with_slice.SetProperty('Slices', 'web_production[0;93)')
    with_custom_slice.SetProperty('Slices', 'custom_web_production[0;93)')
    try:
      no_slice.Calculate([data], "web_production[0;93)")
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail on giving slices to formula without slices')
    try:
      with_slice.Calculate([data])
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail on calculating formula with slices without data slices')
    try:
      with_custom_slice.Calculate([data], "custom_web_production[0;93)", disallow_custom_slices=True)
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail with unknown slices error when called with check flag')

    with_slice.Calculate([data], "web_meta[0;93)", skip_slices_check=True)
    with_slice.Calculate([data], "custom_web_production[0;93)", skip_slices_check=True)
    try:
      with_slice.Calculate([data], "custom_web_production[0;93)", disallow_custom_slices=True)
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail with missing slice in fromula error when called with a disallow_custom_slices flag')

    try:
      with_slice.Calculate([data[:10]], 'web_production[0;10)')
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail on calculating formula without required factors')
    try:
      with_slice.Calculate([data[:10]], 'web_production[0;93)')
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail on calculating formula without required data')

    try:
      with_slice.Calculate([data+data], 'web_production[0;93)')
    except Exception as e:
      print(str(e))
    else:
      raise Exception('Should fail on calculating formula with more factors, than slices describe')

    with_slice.SetProperty('Slices', 'web_meta[0;0) web_production[0;93)')
    no_slice_result = no_slice.Calculate([data])
    print no_slice.Calculate([data]), no_slice.Calculate([data]) == with_slice.Calculate([data],  'web_production[0;93)')
    assert no_slice.Calculate([data]) == with_slice.Calculate([data],  'web_production[0;93)'), 'Fromulas calculation with&withot slices should be euqal'
    print no_slice.Calculate([data]), no_slice.Calculate([data]) == with_custom_slice.Calculate([data],  'custom_web_production[0;93)')
    assert no_slice.Calculate([data]) == with_custom_slice.Calculate([data],  'custom_web_production[0;93)'), 'Fromulas calculation with&with custom slices should be euqal'

    test_slices = 'web_meta[0;%d) web_production[%d;93)'
    test_custom_slices = 'custom_web_meta[0;%d) custom_web_production[%d;93)'
    for i in range(0, 93, 13):
        with_slice.SetProperty('Slices', test_slices %(i,i))
        with_custom_slice.SetProperty('Slices', test_custom_slices %(i,i))
        data_with_zeroes = data[:i] + [0]*10 + data[i:] + [0]*20

        with_slice_result = with_slice.Calculate([data_with_zeroes], 'web_meta[0;%d) pairwise_clicks[%d;%d) web_production[%d;%d)' % (i, i, i+10, i+10, 93+30))
        print test_slices % (i, i), no_slice_result == with_slice_result, with_slice_result
        assert no_slice_result == with_slice_result, 'Results must be equal'

        with_custom_slice_result = with_custom_slice.Calculate([data_with_zeroes], 'custom_web_meta[0;%d) custom_pairwise_clicks[%d;%d) custom_web_production[%d;%d)' % (i, i, i+10, i+10, 93+30))
        print test_slices % (i, i), no_slice_result == with_slice_result, with_slice_result
        assert no_slice_result == with_slice_result, 'Results must be equal'


def TestMC():
    data = [[0.000000, 1.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 1.260240, 0.503056, 0.310251, 0.000000, 0.000000, 0.504337, 0.170233, 0.164565, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.631407, 0.700000, 0.011163, 2.500000, 0.065813, 1.000000, 0.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, -0.400620, 0.719731,
            1.512280, 0.064035, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 15.000000, -24.672300, -21.235800, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 19958500.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.002650, 0.999117, 0.017827, 0.142857, 0.999440, 0.999941, 0.999386, 0.490903, 0.161938, 0.750000, 0.000000, 0.000000, 0.166667, 0.857143, 0.000000,
            0.000000, 0.000000, 0.666667, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            19958500.000000, 31720.000000, 0.000000, 2.000000, 0.987137, 551384000.000000, 0.000000, 0.087560, 0.962000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
            0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
            [0.000000,1.000000,0.000000,0.020233,0.000000,0.004134,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,
             1.543660,0.629181,0.418007,0.000000,-1.550580,0.709701,0.355865,0.167587,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,
             0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,
             0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,-31.718300,-18.380900,0.000000,0.000000,0.000000,
             0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,320705.000000,0.000000,1.000000,0.000000,0.000000,
             0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,1.000000,1.000000,0.000000,1.000000,
             0.000000,0.074380,0.991803,0.033482,0.997773,0.063917,0.627624,0.999970,0.999997,0.999968,0.604443,0.212278,0.928571,0.000000,0.000000,0.442560,0.999230,
             0.658537,0.976190,0.037436,0.999915,0.072727,0.982143,0.659155,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.622737,0.000000,0.000000,0.000000,
             0.000000,0.500000,0.000000,0.000000,0.000000,0.249015,0.183167,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,320705.000000,404.000000,0.000000,
             0.000000,0.000000,553877000.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,
             0.000000,0.000000,0.922000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000]]
    o = libmxnet.TMXNetMC("example.mnmc")
    categValues = o.CategValues()
    print categValues, categValues == [0.10000000149011612, 0.30000001192092896, 0.44999998807907104, 0.8999999761581421]

    o.SetProperty("hello", "world")
    print o.GetProperty("hello"), o.GetProperty("hello") == "world"

    res  = o.Calculate(data)
    print res, res == [0.42792228437236096, 0.10826492871766381]

    categs =  o.CalculateCategs(data)
    print categs, categs == [[0.063086534329469382, 0.0012146822080084183, 0.93528815064100246, 0.00041063282151977118],
                              [0.96654259369616147, 0.02508181965075533, 0.0076709025910396385, 0.0007046840620435628]]


def TestSplits():
    data = [0.0, 0.0, 0.0, 0.0, 91000.0, 46.912100000000002, -1.0, -1.0, 1280000000.0, 0.0, 154.78800000000001, 19.77, 0.30866700000000002, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0083999999999999995, 1.0, 0.0, -230.25899999999999, 0.990676, 1415260.0, 0.0, 13011.0, 1444.0, 0.34999999999999998, 0.62307900000000005, 0.0, 0.0, -1.0, 0.0, 0.0, 1594.0, 238.0, 23774.0, 205.0, 191.0, 48.0, 215.0, 1.0, 1329338624.0, 1329338624.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.6326, 4.5612599999999999, 7.1691200000000004, 6.8840599999999998, 14323.0, 1823.0, 1011.0, 0.070585800000000004, 0.55457999999999996, 14246.0, 3788.0, 2349.0, 0.16488800000000001, 0.620116, 0.65059, 814.0, 0.0, 0.0, 46.0, 0.377056, 314.0, 68.0, 142.0, 71.0, 19.0, 88.0, 1.0, 1329427584.0, 1329427584.0, 0.27800000000000002, 30.491499999999998, 0.165432, 0.040194500000000001]
    data = [[d] + data for d in [1029790.0, 270051, 270053, 548652, 548654, 886775, 886777]]
    o = libmxnet.TMXNetInfo("ip25.info")
    d = o.Calculate(data)
    print d
    o1, o2 = o.SplitBySpecifiedFactors([0])
    d1, d2 = o1.Calculate(data), o2.Calculate(data)
    print [d1[i] + d2[i] for i in xrange(len(d))]


def TestLoadPart():
    data = [0.0, 0.0, 0.0, 0.0, 91000.0, 46.912100000000002, -1.0, -1.0, 1280000000.0, 0.0, 154.78800000000001, 19.77, 0.30866700000000002, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0083999999999999995, 1.0, 0.0, -230.25899999999999, 0.990676, 1415260.0, 0.0, 13011.0, 1444.0, 0.34999999999999998, 0.62307900000000005, 0.0, 0.0, -1.0, 0.0, 0.0, 1594.0, 238.0, 23774.0, 205.0, 191.0, 48.0, 215.0, 1.0, 1329338624.0, 1329338624.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.6326, 4.5612599999999999, 7.1691200000000004, 6.8840599999999998, 14323.0, 1823.0, 1011.0, 0.070585800000000004, 0.55457999999999996, 14246.0, 3788.0, 2349.0, 0.16488800000000001, 0.620116, 0.65059, 814.0, 0.0, 0.0, 46.0, 0.377056, 314.0, 68.0, 142.0, 71.0, 19.0, 88.0, 1.0, 1329427584.0, 1329427584.0, 0.27800000000000002, 30.491499999999998, 0.165432, 0.040194500000000001]
    data = [[d] + data for d in [1029790.0, 270051, 270053, 548652, 548654, 886775, 886777]]
    o = libmxnet.TMXNetInfo("ip25.info")
    o1 = libmxnet.TMXNetInfo()
    o2 = libmxnet.TMXNetInfo()
    o1.LoadInfoPart("ip25.info", 0, 1000)
    o2.LoadInfoPart("ip25.info", 1000, 2000)
    d = o.Calculate(data)
    d1 = o1.Calculate(data)
    d2 = o2.Calculate(data)
    print [(d1[i] + d2[i] - d[i]) < 1e-7 for i in xrange(len(d))]
    o1.LoadBinPart("ip25.bin", 0, 1000)
    o2.LoadBinPart("ip25.bin", 1000, 2000)

    b1 = o1.Calculate(data)
    b2 = o2.Calculate(data)
    print [b1[i] - d1[i] < 1e-7 for i in xrange(len(d1))]
    print [b2[i] - d2[i] < 1e-7 for i in xrange(len(d1))]


    o.SetProperty("hello", "world")
    o.Save("ip25.copy.info")
    oCopy = libmxnet.TMXNetInfo()
    oCopy.LoadInfoPart("ip25.copy.info", 0, 1000)
    print oCopy.GetProperty("hello"), oCopy.GetProperty("hello") == "world"
    os.remove("ip25.copy.info")


if __name__ == "__main__":
    TestInfo()
    TestMC()
    TestSplits()
    TestLoadPart()
    TestInfoSlices()
