# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import scipy.io as sio
import re

FLAGS = tf.flags.FLAGS

def _variable_summary(var):
    """
    Attach summary to a variable tensor. (Like the weights and biases)
    :param var: variable tensor
    :return: None
    """
    tensor_name = var.op.name
    tf.summary.scalar(tensor_name + '/mean', tf.reduce_mean(var))
    tf.summary.scalar(tensor_name + '/max', tf.reduce_max(var))
    tf.summary.scalar(tensor_name + '/min', tf.reduce_min(var))
    tf.summary.scalar(tensor_name + '/stddev', tf.sqrt(tf.reduce_mean(tf.square(var - tf.reduce_mean(var)))))
    tf.summary.histogram(tensor_name + '/histogram', var)

def _activation_summary(op):
    """
    Attach summary to a op tensor. (Like the activation operators)
    :param op: op tensor
    :return: None
    """
    tensor_name = op.op.name
    tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(op))


def model(sequence_data, sequence_length, sequence_alleles, keep_prob):
    """
    Define the model.
    :param sequence_data: shape = [batch_size, max_length_of_batch, 20, 1]
    :param sequence_length: shape = [batch_size, 1]
    :param sequence_alleles: shape = [batch_size, 2, 20]
    :param keep_prob: the parameter of the dropout layer.
    :return: the outputs of the model.
    """
    batch_size_op = tf.shape(sequence_data)[0]

    with tf.variable_scope('Evolution_Module'):
        blosum_matrix = sio.loadmat('./data/pssm.mat')['pssm']

        initializer_filters = tf.reshape(tf.constant(blosum_matrix, dtype=tf.float32), [1, 20, 1, 20])
        PSSM_filters = tf.get_variable('PSSM_filters', initializer=initializer_filters, dtype=tf.float32)

        input = tf.reshape(sequence_data, [batch_size_op, 30, 20, 1])
        Evolution_sequence_data = tf.nn.conv2d(input, PSSM_filters, strides=[1, 1, 20, 1], padding='SAME')
        # Evolution_sequence_data.shape = [batch_size, 30, 1, 20]

    _variable_summary(PSSM_filters)


    with tf.variable_scope('Merge_Module'):
        Merge_sequence_data = tf.concat([sequence_alleles, tf.reshape(Evolution_sequence_data, [-1, 30, 20])], axis=1)
        sequence_length = tf.add(sequence_length, 2)


    with tf.variable_scope('Bid_LSTM_Module') as scope:
        cell_fw = tf.nn.rnn_cell.LSTMCell(num_units=64, initializer=tf.glorot_normal_initializer(seed=0), name='Forward_Lstm')
        cell_fw_dropout = tf.nn.rnn_cell.DropoutWrapper(cell_fw, output_keep_prob=keep_prob, seed=0)
        cell_bw = tf.nn.rnn_cell.LSTMCell(num_units=64, initializer=tf.glorot_normal_initializer(seed=0), name='Backward_Lstm')
        cell_bw_dropout = tf.nn.rnn_cell.DropoutWrapper(cell_bw, output_keep_prob=keep_prob, seed=0)
        init_fw = cell_fw.zero_state(batch_size_op, dtype=tf.float32)
        init_bw = cell_bw.zero_state(batch_size_op, dtype=tf.float32)

        evolution_sequence_data = tf.reshape(Merge_sequence_data, [-1, 32, 20])
        sequence_length = tf.reshape(sequence_length, [-1])

        # Only we define the dynamic_rnn, the weights in LSTMCell would be defined.
        bidrnn_outputs, bidrnn_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw_dropout, cell_bw=cell_bw_dropout,
                                                                       inputs=evolution_sequence_data,
                                                                       sequence_length=sequence_length,
                                                                       initial_state_fw=init_fw,
                                                                       initial_state_bw=init_bw)

        # lstm_params = tf.trainable_variables(scope=scope.name)
        cell_fw_weights = cell_fw.weights[0]
        cell_fw_biases = cell_fw.weights[1]
        cell_bw_weights = cell_bw.weights[0]
        cell_bw_biases = cell_bw.weights[1]

        fw_lstm_outputs = bidrnn_states[0][1]
        bw_lstm_outputs = bidrnn_states[1][1]
        Lstm_outputs = tf.concat((fw_lstm_outputs, bw_lstm_outputs), axis=1)
        # shape = [batch_size, 128]

    _variable_summary(cell_fw_weights)
    _variable_summary(cell_fw_biases)
    _variable_summary(cell_bw_weights)
    _variable_summary(cell_bw_biases)


    with tf.variable_scope('Dense_1_Module'):
        initializer_weights = tf.truncated_normal_initializer(stddev=0.4, seed=0)
        initializer_biases = tf.constant_initializer(0.1)
        weights = tf.get_variable('weight', [128, 64], initializer=initializer_weights, dtype=tf.float32)
        biases = tf.get_variable('biases', [64], initializer=initializer_biases, dtype=tf.float32)

        temp = tf.reshape(Lstm_outputs, [-1, 128])
        Dense_1_outputs = tf.nn.leaky_relu(tf.nn.xw_plus_b(temp, weights, biases), name='Dense_1_outputs')

    _variable_summary(weights)
    _variable_summary(biases)
    _activation_summary(Dense_1_outputs)


    with tf.variable_scope('Dense_2_Module'):
        initializer_weights = tf.truncated_normal_initializer(stddev=0.4, seed=0)
        initializer_biases = tf.constant_initializer(0.1)
        weights = tf.get_variable('weight', [64, 1], initializer=initializer_weights, dtype=tf.float32)
        biases = tf.get_variable('biases', [1], initializer=initializer_biases, dtype=tf.float32)

        temp = tf.reshape(Dense_1_outputs, [-1, 64])
        logits = tf.nn.xw_plus_b(temp, weights, biases)
        predicted_values = tf.reshape(tf.sigmoid(logits), [-1], name='Predicted_values')
        # predicted_values.shape = [batch_size, 1]

    _variable_summary(weights)
    _variable_summary(biases)
    _activation_summary(predicted_values)

    return predicted_values

def losses(predicted_values, label_values):
    """
    Define the loss.
    :param value: tf.float32, [batch_size, n_classes]
    :param labels: tf.float32, [batch_size, n_classes]
    :return:
    """
    with tf.variable_scope('loss') as scope:
        with tf.name_scope('mse_per_example'):
            #loss = tf.losses.mean_squared_error(label_values, predicted_values)
            loss = tf.keras.losses.binary_crossentropy(label_values, predicted_values)
        tf.summary.scalar('loss_value', loss)
    return loss

def training(loss, learning_rate):
    """
    Define the training operation.
    :param loss: the loss.
    :param learning_rate: learning rate.
    :return: training operation.
    """
    with tf.name_scope('optimizer') as scope:
        learning_rate_op = tf.constant(learning_rate, dtype=tf.float32, name='learning_rate')
        optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate_op)
        global_step_op = tf.train.get_or_create_global_step()
        train_op = optimizer.minimize(loss, global_step= global_step_op)

        tf.summary.scalar('global_step', global_step_op)
        tf.summary.scalar('learning_rate', learning_rate_op)

    return train_op, learning_rate_op, global_step_op