# -*- coding: utf-8 -*-
"""
Created on Tue Mar 24 12:30:48 2015

@author: adrienkuntz
"""

##### Calculation of the bispectrum and the trispectrum in perturbation theory
##### The trispectrum is evaluated in points (k1, -k1, k2, -k2) for the purpose of the study
##### In this program uij means cos(theta(i,j)) with theta(i,j) the angle between ki and kj

import numpy as np
#from scipy import integrate
#import matplotlib.pyplot as plt
#import math
from Params import *


### Physical parameters

#omegam0 = 0.27

#mu = 3./7 * np.exp(-2/63. * np.log(omegam0))



### Get the power spectrum from an external file


def power_spectrum():
    
    monfichier = open('matterpower.dat','r')
    matterpower = monfichier.readlines()
    monfichier.close()
    result = np.array([[0,0]])
    for ligne in matterpower:
        part = ligne.split('    ')
        nb = np.array([[float(part[1]),float(part[2])]])
        result = np.concatenate((result,nb))

    return (result[1:,:])
    
powerspectrum = power_spectrum()
N = powerspectrum.size/2    #powerspectrum is a 2 column vector

###Growth factor

def E2(z) : return omegam0*(1+z)**3 + omegalambda0

def D(z) :
    omegam = omegam0 / E2(z) * (1+z)**3 
    omegalambda = omegalambda0 / E2(z)
    return 5/2. * omegam / ((1+z) * (pow(omegam, 4/7.) - omegalambda + (1 + omegam/2)*(1 + omegalambda/70)))



### function which interpolates the power spectrum

def P(k, z) :
    if k < powerspectrum[0,0] : print 'Points missing on the left : k_asked = {}, k_min = {}'.format(k, powerspectrum[0,0])
    elif k > powerspectrum[-1,0] : print 'Points missing on the right : k_asked = {}, k_max = {}'.format(k, powerspectrum[-1,0])
    
    return np.interp(k, powerspectrum[:,0], powerspectrum[:,1]) * (D(z)/D(0))**2
   


#### Bispectrum

def norme2(k1, k2, u12) :
    
    tmp = k1**2 + k2**2 + 2*k1*k2*u12
    if tmp < 0. : print 'oops ! Negative norm2 : {} ! k1 = {}, k2 = {}'.format(tmp, k1, k2)
    return np.sqrt(np.abs(tmp))
    
def norme3(k1, k2, k3, u12, u13, u23) :
    
    tmp = k1**2 + k2**2 + k3**2 + 2*k1*k2*u12 + 2*k1*k3*u13 + 2*k2*k3*u23
    if tmp < 0. : print 'oops ! Negative norm3 : {} !'.format(tmp)
    return np.sqrt(np.abs(tmp))


def F2(k1, k2, u12) : return 5./7 + 1/2.*(k1/k2 + k2/k1)*u12 + 2./7*u12**2
    

def BPT(k1, k2, k3, u12, u13, u23, z) :

    return (2*F2(k1, k2, u12)*P(k1, z)*P(k2, z) + 2*F2(k1, k3, u13)*P(k1, z)*P(k3, z) + 2*F2(k2, k3, u23)*P(k2, z)*P(k3, z))



### Trispectrum
##return T(k1, -k1, k2, -k2)

def G2(k1, k2, u12) : return 3./7 + 1/2.*(k1/k2 + k2/k1)*u12 + 4./7*u12**2
    


def F3(k1, k2, k3, u12, u13, u23) :
    n123 = norme3(k1, k2, k3, u12, u13, u23)**2
    n12 = norme2(k1, k2, u12)**2
    n23 = norme2(k2, k3, u23)**2
    
    
    if n12 == 0. :                  #F3(k1, -k1, k2) is frequently used
        terme1 = 7.*(1 + k2/k1*u12 + k3/k1*u13)
        terme2 = n123*(k1*k2*u12 + k1*k3*u13)/(k1**2 * n23)
        return 1/18.*(terme1 * F2(k2, k3, u23) + terme2 * G2(k2, k3, u23))
        
    elif n23 == 0. :
        terme3 = 7.*(k1**2 + k2**2 + 2*k1*k2*u12 + k1*k3*u13 + k2*k3*u23)/(n12) + n123*(k1*k3*u13 + k2*k3*u23)/(n12*k3**2)
        return 1/18.*(terme3 * G2(k1, k2, u12))

    else :
        terme1 = 7.*(1 + k2/k1*u12 + k3/k1*u13)
        terme2 = n123*(k1*k2*u12 + k1*k3*u13)/(k1**2 * n23)
        terme3 = 7.*(k1**2 + k2**2 + 2*k1*k2*u12 + k1*k3*u13 + k2*k3*u23)/(n12) + n123*(k1*k3*u13 + k2*k3*u23)/(n12*k3**2)
        return 1/18.*(terme1 * F2(k2, k3, u23) + terme2 * G2(k2, k3, u23) + terme3 * G2(k1, k2, u12))
    
    
    
    
def F3s(k1, k2, k3, u12, u13, u23):
    return 1/6.*(F3(k1, k2, k3, u12, u13, u23) + F3(k1, k3, k2, u13, u12, u23) + F3(k2, k1, k3, u12, u23, u13) + F3(k3, k1, k2, u13, u23, u12) + F3(k3, k2, k1, u23, u13, u12) + F3(k2, k3, k1, u23, u12, u13))



def TPT(k1, k2, u12, z):
    n12 = norme2(k1, k2, u12)
    m12 = norme2(k1, k2, -u12)
    
    return ( 4*P(n12, z)*(F2(k1, n12, -(k1**2 + k1*k2*u12)/(k1*n12)) * P(k1, z) + F2(k2, n12, -(k2**2 + k1*k2*u12)/(k2*n12)) * P(k2, z))**2
           + 4*P(m12, z)*(F2(k1, m12, (k1*k2*u12 - k1**2)/(k1*m12)) * P(k1, z) + F2(k2, m12, (k1*k2*u12 - k2**2)/(k2*m12)) * P(k2, z))**2
           + 12*(F3s(k1, k1, k2, -1., u12, -u12) * P(k1, z)**2 * P(k2, z) + F3s(k1, k2, k2, u12, -u12, -1.) * P(k1, z) * P(k2, z)**2) )