import numpy as np
import pandas as pd
import scipy.stats as stats
import scipy.optimize as optimize
from scipy.special import gammaln,gamma, psi, betaln
import multiprocessing
from scipy.sparse import issparse
import numba
import logging

# Configuração do logging para salvar em arquivo e mostrar no console
log_format = '%(asctime)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_format, handlers=[
    logging.FileHandler("polya_urn_log.log"),
    logging.StreamHandler()
])


class PolyaUrnFilter:
    def __init__(self, a=-1, apr_lvl=10, parallel=False, num_processes=None):
        self.a = a
        self.apr_lvl = apr_lvl
        self.parallel = parallel
        self.num_processes = num_processes
        


    def eq_from_lhood(self, X, a):
        """
        Optimized version of the derivative of the likelihood function.

        Args:
        X (numpy.ndarray): Array containing w, k, and s values.
        a (float): Free parameter of the filter.

        Returns:
        float: Sum of the derivative of the likelihood.
        """
        w, k, s = X.T
        a_inv = 1 / a
        psi_a_inv = psi(a_inv)
        DL = a_inv ** (-2) * (psi_a_inv + (-1 + k) * psi(a_inv * (-1 + k)) - k * psi(a_inv * k) + \
                              k * psi(a_inv * k + s) - psi(a_inv + w) - (-1 + k) * psi(a_inv * (-1 + k + a * s - a * w)))
        DL = np.nan_to_num(DL, nan=0.0, posinf=1e100, neginf=-1e100)
        return np.sum(DL)
    
    
    def lhood(self, X, a):
        """
        Optimized version of the likelihood function.

        Args:
        X (numpy.ndarray): Array containing w, k, and s values.
        a (float): Free parameter of the filter.

        Returns:
        float: Sum of the negative log-likelihood.
        """
        w, k, s = X.T
        if a == 0:  # Binomial case
            P = stats.binom.pmf(w, s, 1 / k)
        else:
            n = s
            A = 1 / a
            B = (k - 1) / a
            x = w
            P = np.exp(gammaln(n + 1) + betaln(x + A, n - x + B) - gammaln(x + 1) - gammaln(n - x + 1) - betaln(A, B))
        P[P == 0] = 1e-20
        out = np.sum(-np.log(P))
        if np.isinf(out):
            out = np.sign(out) * 1e100
        return out

    def polya_pdf(self, w, s, k, a):
        if a == 0:  # Binomial case
            p = stats.binom.pmf(w, s, 1 / k)
        else:
            n = s
            A = 1 / a
            B = (k - 1) / a
            x = w
            p = np.exp(gammaln(n + 1) + betaln(x + A, n - x + B) - gammaln(x + 1) - gammaln(n - x + 1) - betaln(A, B))
        return p



    def get_ML_estimate(self, W):
        """
        Get the maximum likelihood estimation of the free parameter "a".

        Args:
        W (numpy.ndarray): Adjacency matrix of the network.

        Returns:
        tuple: Tuple containing the best estimate of 'a' and the error.
        """
        logging.info('Iniciando a estimativa de máxima verossimilhança (ML) de "a".')
        if np.array_equal(W, W.T):
            U = np.triu(W)
            i, j = np.nonzero(U)
            w = U[i, j]
        else:
            i, j = np.nonzero(W)
            w = W[i, j]

        # Get degrees and strengths
        k_in = np.sum(W != 0, axis=1)
        s_in = np.sum(W, axis=1)
        k_out = np.sum(W != 0, axis=0)
        s_out = np.sum(W, axis=0)
        w = np.concatenate([w, w])
        k = np.concatenate([k_out[i], k_in[j]])
        s = np.concatenate([s_out[i], s_in[j]])

        # Remove links with degree 1
        mask = k != 1
        w = w[mask]
        s = s[mask]
        k = k[mask]

        f = lambda a: self.eq_from_lhood(np.column_stack([w, k, s]), a)
        g = lambda a: self.lhood(np.column_stack([w, k, s]), a)

        #options_lsq = {'disp': False, 'bounds': (0, 15)}
        # First find the maximum value of the likelihood
        x0_res = optimize.least_squares(g, 0.5, bounds=(0, 20), method='trf')
        x0 = x0_res.x

        # Use this value to calculate the value that puts the derivative to 0
        res = optimize.least_squares(f, x0, bounds=(0, 15), method='trf')
        a_best = res.x
        err = res.fun

        # Try higher precision if feval is not close to 0
        if np.any(np.abs(err) > 1e-6):
            res = optimize.least_squares(f, x0, bounds=(0, 15), method='trf', max_nfev=5000, ftol=1e-15, xtol=1e-15, gtol=1e-15)
            a_best = res.x
            err = res.fun
            if np.any(np.abs(err) > 1e-6):
                print('Try stricter minimization options')
        logging.info(f'Estimativa finalizada, a = {a_best}')
        return a_best, err


    def polya_cdf_parallel(self, args):
        w_val, s_val, k_val, a, L = args
        if a == 0:  # Caso binomial
            p_val = stats.binom.cdf(w_val, s_val, 1 / k_val)
        else:
            A = 1 / a
            B = (k_val - 1) / a
            # As condições para usar a aproximação
            idx1 = s_val - w_val >= L * (B + 1)
            idx2 = w_val >= L * max(A, 1)
            idx3 = s_val >= L * max(k_val / a, 1)
            idx4 = k_val >= L * (a - 1 + 1e-20)
            # Se as condições forem satisfeitas, usamos a aproximação
            if idx1 and idx2 and idx3 and idx4:  
                gamma_A = gamma(A)
                p_val = (1 / gamma_A) * (1 - w_val / s_val) ** B * (w_val * k_val / (s_val * A)) ** (A - 1)
                # Verificação e impressão para comparação, caso necessário
                if p_val > 1:
                    print("Erro na aproximação")
            else:  # Senão, cálculo exato
                x = np.arange(w_val)
                log_p_val = gammaln(s_val + 1) + betaln(x + A, s_val - x + B) - gammaln(x + 1) - gammaln(s_val - x + 1) - betaln(A, B)
                p_val = 1 - np.sum(np.exp(log_p_val))
                if p_val > 1:
                    print("Erro no cálculo exato")
        return p_val




    def polya_cdf(self, w, s, k, a, L, parallel, num_processes=None):
        p = np.full(len(w), np.nan)

        if parallel:
            # Cria um pool de processos
            with multiprocessing.Pool(processes=num_processes) as pool:
                # Prepara os argumentos para passar para a função polya_cdf_parallel
                args = [(w_val, s_val, k_val, a, L) for w_val, s_val, k_val in zip(w, s, k)]

                # Mapeia a função polya_cdf_parallel para os argumentos usando o pool de processos
                results = pool.map(self.polya_cdf_parallel, args)

                # Copia os resultados para o array p
                for idx, result in enumerate(results):
                    p[idx] = result
        else:
            # Código para processamento sequencial
            for idx in range(len(w)):
                p[idx] = self.polya_cdf_parallel((w[idx], s[idx], k[idx], a, L))

        return np.maximum(p, 0)

    
    def PF(self, W, a=None, apr_lvl=None, parallel=None, num_processes=None):

        """
        Polya Filter function to calculate P-values for each link of the network.

        Args:
        W (numpy.ndarray): Adjacency matrix of the network.
        a (float): Free parameter of the filter (optional, default is self.a).
        apr_lvl (int): Constant defining the regime for the approximate p-value form (optional, default is self.apr_lvl).
        parallel (int): Flag for parallel computing (optional, default is self.parallel).

        Returns:
        numpy.ndarray: P-values prescribed by the Polya Filter for each link.
        """
        logging.info('Iniciando o cálculo de p-valores.')

        # Use valores padrão da instância se os argumentos não forem fornecidos
        if a is None:
            a = self.a
        if apr_lvl is None:
            apr_lvl = self.apr_lvl
        if parallel is None:
            parallel = self.parallel
            
        # Check if the network is symmetric (i.e., undirected) and get the edge list
        if issparse(W):
            W = W.toarray()
        if np.array_equal(W, W.T):
            U = np.triu(W)
            i, j = np.nonzero(U)
            w = U[i, j]
        else:
            i, j = np.nonzero(W)
            w = W[i, j]

        # Get the degrees and strengths
        k_in = np.sum(W != 0, axis=1)
        s_in = np.sum(W, axis=1)
        k_out = np.sum(W != 0, axis=0)
        s_out = np.sum(W, axis=0)
        k_in = k_in[j]
        k_out = k_out[i]
        s_in = s_in[j]
        s_out = s_out[i]

        # If a < 0, get the ML estimates
        if a < 0:
            a, err = self.get_ML_estimate(W)

        # Use the asymptotic form if non-integer weights are present
        if np.any(w % 1 != 0):
            apr_lvl = 0
        
        # Calculate p-values
        p1 = self.polya_cdf(w, s_in, k_in, a, apr_lvl, parallel)
        p2 = self.polya_cdf(w, s_out, k_out, a, apr_lvl, parallel)
        P = np.minimum(p1, p2)
    
        
        # Handle the case k=1
        P[k_in == 1] = p2[k_in == 1]
        P[k_out == 1] = p1[k_out == 1]
        P[(k_in == 1) & (k_out == 1)] = 1

        logging.info('Cálculo de p-valores finalizado.')
        return P

    def extract_backbone(self, df, weight_col, a, apr_lvl, parallel, num_processes=None):
        logging.info('Iniciando a extração do backbone.')
    
        #Padronizando o DF
        df['src'], df['trg'] = np.minimum(df['src'], df['trg']), np.maximum(df['src'], df['trg'])

        # Criando um mapeamento de índices
        unique_nodes = np.unique(df[['src', 'trg']].values.ravel())
        print(unique_nodes)
        index_map = {node: idx for idx, node in enumerate(unique_nodes)}

        # Salvando os IDs originais
        df['src_original'] = df['src']
        df['trg_original'] = df['trg']

        # Remapeando SRC e TRG para índices contíguos
        df['src'] = df['src'].map(index_map)
        df['trg'] = df['trg'].map(index_map)
        
        df[weight_col] = df[weight_col].astype(float)
        df['src'] = df['src'].astype(int)
        df['trg'] = df['trg'].astype(int)
        
        # Ordenando os pares de nós (src, trg) no DataFrame
        df.sort_values(by=['src', 'trg'], inplace=True)


        # Criando a matriz de adjacência usando operações do NumPy
        n = len(unique_nodes)
        W = np.zeros((n, n), dtype=float)
        W[df['src'].values, df['trg'].values] = df[weight_col].values
        W[df['trg'].values, df['src'].values] = df[weight_col].values  # Grafo não direcionado
        
        # Calculando os p-valores
        P_values = self.PF(W, a=a, apr_lvl=apr_lvl, parallel=parallel, num_processes=num_processes)


        # Criando um DataFrame para mapear p-valores para pares originais
        # Mantendo a ordem das arestas conforme aparecem no DataFrame df
        df = pd.DataFrame({
            'src': df['src_original'],
            'trg': df['trg_original'],
             weight_col:df[weight_col].values,
            'p_valor': P_values
        })
        logging.info('Extração do backbone finalizada.')

        return df

