import copy
import json, urllib
import collections

import pandas as pd
import numpy as np

from plotly.offline import download_plotlyjs, init_notebook_mode, iplot

init_notebook_mode(connected=True)

## --- Config
df = datasets['Query 4']
title = 'MWeb Pageview Referrer'
begin = ['channel_profile']
begin = []
end = ['branch_install']
no_auto_login = False
max_depth = 4
min_value = 0.01 # percent
style = "percent"

## --- End Config

def isnan(x):
  return x != x

data_template = dict(
    type='sankey',
    domain = dict(
      x =  [0,1],
      y =  [0,1]
    ),
    orientation = "h",
    valueformat = "d",
    valuesuffix = " users",
    node = dict(
      pad = 10,
      thickness = 30,
      line = dict(
        color = "black",
        width = 0.5
      ),
    )
)

layout_template =  dict(
    title = "Title",
    height = 500,
    width = 1000,
    font = dict(
      size = 10
    )
)

label = []
source  = []
target = []
value = []
link_labels = []
total_users = 0

tree = collections.defaultdict(list)

total_users = 0

node_total = collections.defaultdict(int)

for row in df.itertuples():
  s = f'{row.action}_{row.action_idx}'
  
  if (begin and s in begin) or (not begin and row.action_idx == 1):
    total_users += int(row.count)  

  if isnan(row.next_action):
    t = 'end'
  elif row.next_action in end:
    t = row.next_action
  else:
    t = f'{row.next_action}_{row.action_idx + 1}'

  if s == begin and t == 'wall_passed' and no_auto_login:
    continue
  tree[s].append([t, int(row.count)])
  tree[t]
  node_total[t] += int(row.count)

label = list(tree.keys())

vis = {}
def build(parent, d=0):
  if parent in vis:
    return

  vis[parent] = True
  parent_total = sum(map(lambda x: x[1], tree[parent]))
  for child, v in tree[parent]:
    if v < min_value * total_users / 100:
      continue
    if d > max_depth:
      continue

    source.append(label.index(parent))
    target.append(label.index(child))
    value.append(v)
    
    link_labels.append(f"{v} users, {v/parent_total*100:.2f}% of source, {v/node_total[child]*100:.2f}% of target, {v/total_users*100:.2f}% of total")

    build(child, d+1)

for k in label:
  if begin and k not in begin:
    continue
  build(k)

data = copy.deepcopy(data_template)
layout = copy.deepcopy(layout_template)

layout['title'] = f"{title} ({total_users} users)"
data['node']['label'] = label

if style == "percent":
  value = np.array(value) / total_users * 100 
  data['valueformat'] = '.2f'
  data['valuesuffix'] = "% users"

data['link'] = dict(
    source = source,
    target = target,
    value = value,
    label = link_labels
)

fig = dict(data=[data], layout=layout)
iplot(fig, validate=False)
