# Part of the implementation is borrowed and modified from THUMT,
# publicly available at https://github.com/THUNLP-MT/THUMT
# Copyright 2017-2022 The Alibaba MT Team Authors. All rights reserved.
import math
from collections import namedtuple
from typing import Dict

import tensorflow as tf

from modelscope.metainfo import Models
from modelscope.models.base import Model, Tensor
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['CsanmtForTranslation']


@MODELS.register_module(Tasks.translation, module_name=Models.translation)
class CsanmtForTranslation(Model):

    def __init__(self, model_dir, *args, **kwargs):
        """
        Args:
            params (dict): the model configuration.
        """
        super().__init__(model_dir, *args, **kwargs)
        self.params = kwargs
        print(self.params)

    def __call__(self,
                 input: Dict[str, Tensor],
                 label: Dict[str, Tensor] = None,
                 prefix: Dict[str, Tensor] = None,
                 prefix_hit: Dict[bool, Tensor] = None) -> Dict[str, Tensor]:
        """return the result by the model

        Args:
            input: the preprocessed input source sequence
            label: the ground truth target data for model training
            prefix: the preprocessed input target prefix sequence for interactive translation
            prefix_hit: the preprocessed target prefix subword vector for interactive translation

        Returns:
            output_seqs: output sequence of target ids
        """
        if label is None:
            with tf.compat.v1.variable_scope('NmtModel'):
                output_seqs, output_scores = self.beam_search(
                    {
                        'input_wids': input,
                        'prefix_wids': prefix,
                        'prefix_hit': prefix_hit
                    }, self.params)
            return {
                'output_seqs': output_seqs,
                'output_scores': output_scores,
            }
        else:
            train_op, loss = self.transformer_model_train_fn(input, label)
            return {
                'train_op': train_op,
                'loss': loss,
            }

    def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
        """
        Run the forward pass for a model.

        Args:
            input (Dict[str, Tensor]): the dict of the model inputs for the forward method

        Returns:
            Dict[str, Tensor]: output from the model forward pass
        """
        ...

    def encoding_graph(self, features, params):
        src_vocab_size = params['src_vocab_size']
        hidden_size = params['hidden_size']

        initializer = tf.compat.v1.random_normal_initializer(
            0.0, hidden_size**-0.5, dtype=tf.float32)

        if params['shared_source_target_embedding']:
            with tf.compat.v1.variable_scope(
                    'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
                src_embedding = tf.compat.v1.get_variable(
                    'Weights', [src_vocab_size, hidden_size],
                    initializer=initializer)
        else:
            with tf.compat.v1.variable_scope('Source_Embedding'):
                src_embedding = tf.compat.v1.get_variable(
                    'Weights', [src_vocab_size, hidden_size],
                    initializer=initializer)
        src_bias = tf.compat.v1.get_variable('encoder_input_bias',
                                             [hidden_size])

        eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
        src_seq = tf.concat([features, eos_padding], 1)
        src_mask = tf.cast(tf.not_equal(src_seq, 0), dtype=tf.float32)
        shift_src_mask = src_mask[:, :-1]
        shift_src_mask = tf.pad(
            tensor=shift_src_mask,
            paddings=[[0, 0], [1, 0]],
            constant_values=1)

        encoder_input = tf.gather(src_embedding, tf.cast(src_seq, tf.int32))
        encoder_input = encoder_input * (hidden_size**0.5)
        if params['position_info_type'] == 'absolute':
            encoder_input = add_timing_signal(encoder_input)
        encoder_input = tf.multiply(encoder_input,
                                    tf.expand_dims(shift_src_mask, 2))

        encoder_input = tf.nn.bias_add(encoder_input, src_bias)
        encoder_self_attention_bias = attention_bias(shift_src_mask, 'masking')

        if params['residual_dropout'] > 0.0:
            encoder_input = tf.nn.dropout(
                encoder_input, rate=params['residual_dropout'])

        # encode
        encoder_output = transformer_encoder(encoder_input,
                                             encoder_self_attention_bias,
                                             shift_src_mask, params)
        return encoder_output, encoder_self_attention_bias

    def semantic_encoding_graph(self, features, params, name=None):
        hidden_size = params['hidden_size']
        initializer = tf.compat.v1.random_normal_initializer(
            0.0, hidden_size**-0.5, dtype=tf.float32)
        scope = None
        if params['shared_source_target_embedding']:
            vocab_size = params['src_vocab_size']
            scope = 'Shared_Semantic_Embedding'
        elif name == 'source':
            vocab_size = params['src_vocab_size']
            scope = 'Source_Semantic_Embedding'
        elif name == 'target':
            vocab_size = params['trg_vocab_size']
            scope = 'Target_Semantic_Embedding'
        else:
            raise ValueError('error: no right name specified.')

        with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
            embedding_mat = tf.compat.v1.get_variable(
                'Weights', [vocab_size, hidden_size], initializer=initializer)

        eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
        input_seq = tf.concat([features, eos_padding], 1)
        input_mask = tf.cast(tf.not_equal(input_seq, 0), dtype=tf.float32)
        shift_input_mask = input_mask[:, :-1]
        shift_input_mask = tf.pad(
            tensor=shift_input_mask,
            paddings=[[0, 0], [1, 0]],
            constant_values=1)

        encoder_input = tf.gather(embedding_mat, tf.cast(input_seq, tf.int32))
        encoder_input = encoder_input * (hidden_size**0.5)
        encoder_input = tf.multiply(encoder_input,
                                    tf.expand_dims(shift_input_mask, 2))

        encoder_self_attention_bias = attention_bias(shift_input_mask,
                                                     'masking')

        if params['residual_dropout'] > 0.0:
            encoder_input = tf.nn.dropout(
                encoder_input, rate=params['residual_dropout'])

        # encode
        encoder_output = transformer_semantic_encoder(
            encoder_input, encoder_self_attention_bias, shift_input_mask,
            params)
        return encoder_output

    def build_contrastive_training_graph(self, features, labels, params):
        # representations
        source_name = 'source'
        target_name = 'target'
        if params['shared_source_target_embedding']:
            source_name = None
            target_name = None
        feature_output = self.semantic_encoding_graph(
            features, params, name=source_name)
        label_output = self.semantic_encoding_graph(
            labels, params, name=target_name)

        return feature_output, label_output

    def MGMC_sampling(self, x_embedding, y_embedding, params, epsilon=1e-12):
        K = params['num_of_samples']
        eta = params['eta']
        assert K % 2 == 0

        def get_samples(x_vector, y_vector):
            bias_vector = y_vector - x_vector
            w_r = tf.math.divide(
                tf.abs(bias_vector) - tf.reduce_min(
                    input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
                + epsilon,
                tf.reduce_max(
                    input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
                - tf.reduce_min(
                    input_tensor=tf.abs(bias_vector), axis=2, keepdims=True)
                + 2 * epsilon)

            R = []
            for i in range(K // 2):
                omega = eta * tf.random.normal(tf.shape(input=bias_vector), 0.0, w_r) + \
                    (1.0 - eta) * tf.random.normal(tf.shape(input=bias_vector), 0.0, 1.0)
                sample = x_vector + omega * bias_vector
                R.append(sample)
            return R

        ALL_SAMPLES = []
        ALL_SAMPLES = get_samples(x_embedding, y_embedding)
        ALL_SAMPLES.extend(get_samples(y_embedding, x_embedding))

        assert len(ALL_SAMPLES) == K

        return tf.concat(ALL_SAMPLES, axis=0)

    def decoding_graph(self,
                       encoder_output,
                       encoder_self_attention_bias,
                       labels,
                       params={},
                       embedding_augmentation=None):
        trg_vocab_size = params['trg_vocab_size']
        hidden_size = params['hidden_size']

        initializer = tf.compat.v1.random_normal_initializer(
            0.0, hidden_size**-0.5, dtype=tf.float32)

        if params['shared_source_target_embedding']:
            with tf.compat.v1.variable_scope(
                    'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
                trg_embedding = tf.compat.v1.get_variable(
                    'Weights', [trg_vocab_size, hidden_size],
                    initializer=initializer)
        else:
            with tf.compat.v1.variable_scope('Target_Embedding'):
                trg_embedding = tf.compat.v1.get_variable(
                    'Weights', [trg_vocab_size, hidden_size],
                    initializer=initializer)

        eos_padding = tf.zeros_like(labels, dtype=tf.int64)[:, :1]
        trg_seq = tf.concat([labels, eos_padding], 1)
        trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32)
        shift_trg_mask = trg_mask[:, :-1]
        shift_trg_mask = tf.pad(
            tensor=shift_trg_mask,
            paddings=[[0, 0], [1, 0]],
            constant_values=1)

        decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))

        decoder_input *= hidden_size**0.5
        decoder_self_attention_bias = attention_bias(
            tf.shape(input=decoder_input)[1], 'causal')
        decoder_input = tf.pad(
            tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        if params['position_info_type'] == 'absolute':
            decoder_input = add_timing_signal(decoder_input)

        decoder_input = tf.nn.dropout(
            decoder_input, rate=1 - (1.0 - params['residual_dropout']))

        # training
        decoder_output, attention_weights = transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_self_attention_bias,
            states_key=None,
            states_val=None,
            embedding_augmentation=embedding_augmentation,
            params=params)

        logits = self.prediction(decoder_output, params)

        on_value = params['confidence']
        off_value = (1.0 - params['confidence']) / tf.cast(
            trg_vocab_size - 1, dtype=tf.float32)
        soft_targets = tf.one_hot(
            tf.cast(trg_seq, tf.int32),
            depth=trg_vocab_size,
            on_value=on_value,
            off_value=off_value)
        mask = tf.cast(shift_trg_mask, logits.dtype)
        xentropy = tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=tf.stop_gradient(soft_targets)) * mask
        loss = tf.reduce_sum(input_tensor=xentropy) / tf.reduce_sum(
            input_tensor=mask)

        return loss

    def build_training_graph(self,
                             features,
                             labels,
                             params,
                             feature_embedding=None,
                             label_embedding=None):
        # encode
        encoder_output, encoder_self_attention_bias = self.encoding_graph(
            features, params)
        embedding_augmentation = None
        if feature_embedding is not None and label_embedding is not None:
            embedding_augmentation = self.MGMC_sampling(
                feature_embedding, label_embedding, params)

            encoder_output = tf.tile(encoder_output,
                                     [params['num_of_samples'], 1, 1])
            encoder_self_attention_bias = tf.tile(
                encoder_self_attention_bias,
                [params['num_of_samples'], 1, 1, 1])
            labels = tf.tile(labels, [params['num_of_samples'], 1])

        # decode
        loss = self.decoding_graph(
            encoder_output,
            encoder_self_attention_bias,
            labels,
            params,
            embedding_augmentation=embedding_augmentation)

        return loss

    def transformer_model_train_fn(self, features, labels):
        initializer = get_initializer(self.params)
        with tf.compat.v1.variable_scope('NmtModel', initializer=initializer):
            num_gpus = self.params['num_gpus']
            gradient_clip_norm = self.params['gradient_clip_norm']
            global_step = tf.compat.v1.train.get_global_step()
            print(global_step)

            # learning rate
            learning_rate = get_learning_rate_decay(
                self.params['learning_rate'], global_step, self.params)
            learning_rate = tf.convert_to_tensor(
                value=learning_rate, dtype=tf.float32)

            # optimizer
            if self.params['optimizer'] == 'sgd':
                optimizer = tf.compat.v1.train.GradientDescentOptimizer(
                    learning_rate)
            elif self.params['optimizer'] == 'adam':
                optimizer = tf.compat.v1.train.AdamOptimizer(
                    learning_rate=learning_rate,
                    beta1=self.params['adam_beta1'],
                    beta2=self.params['adam_beta2'],
                    epsilon=self.params['adam_epsilon'])
            else:
                tf.compat.v1.logging.info('optimizer not supported')
                sys.exit()
            opt = MultiStepOptimizer(optimizer, self.params['update_cycle'])

            def fill_gpus(inputs, num_gpus):
                outputs = inputs
                for i in range(num_gpus):
                    outputs = tf.concat([outputs, inputs], axis=0)
                outputs = outputs[:num_gpus, ]
                return outputs

            features = tf.cond(
                pred=tf.shape(input=features)[0] < num_gpus,
                true_fn=lambda: fill_gpus(features, num_gpus),
                false_fn=lambda: features)
            labels = tf.cond(
                pred=tf.shape(input=labels)[0] < num_gpus,
                true_fn=lambda: fill_gpus(labels, num_gpus),
                false_fn=lambda: labels)

            if num_gpus > 0:
                feature_shards = shard_features(features, num_gpus)
                label_shards = shard_features(labels, num_gpus)
            else:
                feature_shards = [features]
                label_shards = [labels]

            if num_gpus > 0:
                devices = ['gpu:%d' % d for d in range(num_gpus)]
            else:
                devices = ['cpu:0']
            multi_grads = []
            sharded_losses = []

            for i, device in enumerate(devices):
                with tf.device(device), tf.compat.v1.variable_scope(
                        tf.compat.v1.get_variable_scope(),
                        reuse=True if i > 0 else None):
                    with tf.name_scope('%s_%d' % ('GPU', i)):
                        feature_output, label_output = self.build_contrastive_training_graph(
                            feature_shards[i], label_shards[i], self.params)
                        mle_loss = self.build_training_graph(
                            feature_shards[i], label_shards[i], self.params,
                            feature_output, label_output)
                        sharded_losses.append(mle_loss)
                        tf.compat.v1.summary.scalar('mle_loss_{}'.format(i),
                                                    mle_loss)

                        # Optimization
                        trainable_vars_list = [
                            v for v in tf.compat.v1.trainable_variables()
                            if 'Semantic_Embedding' not in v.name
                            and 'mini_xlm_encoder' not in v.name
                        ]
                        grads_and_vars = opt.compute_gradients(
                            mle_loss,
                            var_list=trainable_vars_list,
                            colocate_gradients_with_ops=True)
                        multi_grads.append(grads_and_vars)

            total_loss = tf.add_n(sharded_losses) / len(sharded_losses)

            # Average gradients
            grads_and_vars = average_gradients(multi_grads)

            if gradient_clip_norm > 0.0:
                grads, var_list = list(zip(*grads_and_vars))
                grads, _ = tf.clip_by_global_norm(grads, gradient_clip_norm)
                grads_and_vars = zip(grads, var_list)

            train_op = opt.apply_gradients(
                grads_and_vars,
                global_step=tf.compat.v1.train.get_global_step())

            return train_op, total_loss

    def prediction(self, decoder_output, params):
        hidden_size = params['hidden_size']
        trg_vocab_size = params['trg_vocab_size']

        if params['shared_embedding_and_softmax_weights']:
            embedding_scope = 'Shared_Embedding' if params[
                'shared_source_target_embedding'] else 'Target_Embedding'
            with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
                weights = tf.compat.v1.get_variable('Weights')
        else:
            weights = tf.compat.v1.get_variable('Softmax',
                                                [tgt_vocab_size, hidden_size])
        shape = tf.shape(input=decoder_output)[:-1]
        decoder_output = tf.reshape(decoder_output, [-1, hidden_size])
        logits = tf.matmul(decoder_output, weights, transpose_b=True)
        logits = tf.reshape(logits, tf.concat([shape, [trg_vocab_size]], 0))
        return logits

    def inference_func(self,
                       encoder_output,
                       feature_output,
                       encoder_self_attention_bias,
                       trg_seq,
                       states_key,
                       states_val,
                       params={},
                       is_prefix=False):
        trg_vocab_size = params['trg_vocab_size']
        hidden_size = params['hidden_size']

        initializer = tf.compat.v1.random_normal_initializer(
            0.0, hidden_size**-0.5, dtype=tf.float32)

        if params['shared_source_target_embedding']:
            with tf.compat.v1.variable_scope(
                    'Shared_Embedding', reuse=tf.compat.v1.AUTO_REUSE):
                trg_embedding = tf.compat.v1.get_variable(
                    'Weights', [trg_vocab_size, hidden_size],
                    initializer=initializer)
        else:
            with tf.compat.v1.variable_scope('Target_Embedding'):
                trg_embedding = tf.compat.v1.get_variable(
                    'Weights', [trg_vocab_size, hidden_size],
                    initializer=initializer)

        decoder_input = tf.gather(trg_embedding, tf.cast(trg_seq, tf.int32))
        decoder_input *= hidden_size**0.5
        decoder_self_attention_bias = attention_bias(
            tf.shape(input=decoder_input)[1], 'causal')
        decoder_input = tf.pad(
            tensor=decoder_input, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        if params['position_info_type'] == 'absolute':
            decoder_input = add_timing_signal(decoder_input)
        if not is_prefix:
            decoder_input = decoder_input[:, -1:, :]
            decoder_self_attention_bias = decoder_self_attention_bias[:, :,
                                                                      -1:, :]
        decoder_output, attention_weights = transformer_decoder(
            decoder_input,
            encoder_output,
            decoder_self_attention_bias,
            encoder_self_attention_bias,
            states_key=states_key,
            states_val=states_val,
            embedding_augmentation=feature_output,
            params=params)
        if not is_prefix:
            decoder_output_last = decoder_output[:, -1, :]
            attention_weights_last = attention_weights[:, -1, :]
        else:
            decoder_output_last = decoder_output
            attention_weights_last = attention_weights

        if params['shared_embedding_and_softmax_weights']:
            embedding_scope = \
                'Shared_Embedding' if params['shared_source_target_embedding'] else 'Target_Embedding'
            with tf.compat.v1.variable_scope(embedding_scope, reuse=True):
                weights = tf.compat.v1.get_variable('Weights')
        else:
            weights = tf.compat.v1.get_variable('Softmax',
                                                [trg_vocab_size, hidden_size])
        logits = tf.matmul(decoder_output_last, weights, transpose_b=True)
        log_prob = tf.nn.log_softmax(logits)
        return log_prob, attention_weights_last, states_key, states_val

    def beam_search(self, features, params):
        beam_size = params['beam_size']
        trg_vocab_size = params['trg_vocab_size']
        hidden_size = params['hidden_size']
        num_decoder_layers = params['num_decoder_layers']
        lp_rate = params['lp_rate']
        max_decoded_trg_len = params['max_decoded_trg_len']
        src_input = features['input_wids']
        if 'prefix_wids' in features:
            prefix = features['prefix_wids']
            prefix_hit = features['prefix_hit']
        else:
            prefix = None
            prefix_hit = None
        batch_size = tf.shape(src_input)[0]

        src_input = tile_to_beam_size(src_input, beam_size)
        src_input = merge_first_two_dims(src_input)
        if prefix is not None:
            prefix = tf.cast(tile_to_beam_size(prefix, beam_size), tf.int32)
            prefix_hit = tile_to_beam_size(prefix_hit, beam_size)

        encoder_output, encoder_self_attention_bias = self.encoding_graph(
            src_input, params)
        source_name = 'source'
        if params['shared_source_target_embedding']:
            source_name = None
        feature_output = self.semantic_encoding_graph(
            src_input, params, name=source_name)

        states_key = [
            tf.fill([batch_size, 0, hidden_size], 0.0)
            for layer in range(num_decoder_layers)
        ]
        states_val = [
            tf.fill([batch_size, 0, hidden_size], 0.0)
            for layer in range(num_decoder_layers)
        ]
        for layer in range(num_decoder_layers):
            states_key[layer].set_shape(
                tf.TensorShape([None, None, hidden_size]))
            states_val[layer].set_shape(
                tf.TensorShape([None, None, hidden_size]))
        states_key = [
            tile_to_beam_size(states_key[layer], beam_size)
            for layer in range(num_decoder_layers)
        ]
        states_val = [
            tile_to_beam_size(states_val[layer], beam_size)
            for layer in range(num_decoder_layers)
        ]
        fixed_length = 1
        if prefix is not None:
            init_seqs = tf.concat(
                [prefix, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
            fixed_length = tf.shape(init_seqs)[-1]
            flat_seqs = merge_first_two_dims(init_seqs)
            flat_states_key = [
                merge_first_two_dims(states_key[layer])
                for layer in range(num_decoder_layers)
            ]
            flat_states_val = [
                merge_first_two_dims(states_val[layer])
                for layer in range(num_decoder_layers)
            ]

            step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
                encoder_output,
                feature_output,
                encoder_self_attention_bias,
                flat_seqs,
                flat_states_key,
                flat_states_val,
                params=params,
                is_prefix=True)

            states_key = [
                split_first_two_dims(step_states_key[layer], batch_size,
                                     beam_size)
                for layer in range(num_decoder_layers)
            ]
            states_val = [
                split_first_two_dims(step_states_val[layer], batch_size,
                                     beam_size)
                for layer in range(num_decoder_layers)
            ]

            prefix_hit = merge_first_two_dims(prefix_hit)
            log_probs = tf.where(
                prefix_hit, step_log_probs[:, -1, :],
                tf.ones_like(step_log_probs[:, -1, :]) * tf.float32.min)

            init_seqs = tf.concat([
                flat_seqs[:, :-1],
                tf.expand_dims(
                    tf.cast(tf.argmax(log_probs, -1), tf.int32), -1)
            ], -1)

            init_seqs = split_first_two_dims(init_seqs, batch_size, beam_size)
            init_seqs = tf.concat(
                [init_seqs, tf.fill([batch_size, beam_size, 1], 0)], axis=2)
        else:
            init_seqs = tf.fill([batch_size, beam_size, 1], 0)

        init_log_probs = \
            tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
        init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
        init_scores = tf.zeros_like(init_log_probs)
        fin_seqs = init_seqs
        fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
        fin_flags = tf.cast(tf.fill([batch_size, beam_size], 0), tf.bool)

        state = BeamSearchState(
            inputs=(init_seqs, init_log_probs, init_scores),
            state=(states_key, states_val),
            finish=(fin_flags, fin_seqs, fin_scores),
        )

        def _beam_search_step(time, state):
            seqs, log_probs = state.inputs[:2]
            states_key, states_val = state.state

            flat_seqs = merge_first_two_dims(seqs)
            flat_states_key = [
                merge_first_two_dims(states_key[layer])
                for layer in range(num_decoder_layers)
            ]
            flat_states_val = [
                merge_first_two_dims(states_val[layer])
                for layer in range(num_decoder_layers)
            ]

            step_log_probs, step_attn_weights, step_states_key, step_states_val = self.inference_func(
                encoder_output,
                feature_output,
                encoder_self_attention_bias,
                flat_seqs,
                flat_states_key,
                flat_states_val,
                params=params,
                is_prefix=False)

            step_log_probs = split_first_two_dims(step_log_probs, batch_size,
                                                  beam_size)
            curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs

            next_states_key = [
                split_first_two_dims(step_states_key[layer], batch_size,
                                     beam_size)
                for layer in range(num_decoder_layers)
            ]
            next_states_val = [
                split_first_two_dims(step_states_val[layer], batch_size,
                                     beam_size)
                for layer in range(num_decoder_layers)
            ]

            # Apply length penalty
            length_penalty = tf.pow(
                (5.0 + tf.cast(time + 1, dtype=tf.float32)) / 6.0, lp_rate)
            curr_scores = curr_log_probs / length_penalty

            # Select top-k candidates
            # [batch_size, beam_size * vocab_size]
            curr_scores = tf.reshape(curr_scores,
                                     [-1, beam_size * trg_vocab_size])
            # [batch_size, 2 * beam_size]
            top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
            # Shape: [batch_size, 2 * beam_size]
            beam_indices = top_indices // trg_vocab_size
            symbol_indices = top_indices % trg_vocab_size
            # Expand sequences
            # [batch_size, 2 * beam_size, time]
            candidate_seqs = gather_2d(seqs, beam_indices)
            candidate_seqs = tf.concat(
                [candidate_seqs[:, :, :-1],
                 tf.expand_dims(symbol_indices, 2)],
                axis=2)
            pad_seqs = tf.fill([batch_size, 2 * beam_size, 1],
                               tf.constant(0, tf.int32))
            candidate_seqs = tf.concat([candidate_seqs, pad_seqs], axis=2)

            # Expand sequences
            # Suppress finished sequences
            flags = tf.equal(symbol_indices, 0)
            # [batch, 2 * beam_size]
            alive_scores = top_scores + tf.cast(
                flags, dtype=tf.float32) * tf.float32.min
            # [batch, beam_size]
            alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
            alive_symbols = gather_2d(symbol_indices, alive_indices)
            alive_indices = gather_2d(beam_indices, alive_indices)
            alive_seqs = gather_2d(seqs, alive_indices)
            alive_seqs = tf.concat(
                [alive_seqs[:, :, :-1],
                 tf.expand_dims(alive_symbols, 2)],
                axis=2)
            pad_seqs = tf.fill([batch_size, beam_size, 1],
                               tf.constant(0, tf.int32))
            alive_seqs = tf.concat([alive_seqs, pad_seqs], axis=2)
            alive_states_key = [
                gather_2d(next_states_key[layer], alive_indices)
                for layer in range(num_decoder_layers)
            ]
            alive_states_val = [
                gather_2d(next_states_val[layer], alive_indices)
                for layer in range(num_decoder_layers)
            ]
            alive_log_probs = alive_scores * length_penalty

            # Select finished sequences
            prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
            # [batch, 2 * beam_size]
            step_fin_scores = top_scores + (
                1.0 - tf.cast(flags, dtype=tf.float32)) * tf.float32.min
            # [batch, 3 * beam_size]
            fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
            fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
            # [batch, beam_size]
            fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
            fin_flags = gather_2d(fin_flags, fin_indices)
            pad_seqs = tf.fill([batch_size, beam_size, 1],
                               tf.constant(0, tf.int32))
            prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
            fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
            fin_seqs = gather_2d(fin_seqs, fin_indices)

            new_state = BeamSearchState(
                inputs=(alive_seqs, alive_log_probs, alive_scores),
                state=(alive_states_key, alive_states_val),
                finish=(fin_flags, fin_seqs, fin_scores),
            )

            return time + 1, new_state

        def _is_finished(t, s):
            log_probs = s.inputs[1]
            finished_flags = s.finish[0]
            finished_scores = s.finish[2]
            max_lp = tf.pow(
                ((5.0 + tf.cast(max_decoded_trg_len, dtype=tf.float32)) / 6.0),
                lp_rate)
            best_alive_score = log_probs[:, 0] / max_lp
            worst_finished_score = tf.reduce_min(
                input_tensor=finished_scores
                * tf.cast(finished_flags, dtype=tf.float32),
                axis=1)
            add_mask = 1.0 - tf.cast(
                tf.reduce_any(input_tensor=finished_flags, axis=1),
                dtype=tf.float32)
            worst_finished_score += tf.float32.min * add_mask
            bound_is_met = tf.reduce_all(
                input_tensor=tf.greater(worst_finished_score,
                                        best_alive_score))

            cond = tf.logical_and(
                tf.less(t, max_decoded_trg_len), tf.logical_not(bound_is_met))

            return cond

        def _loop_fn(t, s):
            outs = _beam_search_step(t, s)
            return outs

        time = tf.constant(0, name='time')
        shape_invariants = BeamSearchState(
            inputs=(tf.TensorShape([None, None, None]),
                    tf.TensorShape([None, None]), tf.TensorShape([None,
                                                                  None])),
            state=([
                tf.TensorShape([None, None, None, hidden_size])
                for layer in range(num_decoder_layers)
            ], [
                tf.TensorShape([None, None, None, hidden_size])
                for layer in range(num_decoder_layers)
            ]),
            finish=(tf.TensorShape([None,
                                    None]), tf.TensorShape([None, None, None]),
                    tf.TensorShape([None, None])))
        outputs = tf.while_loop(
            cond=_is_finished,
            body=_loop_fn,
            loop_vars=[time, state],
            shape_invariants=[tf.TensorShape([]), shape_invariants],
            parallel_iterations=1,
            back_prop=False)

        final_state = outputs[1]
        alive_seqs = final_state.inputs[0]
        alive_scores = final_state.inputs[2]
        final_flags = final_state.finish[0]
        final_seqs = final_state.finish[1]
        final_scores = final_state.finish[2]

        alive_seqs.set_shape([None, beam_size, None])
        final_seqs.set_shape([None, beam_size, None])

        final_seqs = tf.compat.v1.where(
            tf.reduce_any(input_tensor=final_flags, axis=1), final_seqs,
            alive_seqs)
        final_scores = tf.compat.v1.where(
            tf.reduce_any(input_tensor=final_flags, axis=1), final_scores,
            alive_scores)

        final_seqs = final_seqs[:, :, fixed_length - 1:-1]
        return final_seqs, final_scores


class BeamSearchState(
        namedtuple('BeamSearchState', ('inputs', 'state', 'finish'))):
    pass


def tile_to_beam_size(tensor, beam_size):
    """Tiles a given tensor by beam_size. """
    tensor = tf.expand_dims(tensor, axis=1)
    tile_dims = [1] * tensor.shape.ndims
    tile_dims[1] = beam_size

    return tf.tile(tensor, tile_dims)


def infer_shape(x):
    x = tf.convert_to_tensor(x)

    if x.shape.dims is None:
        return tf.shape(x)

    static_shape = x.shape.as_list()
    dynamic_shape = tf.shape(x)

    ret = []
    for i in range(len(static_shape)):
        dim = static_shape[i]
        if dim is None:
            dim = dynamic_shape[i]
        ret.append(dim)

    return ret


def split_first_two_dims(tensor, dim_0, dim_1):
    shape = infer_shape(tensor)
    new_shape = [dim_0] + [dim_1] + shape[1:]
    return tf.reshape(tensor, new_shape)


def merge_first_two_dims(tensor):
    shape = infer_shape(tensor)
    shape[0] *= shape[1]
    shape.pop(1)
    return tf.reshape(tensor, shape)


def gather_2d(params, indices, name=None):
    """ Gather the 2nd dimension given indices
    :param params: A tensor with shape [batch_size, M, ...]
    :param indices: A tensor with shape [batch_size, N]
    :param name: An optional string
    :return: A tensor with shape [batch_size, N, ...]
    """
    batch_size = tf.shape(params)[0]
    range_size = tf.shape(indices)[1]
    batch_pos = tf.range(batch_size * range_size) // range_size
    batch_pos = tf.reshape(batch_pos, [batch_size, range_size])
    indices = tf.stack([batch_pos, indices], axis=-1)
    output = tf.gather_nd(params, indices, name=name)

    return output


def linear(inputs, output_size, bias, concat=True, dtype=None, scope=None):
    with tf.compat.v1.variable_scope(
            scope, default_name='linear', values=[inputs], dtype=dtype):
        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        input_size = [item.get_shape()[-1] for item in inputs]

        if len(inputs) != len(input_size):
            raise RuntimeError('inputs and input_size unmatched!')

        output_shape = tf.concat([tf.shape(inputs[0])[:-1], [output_size]],
                                 axis=0)
        # Flatten to 2D
        inputs = [tf.reshape(inp, [-1, inp.shape[-1]]) for inp in inputs]

        results = []
        if concat:
            input_size = sum(input_size)
            inputs = tf.concat(inputs, 1)

            shape = [input_size, output_size]
            matrix = tf.compat.v1.get_variable('matrix', shape)
            results.append(tf.matmul(inputs, matrix))
        else:
            for i in range(len(input_size)):
                shape = [input_size[i], output_size]
                name = 'matrix_%d' % i
                matrix = tf.compat.v1.get_variable(name, shape)
                results.append(tf.matmul(inputs[i], matrix))

        output = tf.add_n(results)

        if bias:
            shape = [output_size]
            bias = tf.compat.v1.get_variable('bias', shape)
            output = tf.nn.bias_add(output, bias)

        output = tf.reshape(output, output_shape)

        return output


def layer_norm(inputs, epsilon=1e-6, name=None, reuse=None):
    with tf.compat.v1.variable_scope(
            name, default_name='layer_norm', values=[inputs], reuse=reuse):
        channel_size = inputs.get_shape().as_list()[-1]

        scale = tf.compat.v1.get_variable(
            'layer_norm_scale', [channel_size],
            initializer=tf.ones_initializer())

        offset = tf.compat.v1.get_variable(
            'layer_norm_offset', [channel_size],
            initializer=tf.zeros_initializer())

        mean = tf.reduce_mean(inputs, -1, True)
        variance = tf.reduce_mean(tf.square(inputs - mean), -1, True)

        norm_inputs = (inputs - mean) * tf.compat.v1.rsqrt(variance + epsilon)

        return norm_inputs * scale + offset


def _layer_process(x, mode):
    if not mode or mode == 'none':
        return x
    elif mode == 'layer_norm':
        return layer_norm(x)
    else:
        raise ValueError('Unknown mode %s' % mode)


def _residual_fn(x, y, keep_prob=None):
    if keep_prob and keep_prob < 1.0:
        y = tf.nn.dropout(y, rate=1 - (keep_prob))
    return x + y


def embedding_augmentation_layer(x, embedding_augmentation, params, name=None):
    hidden_size = params['hidden_size']
    keep_prob = 1.0 - params['relu_dropout']
    with tf.compat.v1.variable_scope(
            name,
            default_name='embedding_augmentation_layer',
            values=[x, embedding_augmentation]):
        with tf.compat.v1.variable_scope('input_layer'):
            hidden = linear(embedding_augmentation, hidden_size, True, True)
            hidden = tf.nn.relu(hidden)

        if keep_prob and keep_prob < 1.0:
            hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))

        with tf.compat.v1.variable_scope('output_layer'):
            output = linear(hidden, hidden_size, True, True)

        return x + output


def transformer_ffn_layer(x, params, name=None):
    filter_size = params['filter_size']
    hidden_size = params['hidden_size']
    keep_prob = 1.0 - params['relu_dropout']
    with tf.compat.v1.variable_scope(
            name, default_name='ffn_layer', values=[x]):
        with tf.compat.v1.variable_scope('input_layer'):
            hidden = linear(x, filter_size, True, True)
            hidden = tf.nn.relu(hidden)

        if keep_prob and keep_prob < 1.0:
            hidden = tf.nn.dropout(hidden, rate=1 - (keep_prob))

        with tf.compat.v1.variable_scope('output_layer'):
            output = linear(hidden, hidden_size, True, True)

        return output


def transformer_encoder(encoder_input,
                        encoder_self_attention_bias,
                        mask,
                        params={},
                        name='encoder'):
    num_encoder_layers = params['num_encoder_layers']
    hidden_size = params['hidden_size']
    num_heads = params['num_heads']
    residual_dropout = params['residual_dropout']
    attention_dropout = params['attention_dropout']
    layer_preproc = params['layer_preproc']
    layer_postproc = params['layer_postproc']
    x = encoder_input
    mask = tf.expand_dims(mask, 2)
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        for layer in range(num_encoder_layers):
            with tf.compat.v1.variable_scope('layer_%d' % layer):
                max_relative_dis = params['max_relative_dis'] \
                    if params['position_info_type'] == 'relative' else None
                o, w = multihead_attention(
                    _layer_process(x, layer_preproc),
                    None,
                    encoder_self_attention_bias,
                    hidden_size,
                    hidden_size,
                    hidden_size,
                    num_heads,
                    attention_dropout,
                    max_relative_dis=max_relative_dis,
                    name='encoder_self_attention')
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                o = transformer_ffn_layer(
                    _layer_process(x, layer_preproc), params)
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                x = tf.multiply(x, mask)

        return _layer_process(x, layer_preproc)


def transformer_semantic_encoder(encoder_input,
                                 encoder_self_attention_bias,
                                 mask,
                                 params={},
                                 name='mini_xlm_encoder'):
    num_encoder_layers = params['num_semantic_encoder_layers']
    hidden_size = params['hidden_size']
    num_heads = params['num_heads']
    residual_dropout = params['residual_dropout']
    attention_dropout = params['attention_dropout']
    layer_preproc = params['layer_preproc']
    layer_postproc = params['layer_postproc']
    x = encoder_input
    mask = tf.expand_dims(mask, 2)
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        for layer in range(num_encoder_layers):
            with tf.compat.v1.variable_scope('layer_%d' % layer):
                max_relative_dis = params['max_relative_dis']
                o, w = multihead_attention(
                    _layer_process(x, layer_preproc),
                    None,
                    encoder_self_attention_bias,
                    hidden_size,
                    hidden_size,
                    hidden_size,
                    num_heads,
                    attention_dropout,
                    max_relative_dis=max_relative_dis,
                    name='encoder_self_attention')
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                o = transformer_ffn_layer(
                    _layer_process(x, layer_preproc), params)
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                x = tf.multiply(x, mask)

        with tf.compat.v1.variable_scope(
                'pooling_layer', reuse=tf.compat.v1.AUTO_REUSE):
            output = tf.reduce_sum(
                input_tensor=x, axis=1) / tf.reduce_sum(
                    input_tensor=mask, axis=1)
            output = linear(
                tf.expand_dims(output, axis=1), hidden_size, True, True)

        return _layer_process(output, layer_preproc)


def transformer_decoder(decoder_input,
                        encoder_output,
                        decoder_self_attention_bias,
                        encoder_decoder_attention_bias,
                        states_key=None,
                        states_val=None,
                        embedding_augmentation=None,
                        params={},
                        name='decoder'):
    num_decoder_layers = params['num_decoder_layers']
    hidden_size = params['hidden_size']
    num_heads = params['num_heads']
    residual_dropout = params['residual_dropout']
    attention_dropout = params['attention_dropout']
    layer_preproc = params['layer_preproc']
    layer_postproc = params['layer_postproc']
    x = decoder_input
    with tf.compat.v1.variable_scope(name, reuse=tf.compat.v1.AUTO_REUSE):
        for layer in range(num_decoder_layers):
            with tf.compat.v1.variable_scope('layer_%d' % layer):
                max_relative_dis = params['max_relative_dis'] \
                    if params['position_info_type'] == 'relative' else None
                # continuous semantic augmentation
                if embedding_augmentation is not None:
                    x = embedding_augmentation_layer(
                        x, _layer_process(embedding_augmentation,
                                          layer_preproc), params)
                    x = _layer_process(x, layer_postproc)
                o, w = multihead_attention(
                    _layer_process(x, layer_preproc),
                    None,
                    decoder_self_attention_bias,
                    hidden_size,
                    hidden_size,
                    hidden_size,
                    num_heads,
                    attention_dropout,
                    states_key=states_key,
                    states_val=states_val,
                    layer=layer,
                    max_relative_dis=max_relative_dis,
                    name='decoder_self_attention')
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                o, w = multihead_attention(
                    _layer_process(x, layer_preproc),
                    encoder_output,
                    encoder_decoder_attention_bias,
                    hidden_size,
                    hidden_size,
                    hidden_size,
                    num_heads,
                    attention_dropout,
                    max_relative_dis=max_relative_dis,
                    name='encdec_attention')
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

                o = transformer_ffn_layer(
                    _layer_process(x, layer_preproc), params)
                x = _residual_fn(x, o, 1.0 - residual_dropout)
                x = _layer_process(x, layer_postproc)

        return _layer_process(x, layer_preproc), w


def add_timing_signal(x, min_timescale=1.0, max_timescale=1.0e4):
    length = tf.shape(x)[1]
    channels = tf.shape(x)[2]
    position = tf.cast(tf.range(length), tf.float32)
    num_timescales = channels // 2

    log_timescale_increment = \
        (math.log(float(max_timescale) / float(min_timescale)) / (tf.cast(num_timescales, tf.float32) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.cast(tf.range(num_timescales), tf.float32)
        * -log_timescale_increment)

    scaled_time = \
        tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
    signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    signal = tf.pad(signal, [[0, 0], [0, tf.compat.v1.mod(channels, 2)]])
    signal = tf.reshape(signal, [1, length, channels])

    return x + tf.cast(signal, x.dtype)


def attention_bias(inputs, mode, inf=-1e9, dtype=None):
    if dtype is None:
        dtype = tf.float32

    if dtype != tf.float32:
        inf = dtype.min

    if mode == 'masking':
        mask = inputs
        ret = (1.0 - mask) * inf
        ret = tf.expand_dims(tf.expand_dims(ret, 1), 1)

    elif mode == 'causal':
        length = inputs
        lower_triangle = tf.linalg.band_part(
            tf.fill([length, length], 1.0), -1, 0)
        ret = inf * (1.0 - lower_triangle)
        ret = tf.reshape(ret, [1, 1, length, length])
    else:
        raise ValueError('Unknown mode %s' % mode)

    return tf.cast(ret, dtype)


def split_heads(x, num_heads):
    n = num_heads
    old_shape = x.get_shape().dims
    ndims = x.shape.ndims

    last = old_shape[-1]
    new_shape = old_shape[:-1] + [n] + [last // n if last else None]
    ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0))
    ret.set_shape(new_shape)
    perm = [0, ndims - 1] + [i for i in range(1, ndims - 1)] + [ndims]
    return tf.transpose(ret, perm)


def dot_product_attention(q,
                          k,
                          v,
                          bias,
                          dropout_rate=0.0,
                          name=None,
                          rpr=None):
    with tf.compat.v1.variable_scope(
            name, default_name='dot_product_attention', values=[q, k, v]):
        q_shape = tf.shape(q)
        bs, hd, lq, dk = q_shape[0], q_shape[1], q_shape[2], q_shape[3]
        lk = tf.shape(k)[2]
        dv = tf.shape(v)[3]

        if rpr is not None:
            rpr_k, rpr_v = rpr['rpr_k'], rpr[
                'rpr_v']  # (lq, lk, dk), (lq, lk, dv)

        if rpr is None:
            logits = tf.matmul(q, k, transpose_b=True)
        else:  # self-attention with relative position representation
            logits_part1 = tf.matmul(q, k, transpose_b=True)  # bs, hd, lq, lk

            q = tf.reshape(tf.transpose(q, [2, 0, 1, 3]),
                           [lq, bs * hd, dk])  # lq, bs*hd, dk
            logits_part2 = tf.matmul(q,
                                     tf.transpose(rpr_k,
                                                  [0, 2, 1]))  # lq, bs*hd, lk
            logits_part2 = tf.reshape(
                tf.transpose(logits_part2, [1, 0, 2]), [bs, hd, lq, lk])

            logits = logits_part1 + logits_part2  # bs, hd, lq, lk

        if bias is not None:
            logits += bias

        weights = tf.nn.softmax(logits, name='attention_weights')

        if dropout_rate > 0.0:
            weights = tf.nn.dropout(weights, 1.0 - dropout_rate)

        if rpr is None:
            return tf.matmul(weights, v), weights
        else:
            outputs_part1 = tf.matmul(weights, v)  # bs, hd, lq, dv

            weights = tf.reshape(
                tf.transpose(weights, [2, 0, 1, 3]),
                [lq, bs * hd, lk])  # lq, bs*hd, lk
            outputs_part2 = tf.matmul(weights, rpr_v)  # lq, bs*hd, dv
            outputs_part2 = tf.reshape(
                tf.transpose(outputs_part2, [1, 0, 2]), [bs, hd, lq, dv])

            outputs = outputs_part1 + outputs_part2  # bs, hd, lq, dv
            weights = tf.reshape(
                tf.transpose(weights, [1, 0, 2]),
                [bs, hd, lq, lk])  # bs, hd, lq, lk

            return outputs, weights


def combine_heads(x):
    x = tf.transpose(x, [0, 2, 1, 3])
    old_shape = x.get_shape().dims
    a, b = old_shape[-2:]
    new_shape = old_shape[:-2] + [a * b if a and b else None]
    x = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0))
    x.set_shape(new_shape)

    return x


