#include <audioapi/core/AudioNode.h>
#include <audioapi/core/AudioParam.h>
#include <audioapi/core/BaseAudioContext.h>
#include <audioapi/core/utils/AudioGraphManager.h>
#include <audioapi/types/NodeOptions.h>
#include <audioapi/utils/AudioArray.hpp>

#include <memory>

namespace audioapi {

AudioNode::AudioNode(
    const std::shared_ptr<BaseAudioContext> &context,
    const AudioNodeOptions &options)
    : context_(context),
      numberOfInputs_(options.numberOfInputs),
      numberOfOutputs_(options.numberOfOutputs),
      channelCount_(options.channelCount),
      channelCountMode_(options.channelCountMode),
      channelInterpretation_(options.channelInterpretation),
      requiresTailProcessing_(options.requiresTailProcessing) {
  audioBuffer_ = std::make_shared<DSPAudioBuffer>(
      RENDER_QUANTUM_SIZE, channelCount_, context->getSampleRate());
}

AudioNode::~AudioNode() {
  if (isInitialized_.load(std::memory_order_acquire)) {
    cleanup();
  }
}

bool AudioNode::canBeDestructed() const {
  return true;
}

size_t AudioNode::getChannelCount() const {
  return channelCount_;
}

void AudioNode::connect(
    const std::shared_ptr<AudioNode>
        &node) { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    context->getGraphManager()->addPendingNodeConnection(
        shared_from_this(), node, AudioGraphManager::ConnectionType::CONNECT);
  }
}

void AudioNode::connect(
    const std::shared_ptr<AudioParam>
        &param) { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    context->getGraphManager()->addPendingParamConnection(
        shared_from_this(), param, AudioGraphManager::ConnectionType::CONNECT);
  }
}

void AudioNode::disconnect() { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    context->getGraphManager()->addPendingNodeConnection(
        shared_from_this(), nullptr, AudioGraphManager::ConnectionType::DISCONNECT_ALL);
  }
}

void AudioNode::disconnect(
    const std::shared_ptr<AudioNode>
        &node) { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    context->getGraphManager()->addPendingNodeConnection(
        shared_from_this(), node, AudioGraphManager::ConnectionType::DISCONNECT);
  }
}

void AudioNode::disconnect(
    const std::shared_ptr<AudioParam>
        &param) { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    context->getGraphManager()->addPendingParamConnection(
        shared_from_this(), param, AudioGraphManager::ConnectionType::DISCONNECT);
  }
}

bool AudioNode::isEnabled() const {
  return isEnabled_;
}

bool AudioNode::requiresTailProcessing() const {
  return requiresTailProcessing_;
}

void AudioNode::enable() {
  if (isEnabled()) {
    return;
  }

  isEnabled_ = true;

  for (auto it = outputNodes_.begin(), end = outputNodes_.end(); it != end; ++it) {
    it->get()->onInputEnabled();
  }
}

void AudioNode::disable() {
  if (!isEnabled()) {
    return;
  }

  isEnabled_ = false;

  for (auto it = outputNodes_.begin(), end = outputNodes_.end(); it != end; ++it) {
    it->get()->onInputDisabled();
  }
}

std::shared_ptr<DSPAudioBuffer> AudioNode::processAudio(
    const std::shared_ptr<DSPAudioBuffer> &outputBuffer,
    int framesToProcess,
    bool checkIsAlreadyProcessed) {
  if (!isInitialized_.load(std::memory_order_acquire)) {
    return outputBuffer;
  }

  if (checkIsAlreadyProcessed && isAlreadyProcessed()) {
    return audioBuffer_;
  }

  // Process inputs and return the buffer with the most channels.
  auto processingBuffer = processInputs(outputBuffer, framesToProcess, checkIsAlreadyProcessed);

  // Apply channel count mode.
  processingBuffer = applyChannelCountMode(processingBuffer);

  // Mix all input buffers into the processing buffer.
  mixInputsBuffers(processingBuffer);

  assert(processingBuffer != nullptr);

  // Finally, process the node itself.
  return processNode(processingBuffer, framesToProcess);
}

bool AudioNode::isAlreadyProcessed() { // NOLINT(readability-convert-member-functions-to-static)
  if (std::shared_ptr<BaseAudioContext> context = context_.lock()) {
    std::size_t currentSampleFrame = context->getCurrentSampleFrame();

    // check if the node has already been processed for this rendering quantum
    if (currentSampleFrame == lastRenderedFrame_) {
      return true;
    }

    // Update the last rendered frame before processing node and its inputs.
    lastRenderedFrame_ = currentSampleFrame;

    return false;
  }

  // If context is invalid, consider it as already processed to avoid processing
  return true; // NOLINT(readability-simplify-boolean-expr)
}

