#include "infer.h"

#include "constructors_set.h"

#include <infra/yasm/aldan/common/exceptions.h>

#include <util/generic/hash.h>
#include <util/generic/algorithm.h>
#include <util/generic/xrange.h>

using namespace NYasm::NAldan;

namespace {
    class TContext {
    public:
        TContext(TRegistry& registry, TTypeConstructorSet& nonGeneric)
            : Registry(registry)
            , NonGeneric(nonGeneric)
        {
        }

        bool IsGeneric(const TTypeConstructor::TPtr& v) {
            return !OccursInMultipleTypes(v, NonGeneric.begin(), NonGeneric.end());
        }

        bool OccursInType(const TTypeConstructor::TPtr& v, const TTypeConstructor::TPtr& typeExpr) {
            auto prunedTypeExpr = typeExpr->Prune();
            if (*v == *typeExpr) {
                return true;
            } else if (prunedTypeExpr->IsComposite()) {
                const auto& operatorTypes(prunedTypeExpr->CastUnsafe<TTypeOperator>().GetTypes());
                return OccursInMultipleTypes(v, operatorTypes.begin(), operatorTypes.end());
            }
            return false;
        }

        TTypeConstructor::TPtr Clone(const TTypeConstructor::TPtr& t) {
            TTypeConstructorMap<TTypeConstructor::TPtr> mapping;
            return Clone(t, mapping);
        }

        TTypeConstructor::TPtr GetType(const TString& name, TTypeConstructor::TPtr argumentType) {
            TTypeConstructor::TPtr result;

            result = GetTypeFromRegistry(name, argumentType);
            if (result) {
                return result;
            }

            if (argumentType) {
                ythrow TSymbolNotFound(name) << "No function " << name << "(" << argumentType->ToString() << ") defined";
            } else {
                ythrow TSymbolNotFound(name) << "Can't find symbol " << name;
            }
        }

        void Unify(TTypeConstructor::TPtr lhs, TTypeConstructor::TPtr rhs) {
            lhs = lhs->Prune();
            rhs = rhs->Prune();
            if (lhs->GetType() == TTypeConstructor::Variable) {
                if (*lhs != *rhs) {
                    if (OccursInType(lhs, rhs)) {
                        ythrow TUnifyError() << "recursive unification";
                    }
                    lhs->CastUnsafe<TTypeVariable>().SetInstance(std::move(rhs));
                }
            } else if (lhs->IsComposite() && rhs->GetType() == TTypeConstructor::Variable) {
                Unify(std::move(rhs), std::move(lhs));
            } else if (lhs->GetType() == TTypeConstructor::Tuple && rhs->GetType() == TTypeConstructor::List) {
                Unify(std::move(rhs), std::move(lhs));
            } else if (lhs->GetType() == TTypeConstructor::List && rhs->GetType() == TTypeConstructor::Tuple) {
                auto& typedLeft = lhs->CastUnsafe<TListConstructor>();
                auto& typedRight = rhs->CastUnsafe<TTupleConstructor>();
                if (typedRight.GetTypes().empty()) {
                    ythrow TUnifyError() << rhs->ToString() << " can't be converted to " << lhs->ToString();
                }
                for (const auto& tp : typedRight.GetTypes()) {
                    Unify(tp, typedLeft.GetValueType());
                }
            } else if (lhs->IsComposite() && rhs->IsComposite()) {
                auto& typedLeft = lhs->CastUnsafe<TTypeOperator>();
                auto& typedRight = rhs->CastUnsafe<TTypeOperator>();
                const auto& leftTypes(typedLeft.GetTypes());
                const auto& rightTypes(typedRight.GetTypes());
                if (typedLeft.GetName() != typedRight.GetName() || leftTypes.size() != rightTypes.size()) {
                    ythrow TUnifyError() << lhs->ToString() << " not matching to " << rhs->ToString();
                }
                for (const auto& index : xrange(leftTypes.size())) {
                    Unify(leftTypes[index], rightTypes[index]);
                }
            } else {
                ythrow TUnifyError() << "not unified";
            }
        }

