/*
 * Copyright 2016 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Implementation of the shell interpreter execution environment
//

#ifndef wasm_shell_interface_h
#define wasm_shell_interface_h

#include "asmjs/shared-constants.h"
#include "interpreter/exception.h"
#include "ir/module-utils.h"
#include "shared-constants.h"
#include "support/name.h"
#include "support/utilities.h"
#include "wasm-interpreter.h"
#include "wasm.h"

namespace wasm {

struct ShellExternalInterface : ModuleRunner::ExternalInterface {
  // The underlying memory can be accessed through unaligned pointers which
  // isn't well-behaved in C++. WebAssembly nonetheless expects it to behave
  // properly. Avoid emitting unaligned load/store by checking for alignment
  // explicitly, and performing memcpy if unaligned.
  //
  // The allocated memory tries to have the same alignment as the memory being
  // simulated.
  class Memory {
    // Use char because it doesn't run afoul of aliasing rules.
    std::vector<char> memory;
    template<typename T> static bool aligned(const char* address) {
      static_assert(!(sizeof(T) & (sizeof(T) - 1)), "must be a power of 2");
      return 0 == (reinterpret_cast<uintptr_t>(address) & (sizeof(T) - 1));
    }

  public:
    Memory() = default;
    void resize(size_t newSize) {
      // Ensure the smallest allocation is large enough that most allocators
      // will provide page-aligned storage. This hopefully allows the
      // interpreter's memory to be as aligned as the memory being simulated,
      // ensuring that the performance doesn't needlessly degrade.
      //
      // The code is optimistic this will work until WG21's p0035r0 happens.
      const size_t minSize = 1 << 12;
      size_t oldSize = memory.size();
      memory.resize(std::max(minSize, newSize));
      if (newSize < oldSize && newSize < minSize) {
        std::memset(&memory[newSize], 0, minSize - newSize);
      }
    }
    template<typename T> void set(size_t address, T value) {
      if (aligned<T>(&memory[address])) {
        *reinterpret_cast<T*>(&memory[address]) = value;
      } else {
        std::memcpy(&memory[address], &value, sizeof(T));
      }
    }
    template<typename T> T get(size_t address) {
      if (aligned<T>(&memory[address])) {
        return *reinterpret_cast<T*>(&memory[address]);
      } else {
        T loaded;
        std::memcpy(&loaded, &memory[address], sizeof(T));
        return loaded;
      }
    }
  };

  std::map<Name, Memory> memories;
  std::map<Name, std::shared_ptr<ModuleRunner>> linkedInstances;

  ShellExternalInterface(
    std::map<Name, std::shared_ptr<ModuleRunner>> linkedInstances_ = {}) {
    linkedInstances.swap(linkedInstances_);
  }
  virtual ~ShellExternalInterface() = default;

  ModuleRunner* getImportInstanceOrNull(Importable* import) {
    auto it = linkedInstances.find(import->module);
    if (it == linkedInstances.end()) {
      return nullptr;
    }
    return it->second.get();
  }

  ModuleRunner* getImportInstance(Importable* import) {
    auto* ret = getImportInstanceOrNull(import);
    if (!ret) {
      Fatal() << "getImportInstance: unknown import: " << import->module.str
              << "." << import->base.str;
    }
    return ret;
  }

  void init(Module& wasm, ModuleRunner& instance) override {
    ModuleUtils::iterDefinedMemories(wasm, [&](wasm::Memory* memory) {
      auto shellMemory = Memory();
      shellMemory.resize(memory->initial << memory->pageSizeLog2);
      memories[memory->name] = shellMemory;
    });
  }

  Literal getImportedFunction(Function* import) override {
    // TODO: We should perhaps restrict the types with which the well-known
    // functions can be imported.
    if (import->module == SPECTEST && import->base.startsWith(PRINT)) {
      // Use a null instance because these are host functions.
      return Literal(
        std::make_shared<FuncData>(import->name,
                                   nullptr,
                                   [](const Literals& arguments) -> Flow {
                                     for (auto argument : arguments) {
                                       std::cout << argument << " : "
                                                 << argument.type << '\n';
                                     }
                                     return Flow();
                                   }),
        import->type);
    } else if (import->module == ENV && import->base == EXIT) {
      return Literal(std::make_shared<FuncData>(import->name,
                                                nullptr,
                                                [](const Literals&) -> Flow {
                                                  // XXX hack for torture tests
                                                  std::cout << "exit()\n";
                                                  throw ExitException();
                                                }),
                     import->type);
    } else if (auto* inst = getImportInstanceOrNull(import)) {
      return inst->getExportedFunction(import->base);
    }
    // This is not a known import. Create a literal for it, which is good enough
    // if it is never called (see the ref_func.wast spec test, which does that).
    std::cerr << "warning: getImportedFunction: unknown import: "
              << import->module.str << "." << import->name.str << '\n';
    return Literal::makeFunc(import->name, import->type);
  }

  int8_t load8s(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<int8_t>(addr);
  }
  uint8_t load8u(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<uint8_t>(addr);
  }
  int16_t load16s(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<int16_t>(addr);
  }
  uint16_t load16u(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<uint16_t>(addr);
  }
  int32_t load32s(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<int32_t>(addr);
  }
  uint32_t load32u(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<uint32_t>(addr);
  }
  int64_t load64s(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<int64_t>(addr);
  }
  uint64_t load64u(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<uint64_t>(addr);
  }
  std::array<uint8_t, 16> load128(Address addr, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    return memory.get<std::array<uint8_t, 16>>(addr);
  }

  void store8(Address addr, int8_t value, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    memory.set<int8_t>(addr, value);
  }
  void store16(Address addr, int16_t value, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    memory.set<int16_t>(addr, value);
  }
  void store32(Address addr, int32_t value, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    memory.set<int32_t>(addr, value);
  }
  void store64(Address addr, int64_t value, Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    memory.set<int64_t>(addr, value);
  }
  void store128(Address addr,
                const std::array<uint8_t, 16>& value,
                Name memoryName) override {
    auto it = memories.find(memoryName);
    assert(it != memories.end());
    auto& memory = it->second;
    memory.set<std::array<uint8_t, 16>>(addr, value);
  }
  bool
  growMemory(Name memoryName, Address /*oldSize*/, Address newSize) override {
    // Apply a reasonable limit on memory size, 1GB, to avoid DOS on the
    // interpreter.
    if (newSize > 1024 * 1024 * 1024) {
      return false;
    }
    auto it = memories.find(memoryName);
    if (it == memories.end()) {
      trap("growMemory on non-existing memory");
    }
    auto& memory = it->second;
    memory.resize(newSize);
    return true;
  }
  void trap(std::string_view why) override {
    std::cout << "[trap " << why << "]\n";
    throw TrapException();
  }

  void hostLimit(std::string_view why) override {
    std::cout << "[host limit " << why << "]\n";
    throw HostLimitException();
  }
  void throwException(const WasmException& exn) override { throw exn; }
};

} // namespace wasm

#endif // wasm_shell_interface_h
