#!/usr/bin/env python
# Copyright 2019 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import codecs
import math
import os
import re
import sys
import yaml

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import xngen
import xnncommon


parser = argparse.ArgumentParser(
    description="Vector binary operation microkernel test generator"
)
parser.add_argument(
    "-t",
    "--tester",
    metavar="TESTER",
    required=True,
    choices=["VCMulMicrokernelTester", "VBinaryMicrokernelTester"],
    help="Tester class to be used in the generated test",
)
parser.add_argument(
    "-b",
    "--broadcast_b",
    action="store_true",
    help="Broadcast the RHS of the operation",
)
parser.add_argument("-k", "--ukernel", required=True, help="Microkernel type")
parser.add_argument(
    "-o",
    "--output",
    metavar="FILE",
    required=True,
    help="Output (C++ source) file",
)
parser.set_defaults(defines=list())

OP_TYPES = {
    "vadd": "Add",
    "vaddc": "Add",
    "vcopysign": "CopySign",
    "vcopysignc": "CopySign",
    "vrcopysign": "RCopySign",
    "vrcopysignc": "RCopySign",
    "vdiv": "Div",
    "vdivc": "Div",
    "vrdiv": "RDiv",
    "vrdivc": "RDiv",
    "vmax": "Max",
    "vmaxc": "Max",
    "vmin": "Min",
    "vminc": "Min",
    "vmul": "Mul",
    "vmulc": "Mul",
    "vcmul": "CMul",
    "vsub": "Sub",
    "vsubc": "Sub",
    "vrsub": "RSub",
    "vrsubc": "RSub",
    "vsqrdiff": "SqrDiff",
    "vsqrdiffc": "SqrDiff",
    "vprelu": "Prelu",
    "vpreluc": "Prelu",
    "vrpreluc": "RPrelu",
}

BINOP_TEST_TEMPLATE = """
#define XNN_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, datatype, params_type, init_params)
XNN_TEST_BINARY_BATCH_EQ(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
XNN_TEST_BINARY_BATCH_DIV(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
XNN_TEST_BINARY_BATCH_LT(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
XNN_TEST_BINARY_BATCH_GT(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});

$if TESTER in ["VMulCMicrokernelTester"]:
  XNN_TEST_BINARY_INPLACE_A(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
$elif ${BROADCAST_B} == "true":
  XNN_TEST_BINARY_INPLACE_A(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
$else:
  XNN_TEST_BINARY_INPLACE_A(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_INPLACE_B(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_INPLACE_A_AND_B(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});

$if DATATYPE.startswith("q"):
  XNN_TEST_BINARY_A_ZERO_POINT(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_B_ZERO_POINT(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_Y_ZERO_POINT(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_A_SCALE(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_B_SCALE(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_Y_SCALE(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});

$if "minmax" in ACTIVATION_TYPE:
  XNN_TEST_BINARY_QMIN(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
  XNN_TEST_BINARY_QMAX(ukernel, arch_flags, batch_tile, ${BROADCAST_B}, datatype, ${", ".join(TEST_ARGS)});
"""


def main(args):
  options = parser.parse_args(args)

  tester = options.tester
  tester_header = {
      "VCMulMicrokernelTester": "vcmul-microkernel-tester.h",
      "VBinaryMicrokernelTester": "vbinary-microkernel-tester.h",
  }[tester]
  tests = """\
// clang-format off
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
//   Microkernel: {microkernel}
//   Generator: {generator}


#include "src/xnnpack/microparams-init.h"
#include "src/xnnpack/vbinary.h"
#include "test/{tester_header}"

""".format(
      microkernel=options.ukernel,
      generator=sys.argv[0],
      tester_header=tester_header,
  )

  ukernel_parts = options.ukernel.split("-")
  datatype = ukernel_parts[0]
  op = ukernel_parts[1]
  activation = ukernel_parts[2] if len(ukernel_parts) >= 3 else ""

  broadcast_b = False
  if op[-1] == "c":
    broadcast_b = True
  op_type = OP_TYPES[op]

  test_args = ["ukernel"]
  if tester in ["VBinaryMicrokernelTester"]:
    if datatype in ["qs8", "qu8"] and op in ["vprelu", "vpreluc", "vrpreluc"]:
      op_type = "Prelu" if op in ["vprelu", "vpreluc"] else "RPrelu"
      test_args.append("%s::OpType::%s" % (tester, op_type))
    elif not datatype in ["qs8", "qu8"]:
      test_args.append("%s::OpType::%s" % (tester, op_type))
  test_args.append("init_params")
  tests += xnncommon.make_multiline_macro(
      xngen.preprocess(
          BINOP_TEST_TEMPLATE,
          {
              "TEST_ARGS": test_args,
              "TESTER": tester,
              "BROADCAST_B": str(broadcast_b).lower(),
              "DATATYPE": datatype,
              "OP_TYPE": op_type,
              "ACTIVATION_TYPE": activation,
          },
      )
  )

  folder = datatype + "-" + ("vbinary" if datatype.startswith("f") else op)
  tests += (
      f'#include "src/{folder}/{options.ukernel}.inc"\n'
  )
  tests += "#undef XNN_UKERNEL\n"
  tests = tests.replace("s32-vmulc/s32-vmulc.inc", "s32-vmul/s32-vmulc.inc")

  xnncommon.overwrite_if_changed(options.output, tests)


if __name__ == "__main__":
  main(sys.argv[1:])