        TTypeConstructor::TPtr Analyze(TDefinition& definition, TTypeConstructor::TPtr argumentType=nullptr) {
            switch (definition.GetType()) {
                case TDefinition::Identifier: {
                    auto& casted(definition.CastUnsafe<TIdentifierDefinition>());
                    TTypeConstructor::TPtr result;
                    if (casted.SuggestedType) {
                        result = Clone(casted.SuggestedType);
                    }
                    if (!result) {
                        result = GetType(casted.Name, std::move(argumentType));
                    }
                    casted.SetInferredType(Clone(result));
                    return result;
                }
                case TDefinition::Apply: {
                    auto& casted(definition.CastUnsafe<TApplyDefinition>());
                    auto argumentType = Analyze(*casted.Argument);
                    auto functionType = Analyze(*casted.Function, argumentType);
                    auto resultType = MakeIntrusive<TTypeVariable>();
                    auto newFunctionType = MakeIntrusive<TFunctionConstructor>(argumentType, resultType);
                    Unify(newFunctionType, functionType);
                    casted.SetInferredType(Clone(resultType));
                    return resultType;
                }
                case TDefinition::Lambda: {
                    auto& casted(definition.CastUnsafe<TLambdaDefinition>());
                    auto argumentType = MakeIntrusive<TTypeVariable>();
                    TRegistry newRegistry;
                    newRegistry.insert(Registry.begin(), Registry.end());
                    newRegistry[casted.Name].emplace_back(argumentType);
                    TTypeConstructorSet newNonGeneric;
                    newNonGeneric.insert(NonGeneric.begin(), NonGeneric.end());
                    newNonGeneric.emplace(argumentType);
                    TContext newContext(newRegistry, newNonGeneric);
                    auto resultType = newContext.Analyze(*casted.Body);
                    auto newFunctionType = MakeIntrusive<TFunctionConstructor>(argumentType, resultType);
                    casted.SetInferredType(newContext.Clone(newFunctionType));
                    return newFunctionType;
                }
                case TDefinition::Let: {
                    auto& casted(definition.CastUnsafe<TLetDefinition>());
                    auto definitionType = Analyze(*casted.Definition);
                    TRegistry newRegistry;
                    newRegistry.insert(Registry.begin(), Registry.end());
                    newRegistry[casted.Name].emplace_back(definitionType);
                    TContext newContext(newRegistry, NonGeneric);
                    auto resultType = newContext.Analyze(*casted.Body);
                    casted.SetInferredType(newContext.Clone(resultType));
                    return resultType;
                }
                case TDefinition::Tuple: {
                    auto& casted(definition.CastUnsafe<TTupleDefinition>());
                    TVector<TTypeConstructor::TPtr> newTypes(Reserve(casted.Arguments.size()));
                    for (const auto& tp : casted.Arguments) {
                        newTypes.emplace_back(Analyze(*tp));
                    }
                    auto newTupleType = MakeIntrusive<TTupleConstructor>(std::move(newTypes));
                    casted.SetInferredType(Clone(newTupleType));
                    return newTupleType;
                }
                case TDefinition::List: {
                    auto& casted(definition.CastUnsafe<TListDefinition>());
                    TTypeConstructor::TPtr valueType = MakeIntrusive<TTypeVariable>();
                    for (const auto& tp : casted.Arguments) {
                        Unify(valueType, Analyze(*tp));
                    }
                    valueType = valueType->Prune();
                    if (valueType->GetType() == TTypeConstructor::Variable) {
                        ythrow TInferError() << "list type can't be deduced";
                    }
                    auto newListType = MakeIntrusive<TListConstructor>(valueType);
                    casted.SetInferredType(Clone(newListType));
                    return newListType;
                }
            }
        }

    private:
        template <class Iterator>
        bool OccursInMultipleTypes(const TTypeConstructor::TPtr& v, Iterator f, Iterator l) {
            return AnyOf(f, l, [&](const TTypeConstructor::TPtr& other) {
                return OccursInType(v, other);
            });
        }

        TTypeConstructor::TPtr Clone(const TTypeConstructor::TPtr& t, TTypeConstructorMap<TTypeConstructor::TPtr>& mapping) {
            auto pruned = t->Prune();
            switch (pruned->GetType()) {
                case TTypeConstructor::Variable: {
                    if (IsGeneric(pruned)) {
                        TTypeConstructorMap<TTypeConstructor::TPtr>::insert_ctx ctx;
                        auto it = mapping.find(pruned, ctx);
                        if (it == mapping.end()) {
                            it = mapping.emplace_direct(ctx, pruned, MakeIntrusive<TTypeVariable>());
                        }
                        return it->second;
                    } else {
                        return pruned;
                    }
                }
                case TTypeConstructor::Operator: {
                    const auto& casted(pruned->CastUnsafe<TTypeOperator>());
                    return MakeIntrusive<TTypeOperator>(casted.GetName(), CloneMultiple(casted.GetTypes(), mapping));
                }
                case TTypeConstructor::Function: {
                    const auto& casted(pruned->CastUnsafe<TFunctionConstructor>());
                    return MakeIntrusive<TFunctionConstructor>(Clone(casted.GetFromType(), mapping), Clone(casted.GetToType(), mapping));
                }
                case TTypeConstructor::Tuple: {
                    const auto& casted(pruned->CastUnsafe<TTupleConstructor>());
                    return MakeIntrusive<TTupleConstructor>(CloneMultiple(casted.GetTypes(), mapping));
                }
                case TTypeConstructor::List: {
                    const auto& casted(pruned->CastUnsafe<TListConstructor>());
                    return MakeIntrusive<TListConstructor>(Clone(casted.GetValueType(), mapping));
                }
            }
        }

        TVector<TTypeConstructor::TPtr> CloneMultiple(const TVector<TTypeConstructor::TPtr>& incoming,
                                                      TTypeConstructorMap<TTypeConstructor::TPtr>& mapping) {
            TVector<TTypeConstructor::TPtr> result(Reserve(incoming.size()));
            for (const auto& tp : incoming) {
                result.emplace_back(Clone(tp, mapping));
            }
            return result;
        }

        TTypeConstructor::TPtr GetTypeFromRegistry(const TString& name, TTypeConstructor::TPtr argumentType) {
            auto variants = Registry.find(name);
            if (variants != Registry.end()) {
                for (const auto& variant : variants->second) {
                    if (argumentType && variant->GetType() == TTypeConstructor::Function) {
                        try {
                            Unify(Clone(variant->CastUnsafe<TFunctionConstructor>().GetFromType()), Clone(argumentType));
                            return Clone(variant);
                        } catch(const TInferError&) {
                        }
                    } else {
                        return Clone(variant);
                    }
                }
            }
            return nullptr;
        }

        TRegistry& Registry;
        TTypeConstructorSet& NonGeneric;
    };
}

TTypeConstructor::TPtr NYasm::NAldan::AnalyzeType(TDefinition& definition, const TRegistry& registry) {
    TRegistry copiedRegistry;
    copiedRegistry.reserve(registry.size());
    copiedRegistry.insert(registry.begin(), registry.end());
    TTypeConstructorSet nonGeneric;
    TContext context(copiedRegistry, nonGeneric);
    context.Analyze(definition);
    return definition.GetInferredType();
}
