package ru.yandex.tours.geo.base.region

import java.io.InputStream

import org.joda.time.DateTimeZone
import ru.yandex.extdata.common.meta.DataType
import ru.yandex.tours.extdata.{DataDef, DataTypes}
import ru.yandex.tours.geo.base.{region, TreeBase, Region}
import ru.yandex.tours.geo.base.export.XmlParser
import ru.yandex.tours.util.Collections._

import scala.collection.mutable

/** Compiles number of regions to region tree:
  * provides access for parents and children for regions.
  */
class Tree(val regions: Iterable[Region]) extends TreeBase[region.Id, Region] with Serializable {

  private val id2region = (for (region <- regions) yield region.id -> region).toMap
  private val id2metaParent = (for {
    region <- regions
    child <- region.childrenIds
  } yield {
    child -> region
  }).toMultiMap

  private val iso2region = (for {
    region <- regions
    iso <- region.isoCode.toIterable
  } yield iso -> region).toMap

  private val id2children = (for {
    region <- regions
    pair <- {
      val children = for {
        childId <- region.childrenIds
        child <- id2region.get(childId)
      } yield region.id -> child
      children :+ (region.parentId -> region)
    }
  } yield pair).toMultiMap

  def size: Id = id2region.size

  def region(id: Id): Option[Region] = id2region.get(id)

  def parent(region: Region): Option[Region] = id2region.get(region.parentId)

  def parent(id: Id): Option[Region] = id2region.get(id).flatMap(r => id2region.get(r.parentId))

  def children(id: Id): Set[Region] = id2children.getOrElse(id, mutable.Set.empty).toSet

  def children(region: Region): Set[Region] = children(region.id)

  def getByIso(iso: String): Option[Region] = iso2region.get(iso)

  val root: Option[Region] = regions.headOption.map { region =>
    pathToRoot(region).last
  }

  def findChildren(region: Region): Traversable[Region] = {
    new Traversable[Region] {
      def traverse[U](r: Region, f: Region => U): Unit = {
        f(r)
        for (child <- children(r)) {
          traverse(child, f)
        }
      }
      override def foreach[U](f: (Region) => U): Unit = traverse(region, f)
    }
  }

  def allChildren(region: Region): Set[Region] = {
    findChildren(region).toSet
  }

  def pathToRoot(regionId: Int): List[Region] = {
    for {
      region <- region(regionId).toList
      r <- pathToRoot(region)
    } yield r
  }

  def pathToRoot(region: Region): List[Region] = {
    parent(region) match {
      case Some(parent) => region :: pathToRoot(parent)
      case None => region :: Nil
    }
  }

  /** all parents including meta regions */
  def findParents(regionId: Int): Set[Region] = {
    for {
      region <- region(regionId).toSet[Region]
      r <- findParents(region)
    } yield r
  }

  /** all parents including meta regions */
  def findParents(region: Region): Set[Region] = {
    val metaParents = id2metaParent.getOrElse(region.id, Seq.empty)
    (parent(region) match {
      case Some(parent) => findParents(parent) + region
      case None => Set(region)
    }) ++ metaParents.toSet[Region].flatMap(findParents)
  }

  def getTimeZone(regionId: Int): Option[DateTimeZone] = pathToRoot(regionId).map(_.timeZone).find(_.isDefined).flatten
  def getTimeZone(region: Region): Option[DateTimeZone] = pathToRoot(region).map(_.timeZone).find(_.isDefined).flatten

  def parent(region: Region, `type`: Type*): Option[Region] = pathToRoot(region).find(r => `type`.contains(r.`type`))
  def parent(regionId: Int, `type`: Type*): Option[Region] = region(regionId).flatMap(parent(_, `type`: _*))

  def country(region: Region): Option[Region] = parent(region, Types.Country, Types.OverseasLand)
  def country(regionId: Int): Option[Region] = parent(regionId, Types.Country, Types.OverseasLand)
}

object Tree extends DataDef[Tree] {
  override def dataType: DataType = DataTypes.regions

  override def parse(is: InputStream): Tree = {
    val result = XmlParser.parse(is, RegionBoundaries.empty).get
    result.ensuring(_.size > 0, "Region tree should not be empty")
    val visited = mutable.HashSet.empty[Region]
    dfs(result.root.get, result, visited)
    val malformed = result.regions.filterNot(visited.contains).filterNot {
      region => region.`type` == Types.MetaRegion
    }
    require(malformed.isEmpty, s"Malformed geobase: regions without parent\n ${malformed.mkString("\n ")}")
    result
  }

  private def dfs(region: Region, tree: Tree, visited: mutable.HashSet[Region]) {
    visited += region
    for {
      child <- tree.children(region)
      if !visited.contains(child)
    } dfs(child, tree, visited)
  }

  def empty: Tree = new Tree(Iterable.empty[Region])
}