# Copyright (c) 2017, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from coremltools import proto

from .. import SPECIFICATION_VERSION
from . import datatypes
from ._interface_management import set_transform_interface_params


def create_array_feature_extractor(
    input_features, output_name, extract_indices, output_type=None
):
    """
    Creates a feature extractor from an input array ``(feature, return)``.

    Parameters
    ----------
    input_features:
        A list of one ``(name, array)`` tuple.

    extract_indices:
        Either an integer or a list.
        If it's an integer, the output type is by default a double (but may also be an integer).
        If a list, the output type is an array.
    """

    # Make sure that our starting stuff is in the proper form.
    assert len(input_features) == 1
    assert isinstance(input_features[0][1], datatypes.Array)

    # Create the model.
    spec = proto.Model_pb2.Model()
    spec.specificationVersion = SPECIFICATION_VERSION

    if isinstance(extract_indices, int):
        extract_indices = [extract_indices]
        if output_type is None:
            output_type = datatypes.Double()

    elif isinstance(extract_indices, (list, tuple)):
        if not all(isinstance(x, int) for x in extract_indices):
            raise TypeError("extract_indices must be an integer or a list of integers.")

        if output_type is None:
            output_type = datatypes.Array(len(extract_indices))

    else:
        raise TypeError("extract_indices must be an integer or a list of integers.")

    output_features = [(output_name, output_type)]

    for idx in extract_indices:
        assert idx < input_features[0][1].num_elements
        spec.arrayFeatureExtractor.extractIndex.append(idx)

    set_transform_interface_params(spec, input_features, output_features)

    return spec