def create_rpr(orginal_var,
               length_q,
               length_kv,
               max_relative_dis,
               name='create_rpr'):
    with tf.name_scope(name):
        idxs = tf.reshape(tf.range(length_kv), [-1, 1])  # only self-attention
        idys = tf.reshape(tf.range(length_kv), [1, -1])
        ids = idxs - idys
        ids = ids + max_relative_dis
        ids = tf.maximum(ids, 0)
        ids = tf.minimum(ids, 2 * max_relative_dis)
        ids = ids[-length_q:, :]
        rpr = tf.gather(orginal_var, ids)
        return rpr


def multihead_attention(queries,
                        memories,
                        bias,
                        key_depth,
                        value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        states_key=None,
                        states_val=None,
                        layer=0,
                        max_relative_dis=None,
                        name=None):
    if key_depth % num_heads != 0:
        raise ValueError(
            'Key size (%d) must be divisible by the number of attention heads (%d).'
            % (key_size, num_heads))

    if value_depth % num_heads != 0:
        raise ValueError(
            'Value size (%d) must be divisible by the number of attention heads (%d).'
            % (value_size, num_heads))

    with tf.compat.v1.variable_scope(
            name, default_name='multihead_attention',
            values=[queries, memories]):
        if memories is None:
            # self attention
            combined = linear(
                queries,
                key_depth * 2 + value_depth,
                True,
                True,
                scope='qkv_transform')
            q, k, v = tf.split(
                combined, [key_depth, key_depth, value_depth], axis=2)
        else:
            q = linear(queries, key_depth, True, True, scope='q_transform')
            combined = linear(
                memories,
                key_depth + value_depth,
                True,
                True,
                scope='kv_transform')
            k, v = tf.split(combined, [key_depth, value_depth], axis=2)

        if states_key is not None:
            k = states_key[layer] = tf.concat([states_key[layer], k], axis=1)
        if states_val is not None:
            v = states_val[layer] = tf.concat([states_val[layer], v], axis=1)

        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        key_depth_per_head = key_depth // num_heads
        q *= key_depth_per_head**-0.5

        length_q = tf.shape(q)[2]
        length_kv = tf.shape(k)[2]

        # relative position representation (only in self-attention)
        if memories is None and max_relative_dis is not None:
            rpr_k = tf.compat.v1.get_variable(
                'rpr_k', [2 * max_relative_dis + 1, key_depth // num_heads])
            rpr_v = tf.compat.v1.get_variable(
                'rpr_v', [2 * max_relative_dis + 1, value_depth // num_heads])
            rpr_k = create_rpr(rpr_k, length_q, length_kv, max_relative_dis)
            rpr_v = create_rpr(rpr_v, length_q, length_kv, max_relative_dis)
            rpr = {'rpr_k': rpr_k, 'rpr_v': rpr_v}
            x, w = dot_product_attention(q, k, v, bias, dropout_rate, rpr=rpr)
        else:
            x, w = dot_product_attention(q, k, v, bias, dropout_rate)
        x = combine_heads(x)
        w = tf.reduce_mean(w, 1)
        x = linear(x, output_depth, True, True, scope='output_transform')
        return x, w


def get_initializer(params):
    if params['initializer'] == 'uniform':
        max_val = params['initializer_scale']
        return tf.compat.v1.random_uniform_initializer(-max_val, max_val)
    elif params['initializer'] == 'normal':
        return tf.compat.v1.random_normal_initializer(
            0.0, params['initializer_scale'])
    elif params['initializer'] == 'normal_unit_scaling':
        return tf.compat.v1.variance_scaling_initializer(
            params['initializer_scale'], mode='fan_avg', distribution='normal')
    elif params['initializer'] == 'uniform_unit_scaling':
        return tf.compat.v1.variance_scaling_initializer(
            params['initializer_scale'],
            mode='fan_avg',
            distribution='uniform')
    else:
        raise ValueError('Unrecognized initializer: %s'
                         % params['initializer'])


def get_learning_rate_decay(learning_rate, global_step, params):
    if params['learning_rate_decay'] in ['linear_warmup_rsqrt_decay', 'noam']:
        step = tf.cast(global_step, dtype=tf.float32)
        warmup_steps = tf.cast(params['warmup_steps'], dtype=tf.float32)
        multiplier = params['hidden_size']**-0.5
        decay = multiplier * tf.minimum((step + 1) * (warmup_steps**-1.5),
                                        (step + 1)**-0.5)
        return learning_rate * decay
    elif params['learning_rate_decay'] == 'piecewise_constant':
        return tf.compat.v1.train.piecewise_constant(
            tf.cast(global_step, dtype=tf.int32),
            params['learning_rate_boundaries'], params['learning_rate_values'])
    elif params['learning_rate_decay'] == 'none':
        return learning_rate
    else:
        raise ValueError('Unknown learning_rate_decay')


def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)
            grads.append(expanded_g)
        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)
        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads


_ENGINE = None


def all_reduce(tensor):
    if _ENGINE is None:
        return tensor

    return _ENGINE.allreduce(tensor, compression=_ENGINE.Compression.fp16)


class MultiStepOptimizer(tf.compat.v1.train.Optimizer):

    def __init__(self,
                 optimizer,
                 step=1,
                 use_locking=False,
                 name='MultiStepOptimizer'):
        super(MultiStepOptimizer, self).__init__(use_locking, name)
        self._optimizer = optimizer
        self._step = step
        self._step_t = tf.convert_to_tensor(step, name='step')

    def _all_reduce(self, tensor):
        with tf.name_scope(self._name + '_Allreduce'):
            if tensor is None:
                return tensor

            if isinstance(tensor, tf.IndexedSlices):
                tensor = tf.convert_to_tensor(tensor)

            return all_reduce(tensor)

    def compute_gradients(self,
                          loss,
                          var_list=None,
                          gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
                          aggregation_method=None,
                          colocate_gradients_with_ops=False,
                          grad_loss=None):
        grads_and_vars = self._optimizer.compute_gradients(
            loss, var_list, gate_gradients, aggregation_method,
            colocate_gradients_with_ops, grad_loss)

        grads, var_list = list(zip(*grads_and_vars))

        # Do not create extra variables when step is 1
        if self._step == 1:
            grads = [self._all_reduce(t) for t in grads]
            return list(zip(grads, var_list))

        first_var = min(var_list, key=lambda x: x.name)
        iter_var = self._create_non_slot_variable(
            initial_value=0 if self._step == 1 else 1,
            name='iter',
            colocate_with=first_var)

        new_grads = []

        for grad, var in zip(grads, var_list):
            grad_acc = self._zeros_slot(var, 'grad_acc', self._name)

            if isinstance(grad, tf.IndexedSlices):
                grad_acc = tf.scatter_add(
                    grad_acc,
                    grad.indices,
                    grad.values,
                    use_locking=self._use_locking)
            else:
                grad_acc = tf.assign_add(
                    grad_acc, grad, use_locking=self._use_locking)

            def _acc_grad():
                return grad_acc

            def _avg_grad():
                return self._all_reduce(grad_acc / self._step)

            grad = tf.cond(tf.equal(iter_var, 0), _avg_grad, _acc_grad)
            new_grads.append(grad)

        return list(zip(new_grads, var_list))

    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if self._step == 1:
            return self._optimizer.apply_gradients(
                grads_and_vars, global_step, name=name)

        grads, var_list = list(zip(*grads_and_vars))

        def _pass_gradients():
            return tf.group(*grads)

        def _apply_gradients():
            op = self._optimizer.apply_gradients(
                zip(grads, var_list), global_step, name)
            with tf.control_dependencies([op]):
                zero_ops = []
                for var in var_list:
                    grad_acc = self.get_slot(var, 'grad_acc')
                    zero_ops.append(
                        grad_acc.assign(
                            tf.zeros_like(grad_acc),
                            use_locking=self._use_locking))
                zero_op = tf.group(*zero_ops)
            return tf.group(*[op, zero_op])

        iter_var = self._get_non_slot_variable('iter', tf.get_default_graph())
        update_op = tf.cond(
            tf.equal(iter_var, 0), _apply_gradients, _pass_gradients)

        with tf.control_dependencies([update_op]):
            iter_op = iter_var.assign(
                tf.mod(iter_var + 1, self._step_t),
                use_locking=self._use_locking)

        return tf.group(*[update_op, iter_op])


def shard_features(x, num_datashards):
    x = tf.convert_to_tensor(x)
    batch_size = tf.shape(x)[0]
    size_splits = []

    with tf.device('/cpu:0'):
        for i in range(num_datashards):
            size_splits.append(
                tf.cond(
                    tf.greater(
                        tf.compat.v1.mod(batch_size, num_datashards),
                        i), lambda: batch_size // num_datashards + 1,
                    lambda: batch_size // num_datashards))

    return tf.split(x, size_splits, axis=0)
