/*
 * Copyright 2023 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#define UNSUBTYPING_DEBUG 0

#include <cstddef>

#if !UNSUBTYPING_DEBUG
#include <unordered_map>
#include <unordered_set>
#endif

#include "ir/subtype-exprs.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
#include "support/index.h"
#include "wasm-traversal.h"
#include "wasm-type.h"
#include "wasm.h"

#if UNSUBTYPING_DEBUG
#include "support/insert_ordered.h"
#endif

// Compute and use the minimal subtype relation required to maintain module
// validity and behavior. This minimal relation will be a subset of the original
// subtype relation. Start by walking the IR and collecting pairs of types that
// need to be in the subtype relation for each expression to validate. For
// example, a local.set requires that the type of its operand be a subtype of
// the local's type. Casts do not generate subtypings at this point because it
// is not necessary for the cast target to be a subtype of the cast source for
// the cast to validate.
//
// From that initial subtype relation, we then start finding new subtypings that
// are required by the subtypings we have found already. These transitively
// required subtypings come from two sources.
//
// The first source is type definitions. Consider these type definitions:
//
//   (type $A (sub (struct (ref $X))))
//   (type $B (sub $A (struct (ref $Y))))
//
// If we have determined that $B must remain a subtype of $A, then we know that
// $Y must remain a subtype of $X as well, since the type definitions would not
// be valid otherwise. Similarly, knowing that $X must remain a subtype of $Y
// may transitively require other subtypings as well based on their type
// definitions.
//
// The second source of transitive subtyping requirements is casts. Although
// casting from one type to another does not necessarily require that those
// types are related, we do need to make sure that we do not change the
// behavior of casts by removing subtype relationships they might observe. For
// example, consider this module:
//
// (module
//  ;; original subtyping: $bot <: $mid <: $top
//  (type $top (sub (struct)))
//  (type $mid (sub $top (struct)))
//  (type $bot (sub $mid (struct)))
//
//  (func $f
//   (local $top (ref $top))
//   (local $mid (ref $mid))
//
//   ;; Requires $bot <: $top
//   (local.set $top (struct.new $bot))
//
//   ;; Cast $top to $mid
//   (local.set $mid (ref.cast (ref $mid) (local.get $top)))
//  )
// )
//
// The only subtype relation directly required by the IR for this module is $bot
// <: $top. However, if we optimized the module so that $bot <: $top was the
// only subtype relation, we would change the behavior of the cast. In the
// original module, a value of type (ref $bot) is cast to (ref $mid). The cast
// succeeds because in the original module, $bot <: $mid. If we optimize so that
// we have $bot <: $top and no other subtypings, though, the cast will fail
// because the value of type (ref $bot) no longer inhabits (ref $mid). To
// prevent the cast's behavior from changing, we need to ensure that $bot <:
// $mid.
//
// The set of subtyping requirements generated by a cast from $src to $dest is
// that for every known remaining subtype $v of $src, if $v <: $dest in the
// original module, then $v <: $dest in the optimized module. In other words,
// for every type $v of values we know can flow into the cast, if the cast would
// have succeeded for values of type $v before, then we know the cast must
// continue to succeed for values of type $v. These requirements arising from
// casts can also generate transitive requirements because we learn about new
// types of values that can flow into casts as we learn about new subtypes of
// cast sources.
//
// Starting with the initial subtype relation determined by walking the IR,
// repeatedly search for new subtypings by analyzing type definitions and casts
// until we reach a fixed point. This is the minimal subtype relation that
// preserves module validity and behavior that can be found without a more
// precise analysis of types that might flow into each cast.

namespace wasm {

namespace {

#if UNSUBTYPING_DEBUG
template<typename K, typename V> using Map = InsertOrderedMap<K, V>;
template<typename T> using Set = InsertOrderedSet<T>;
#else
template<typename K, typename V> using Map = std::unordered_map<K, V>;
template<typename T> using Set = std::unordered_set<T>;
#endif
// A tree (or rather a forest) of types with the ability to query and set
// supertypes in constant time and efficiently iterate over supertypes and
// subtypes.
struct TypeTree {
  struct Node {
    // The type represented by this node.
    HeapType type;
    // The index of the parent (supertype) in the list of nodes. Set to the
    // index of this node if there is no parent.
    Index parent;
    // The index of this node in the parent's list of children, if any, enabling
    // O(1) updates.
    Index indexInParent = 0;
    // The indices of the children (subtypes) in the list of nodes.
    std::vector<Index> children;

    Node(HeapType type, Index index) : type(type), parent(index) {}
  };

  std::vector<Node> nodes;
  Map<HeapType, Index> indices;

  void setSupertype(HeapType sub, HeapType super) {
    auto childIndex = getIndex(sub);
    auto parentIndex = getIndex(super);
    auto& childNode = nodes[childIndex];
    auto& parentNode = nodes[parentIndex];
    // Remove sub from its old supertype if necessary.
    if (auto oldParentIndex = childNode.parent; oldParentIndex != childIndex) {
      auto& oldParentNode = nodes[oldParentIndex];
      // Move sub to the back of its parent's children and then pop it.
      auto& children = oldParentNode.children;
      assert(children[childNode.indexInParent] == childIndex);
      auto& swappedNode = nodes[children.back()];
      assert(swappedNode.indexInParent == children.size() - 1);
      // Swap the indices in the parent's child vector.
      std::swap(children[childNode.indexInParent], children.back());
      // Swap the index in the kept child.
      swappedNode.indexInParent = childNode.indexInParent;
      children.pop_back();
    }
    childNode.parent = parentIndex;
    childNode.indexInParent = parentNode.children.size();
    parentNode.children.push_back(childIndex);
  }

  std::optional<HeapType> getSupertype(HeapType type) {
    auto index = getIndex(type);
    auto parentIndex = nodes[index].parent;
    if (parentIndex == index) {
      return std::nullopt;
    }
    return nodes[parentIndex].type;
  }

  struct SupertypeIterator {
    using value_type = const HeapType;
    using difference_type = std::ptrdiff_t;
    using reference = const HeapType&;
    using pointer = const HeapType*;
    using iterator_category = std::input_iterator_tag;

    TypeTree* parent;
    std::optional<Index> index;

    bool operator==(const SupertypeIterator& other) {
      return index == other.index;
    }
    bool operator!=(const SupertypeIterator& other) {
      return !(*this == other);
    }
    const HeapType& operator*() const { return parent->nodes[*index].type; }
    const HeapType* operator->() const { return &*(*this); }
    SupertypeIterator& operator++() {
      auto parentIndex = parent->nodes[*index].parent;
      if (parentIndex == *index) {
        index = std::nullopt;
      } else {
        index = parentIndex;
      }
      return *this;
    }
    SupertypeIterator operator++(int) {
      auto it = *this;
      ++(*this);
      return it;
    }
  };

  struct Supertypes {
    TypeTree* parent;
    Index index;
    SupertypeIterator begin() { return {parent, index}; }
    SupertypeIterator end() { return {parent, std::nullopt}; }
  };

  Supertypes supertypes(HeapType type) { return {this, getIndex(type)}; }

  struct SubtypeIterator {
    using value_type = const HeapType;
    using difference_type = std::ptrdiff_t;
    using reference = const HeapType&;
    using pointer = const HeapType*;
    using iterator_category = std::input_iterator_tag;

    TypeTree* parent;

    // DFS stack of (node index, child index) pairs.
    std::vector<std::pair<Index, Index>> stack;

    bool operator==(const SubtypeIterator& other) {
      return stack == other.stack;
    }
    bool operator!=(const SubtypeIterator& other) { return !(*this == other); }
    const HeapType& operator*() const {
      return parent->nodes[stack.back().first].type;
    }
    const HeapType* operator->() const { return &*(*this); }
    SubtypeIterator& operator++() {
      while (true) {
        if (stack.empty()) {
          return *this;
        }
        auto& [index, childIndex] = stack.back();
        auto& children = parent->nodes[index].children;
        if (childIndex == children.size()) {
          stack.pop_back();
        } else {
          auto child = children[childIndex++];
          stack.push_back({child, 0u});
          return *this;
        }
      }
    }
    SubtypeIterator operator++(int) {
      auto it = *this;
      ++(*this);
      return it;
    }
  };

  struct Subtypes {
    TypeTree* parent;
    Index index;
    SubtypeIterator begin() { return {parent, {std::make_pair(index, 0u)}}; }
    SubtypeIterator end() { return {parent, {}}; }
  };

  Subtypes subtypes(HeapType type) { return {this, getIndex(type)}; }

private:
  Index getIndex(HeapType type) {
    auto [it, inserted] = indices.insert({type, nodes.size()});
    if (inserted) {
      nodes.emplace_back(type, nodes.size());
    }
    return it->second;
  }
};

struct Unsubtyping : Pass {
  // (sub, super) pairs that we have discovered but not yet processed.
  std::vector<std::pair<HeapType, HeapType>> work;

  // Record the type tree with supertype and subtype relations in such a way
  // that we can add new supertype relationships in constant time.
  TypeTree types;

  // Map from cast source types to their destinations.
  Map<HeapType, std::vector<HeapType>> casts;

  void run(Module* wasm) override {
    if (!wasm->features.hasGC()) {
      return;
    }

    // Initialize the subtype relation based on what is immediately required to
    // keep the code and public types valid.
    analyzePublicTypes(*wasm);
    analyzeModule(*wasm);

    // Find further subtypings and iterate to a fixed point.
    while (!work.empty()) {
      auto [sub, super] = work.back();
      work.pop_back();
      process(sub, super);
    }

    rewriteTypes(*wasm);

    // Cast types may be refinable if their source and target types are no
    // longer related. TODO: Experiment with running this only after checking
    // whether it is necessary.
    ReFinalize().run(getPassRunner(), wasm);
  }

  void noteSubtype(HeapType sub, HeapType super) {
    // Bottom types are uninteresting, but other basic heap types can be
    // interesting because of their interactions with casts.
    if (sub == super || sub.isBottom()) {
      return;
    }

    work.push_back({sub, super});
  }

  void noteSubtype(Type sub, Type super) {
    if (sub.isTuple()) {
      assert(super.isTuple() && sub.size() == super.size());
      for (size_t i = 0, size = sub.size(); i < size; ++i) {
        noteSubtype(sub[i], super[i]);
      }
      return;
    }
    if (!sub.isRef() || !super.isRef()) {
      return;
    }
    noteSubtype(sub.getHeapType(), super.getHeapType());
  }

  void analyzePublicTypes(Module& wasm) {
    // We cannot change supertypes for anything public.
    for (auto type : ModuleUtils::getPublicHeapTypes(wasm)) {
      if (auto super = type.getDeclaredSuperType()) {
        noteSubtype(type, *super);
      }
    }
  }

  void analyzeModule(Module& wasm) {
    struct Info {
      // (source, target) pairs for casts.
      Set<std::pair<HeapType, HeapType>> casts;

      // Observed (sub, super) subtype constraints.
      Set<std::pair<HeapType, HeapType>> subtypings;
    };

    struct Collector
      : ControlFlowWalker<Collector, SubtypingDiscoverer<Collector>> {
      Info& info;
      Collector(Info& info) : info(info) {}
      void noteSubtype(Type sub, Type super) {
        if (sub.isTuple()) {
          assert(super.isTuple() && sub.size() == super.size());
          for (size_t i = 0, size = sub.size(); i < size; ++i) {
            noteSubtype(sub[i], super[i]);
          }
          return;
        }
        if (!sub.isRef() || !super.isRef()) {
          return;
        }
        noteSubtype(sub.getHeapType(), super.getHeapType());
      }
      void noteSubtype(HeapType sub, HeapType super) {
        if (sub == super || sub.isBottom()) {
          return;
        }
        info.subtypings.insert({sub, super});
      }
      void noteSubtype(Type sub, Expression* super) {
        noteSubtype(sub, super->type);
      }
      void noteSubtype(Expression* sub, Type super) {
        noteSubtype(sub->type, super);
      }
      void noteSubtype(Expression* sub, Expression* super) {
        noteSubtype(sub->type, super->type);
      }
      void noteNonFlowSubtype(Expression* sub, Type super) {
        // This expression's type must be a subtype of |super|, but the value
        // does not flow anywhere - this is a static constraint. As the value
        // does not flow, it cannot reach anywhere else, which means we need
        // this in order to validate but it does not interact with casts. Given
        // that, if super is a basic type then we can simply ignore this: we
        // only remove subtyping between user types, so subtyping wrt basic
        // types is unchanged, and so this constraint will never be a problem.
        //
        // This is sort of a hack because in general to be precise we should not
        // just consider basic types here - in general, we should note for each
        // constraint whether it is a flow-based one or not, and only take the
        // flow-based ones into account when looking at the impact of casts.
        // However, in practice this is enough as the only non-trivial case of
        // |noteNonFlowSubtype| is for RefEq, which uses a basic type (eqref).
        // Other cases of non-flow subtyping end up trivial, e.g., the target of
        // a CallRef is compared to itself (and we ignore constraints of A :>
        // A). However, if we change how |noteNonFlowSubtype| is used in
        // SubtypingDiscoverer then we may need to generalize this.
        if (super.isRef() && super.getHeapType().isBasic()) {
          return;
        }

        // Otherwise, we must take this into account.
        noteSubtype(sub, super);
      }
      void noteCast(HeapType src, HeapType dst) {
        // Casts to self and casts that must fail because they have incompatible
        // types are uninteresting.
        if (dst == src) {
          return;
        }
        if (HeapType::isSubType(dst, src)) {
          info.casts.insert({src, dst});
          return;
        }
        if (HeapType::isSubType(src, dst)) {
          // This is an upcast that will always succeed, but only if we ensure
          // src <: dst.
          info.subtypings.insert({src, dst});
        }
      }
      void noteCast(Expression* src, Type dst) {
        if (src->type.isRef() && dst.isRef()) {
          noteCast(src->type.getHeapType(), dst.getHeapType());
        }
      }
      void noteCast(Expression* src, Expression* dst) {
        if (src->type.isRef() && dst->type.isRef()) {
          noteCast(src->type.getHeapType(), dst->type.getHeapType());
        }
      }
    };

    // Collect subtyping constraints and casts from functions in parallel.
    ModuleUtils::ParallelFunctionAnalysis<Info> analysis(
      wasm, [&](Function* func, Info& info) {
        if (!func->imported()) {
          Collector(info).walkFunctionInModule(func, &wasm);
        }
      });

    Info collectedInfo;
    for (auto& [_, info] : analysis.map) {
      collectedInfo.casts.insert(info.casts.begin(), info.casts.end());
      collectedInfo.subtypings.insert(info.subtypings.begin(),
                                      info.subtypings.end());
    }

    // Collect constraints from module-level code as well.
    Collector collector(collectedInfo);
    collector.walkModuleCode(&wasm);
    collector.setModule(&wasm);
    for (auto& global : wasm.globals) {
      collector.visitGlobal(global.get());
    }
    for (auto& segment : wasm.elementSegments) {
      collector.visitElementSegment(segment.get());
    }

    // Prepare the collected information for the upcoming processing loop.
    for (auto& [sub, super] : collectedInfo.subtypings) {
      noteSubtype(sub, super);
    }
    for (auto [src, dst] : collectedInfo.casts) {
      casts[src].push_back(dst);
    }
  }

  void process(HeapType sub, HeapType super) {
    auto oldSuper = types.getSupertype(sub);
    if (oldSuper) {
      // We already had a recorded supertype. The new supertype might be
      // deeper,shallower, or equal to the old supertype. We must recursively
      // note the relationship between the old and new supertypes.
      if (super == *oldSuper) {
        // Nothing new to do here.
        return;
      }
      if (HeapType::isSubType(*oldSuper, super)) {
        // sub <: oldSuper <: super
        processDescribed(sub, *oldSuper, super);
        noteSubtype(*oldSuper, super);
        // We already handled sub <: oldSuper, so we're done.
        return;
      }
      // sub <: super <: oldSuper
      // Eagerly process super <: oldSuper first. This ensures that sub and
      // super will already be in the same tree when we process them below, so
      // when we process casts we will know that we only need to process up to
      // oldSuper.
      processDescribed(sub, super, *oldSuper);
      process(super, *oldSuper);
    }

    types.setSupertype(sub, super);

    // We have a new supertype. Find the implied subtypings from the type
    // definitions and casts.
    processDefinitions(sub, super);
    processCasts(sub, super, oldSuper);
  }

  void processDescribed(HeapType sub, HeapType mid, HeapType super) {
    // We are establishing sub <: mid <: super. If super describes the immediate
    // supertype of the type sub describes, then once we insert mid between them
    // we would have this:
    //
    // A -> super
    // ^     ^
    // |    mid
    // |     ^
    // C -> sub
    //
    // This violates the requirement that the descriptor of C's immediate
    // supertype must be the immediate supertype of C's descriptor. To fix it,
    // we have to find the type B that mid describes and insert it between A and
    // C:
    //
    // A -> super
    // ^     ^
    // B -> mid
    // ^     ^
    // C -> sub
    //
    // We do this eagerly before we establish sub <: mid <: super so that if
    // establishing that subtyping requires recursively establishing other
    // subtypings, we can depend on the invariant that the described types are
    // always set up correctly beforehand.
    auto subDescribed = sub.getDescribedType();
    auto superDescribed = super.getDescribedType();
    if (subDescribed && superDescribed &&
        types.getSupertype(*subDescribed) == superDescribed) {
      auto midDescribed = mid.getDescribedType();
      assert(midDescribed);
      process(*subDescribed, *midDescribed);
    }
  }

  void processDefinitions(HeapType sub, HeapType super) {
    if (super.isBasic()) {
      return;
    }
    switch (sub.getKind()) {
      case HeapTypeKind::Func: {
        auto sig = sub.getSignature();
        auto superSig = super.getSignature();
        noteSubtype(superSig.params, sig.params);
        noteSubtype(sig.results, superSig.results);
        break;
      }
      case HeapTypeKind::Struct: {
        const auto& fields = sub.getStruct().fields;
        const auto& superFields = super.getStruct().fields;
        for (size_t i = 0, size = superFields.size(); i < size; ++i) {
          noteSubtype(fields[i].type, superFields[i].type);
        }
        break;
      }
      case HeapTypeKind::Array: {
        auto elem = sub.getArray().element;
        noteSubtype(elem.type, super.getArray().element.type);
        break;
      }
      case HeapTypeKind::Cont:
        WASM_UNREACHABLE("TODO: cont");
      case HeapTypeKind::Basic:
        WASM_UNREACHABLE("unexpected kind");
    }
    if (auto desc = sub.getDescriptorType()) {
      if (auto superDesc = super.getDescriptorType()) {
        noteSubtype(*desc, *superDesc);
      }
    }
  }

  void
  processCasts(HeapType sub, HeapType super, std::optional<HeapType> oldSuper) {
    // We are either attaching the one tree rooted at `sub` under a new
    // supertype in another tree, or we are reparenting `sub` below a
    // descendent of `oldSuper` in the same tree. In the former case, we must
    // evaluate `sub` and all its subtypes against all its new supertypes and
    // their cast destinations. In the latter case, `sub` and all its subtypes
    // must have already been evaluated against `oldSuper` and its supertypes,
    // so we only need to additionally evaluate them against supertypes up to
    // `oldSuper`.
    for (auto type : types.subtypes(sub)) {
      for (auto src : types.supertypes(super)) {
        if (oldSuper && src == *oldSuper) {
          break;
        }
        for (auto dst : casts[src]) {
          if (HeapType::isSubType(type, dst)) {
            noteSubtype(type, dst);
          }
        }
      }
    }
  }

  void rewriteTypes(Module& wasm) {
    struct Rewriter : GlobalTypeRewriter {
      Unsubtyping& parent;
      Rewriter(Unsubtyping& parent, Module& wasm)
        : GlobalTypeRewriter(wasm), parent(parent) {}
      std::optional<HeapType> getDeclaredSuperType(HeapType type) override {
        if (auto super = parent.types.getSupertype(type);
            super && !super->isBasic()) {
          return *super;
        }
        return std::nullopt;
      }
    };
    Rewriter(*this, wasm).update();
  }
};

} // anonymous namespace

Pass* createUnsubtypingPass() { return new Unsubtyping(); }

} // namespace wasm
