// Copyright (c) ONNX Project Contributors
//
// SPDX-License-Identifier: Apache-2.0

// Adapter for Scatter in default domain from version 10 to 11

#pragma once

#include <memory>

#include "onnx/version_converter/adapters/adapter.h"

namespace ONNX_NAMESPACE {
namespace version_conversion {

class Scatter_10_11 final : public Adapter {
 public:
  explicit Scatter_10_11() : Adapter("Scatter", OpSetID(10), OpSetID(11)) {}

  Node* adapt_scatter_10_11(const std::shared_ptr<Graph>& graph, Node* node) const {
    const ArrayRef<Value*>& inputs = node->inputs();
    ONNX_ASSERTM(inputs.size() >= 3, "Scatter in opset 10 needs to have at least 3 inputs.")

    int axis = node->hasAttribute(kaxis) ? node->i(kaxis) : 0;

    // Replace the node with an equivalent ScatterElements node
    Node* scatter_elements = graph->create(kScatterElements);
    scatter_elements->i_(kaxis, axis);
    scatter_elements->addInput(inputs[0]);
    scatter_elements->addInput(inputs[1]);
    scatter_elements->addInput(inputs[2]);
    node->replaceAllUsesWith(scatter_elements);

    scatter_elements->insertBefore(node);
    node->destroy();

    return scatter_elements;
  }

  Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
    return adapt_scatter_10_11(graph, node);
  }
};

} // namespace version_conversion
} // namespace ONNX_NAMESPACE
