# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest

import torch
from executorch.backends.test.suite.flow import TestFlow

from executorch.backends.test.suite.operators import (
    dtype_test,
    operator_test,
    OperatorTest,
)


class TruncModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.trunc(x)


@operator_test
class TestTrunc(OperatorTest):
    @dtype_test
    def test_trunc_dtype(self, flow: TestFlow, dtype) -> None:
        # Test with different dtypes
        model = TruncModel().to(dtype)
        self._test_op(model, (torch.rand(10, 10).to(dtype) * 10 - 5,), flow)

    def test_trunc_shapes(self, flow: TestFlow) -> None:
        # Test with different tensor shapes

        # 1D tensor
        self._test_op(TruncModel(), (torch.randn(20) * 5,), flow)

        # 2D tensor
        self._test_op(TruncModel(), (torch.randn(5, 10) * 5,), flow)

        # 3D tensor
        self._test_op(TruncModel(), (torch.randn(3, 4, 5) * 5,), flow)

    @unittest.skip("NaN and Inf are not enforced for backends.")
    def test_trunc_edge_cases(self, flow: TestFlow) -> None:
        # Test edge cases

        # Integer values (should remain unchanged)
        self._test_op(
            TruncModel(),
            (torch.arange(-5, 6).float(),),
            flow,
            generate_random_test_inputs=False,
        )

        # Values with different fractional parts
        x = torch.tensor(
            [-2.9, -2.5, -2.1, -0.9, -0.5, -0.1, 0.0, 0.1, 0.5, 0.9, 2.1, 2.5, 2.9]
        )
        self._test_op(TruncModel(), (x,), flow, generate_random_test_inputs=False)

        # Tensor with infinity
        x = torch.tensor([float("inf"), float("-inf"), 1.4, -1.4])
        self._test_op(TruncModel(), (x,), flow, generate_random_test_inputs=False)

        # Tensor with NaN
        x = torch.tensor([float("nan"), 1.4, -1.4])
        self._test_op(TruncModel(), (x,), flow, generate_random_test_inputs=False)