std::shared_ptr<DSPAudioBuffer> AudioNode::processInputs(
    const std::shared_ptr<DSPAudioBuffer> &outputBuffer,
    int framesToProcess,
    bool checkIsAlreadyProcessed) { // NOLINT(readability-convert-member-functions-to-static)
  auto processingBuffer = audioBuffer_;
  processingBuffer->zero();

  size_t maxNumberOfChannels = 0;
  for (auto *inputNode : inputNodes_) {
    assert(inputNode != nullptr);

    if (!inputNode->isEnabled()) {
      continue;
    }

    auto inputBuffer =
        inputNode->processAudio(outputBuffer, framesToProcess, checkIsAlreadyProcessed);
    inputBuffers_.push_back(inputBuffer);

    if (maxNumberOfChannels < inputBuffer->getNumberOfChannels()) {
      maxNumberOfChannels = inputBuffer->getNumberOfChannels();
      processingBuffer = inputBuffer;
    }
  }

  return processingBuffer;
}

std::shared_ptr<DSPAudioBuffer> AudioNode::applyChannelCountMode(
    const std::shared_ptr<DSPAudioBuffer> &processingBuffer) {
  // If the channelCountMode is EXPLICIT, the node should output the number of
  // channels specified by the channelCount.
  if (channelCountMode_ == ChannelCountMode::EXPLICIT) {
    return audioBuffer_;
  }

  // If the channelCountMode is CLAMPED_MAX, the node should output the maximum
  // number of channels clamped to channelCount.
  if (channelCountMode_ == ChannelCountMode::CLAMPED_MAX &&
      processingBuffer->getNumberOfChannels() >= channelCount_) {
    return audioBuffer_;
  }

  return processingBuffer;
}

void AudioNode::mixInputsBuffers(const std::shared_ptr<DSPAudioBuffer> &processingBuffer) {
  assert(processingBuffer != nullptr);

  for (auto it = inputBuffers_.begin(), end = inputBuffers_.end(); it != end; ++it) {
    processingBuffer->sum(**it, channelInterpretation_);
  }

  inputBuffers_.clear();
}

void AudioNode::connectNode(const std::shared_ptr<AudioNode> &node) {
  auto position = outputNodes_.find(node);

  if (position == outputNodes_.end()) {
    outputNodes_.insert(node);
    node->onInputConnected(this);
  }
}

void AudioNode::connectParam(const std::shared_ptr<AudioParam> &param) {
  auto position = outputParams_.find(param);

  if (position == outputParams_.end()) {
    outputParams_.insert(param);
    param->addInputNode(this);
  }
}

void AudioNode::disconnectNode(const std::shared_ptr<AudioNode> &node) {
  auto position = outputNodes_.find(node);

  if (position != outputNodes_.end()) {
    node->onInputDisconnected(this);
    outputNodes_.erase(node);
  }
}

void AudioNode::disconnectParam(const std::shared_ptr<AudioParam> &param) {
  auto position = outputParams_.find(param);

  if (position != outputParams_.end()) {
    param->removeInputNode(this);
    outputParams_.erase(param);
  }
}

void AudioNode::onInputEnabled() {
  numberOfEnabledInputNodes_ += 1;

  if (!isEnabled()) {
    enable();
  }
}

void AudioNode::onInputDisabled() {
  numberOfEnabledInputNodes_ -= 1;

  if (isEnabled() && numberOfEnabledInputNodes_ == 0) {
    disable();
  }
}

void AudioNode::onInputConnected(AudioNode *node) {
  if (!isInitialized_.load(std::memory_order_acquire)) {
    return;
  }

  inputNodes_.insert(node);

  if (node->isEnabled()) {
    onInputEnabled();
  }
}

void AudioNode::onInputDisconnected(AudioNode *node) {
  if (!isInitialized_.load(std::memory_order_acquire)) {
    return;
  }

  if (node->isEnabled()) {
    onInputDisabled();
  }

  auto position = inputNodes_.find(node);

  if (position != inputNodes_.end()) {
    inputNodes_.erase(position);
  }
}

void AudioNode::cleanup() {
  isInitialized_.store(false, std::memory_order_release);

  for (auto it = outputNodes_.begin(), end = outputNodes_.end(); it != end; ++it) {
    it->get()->onInputDisconnected(this);
  }

  outputNodes_.clear();
}

} // namespace audioapi
