from __future__ import print_function
from optparse import OptionParser    
import sys
import os
import glob
import copy
import numpy as np
import pandas as pd
import random

#np.random.seed(0)
#random.seed(0)

DEFREF="/...PATH_TO.../R64_nucl.fasta"

GCTA="/...PATH_TO.../GCTA/gcta64"
PLINK="/...PATH_TO.../Plink/plink-1.07-x86_64/plink"


###############################################
#
#                CLASS DEFINITION
#
###############################################

class Site():
    """
    Class that stores a polymorphic site
    attributes:
    -chromosome
    -position
    -reference allele
    -alternative allele
    """
    def __init__(self, chromosome, position, ref_all, alt_all):
        self.chromosome=chromosome
        self.position=position
        self.ref_all=ref_all
        self.alt_all=alt_all
        self.mut_strains=[]
        
    def __repr__(self):
        return "chromosome:{0}, position:{1}, reference:{2}, alternative:{3}, mut_strains:{4}".format(self.chromosome, self.position, self.ref_all, self.alt_all, self.mut_strains)
    
    def get_allele(self, strain):
        """
        returns the allelic version of the strain for the concerned site
        """
        allele = self.alt_all if strain in self.mut_strains else self.ref_all
        return "{0} {1}".format(allele, allele)



###############################################
#
#                Genotype Simulation
#
###############################################


def simulate_genotype(args):
    """
    Simulates genotype.
    Takes as arguments the number of individuals to simulate, the fasta reference, and the plink fileset name.
    """

    usage = "Usage: %prog simulate_genotype [-n nb_indiv -o plink_output]"
    parser = OptionParser(usage=usage)
    parser.add_option("-n", "--nb_indiv", dest="nb_indiv", type="int", help="number of individuals", default=1000)
    parser.add_option("-r", "--ref", dest="ref", type="string", help="Fasta reference sequence", default=DEFREF)
    parser.add_option("-o", "--plink_out", dest="plink_out", type="string", help="Output plink files prefix",)
    parser.add_option("-m", "--mu", dest="mu", type="int", help="Average number of SNP by strain", default=43000)
    parser.add_option("-s", "--sigma", dest="sigma", type="int", help="SD of number of SNP by strain", default=10000)
    parser.add_option("-f", "--freq", dest="freq", type="float", help="Probability of site to be mutated for each ind", default=0.014)
    (opts, args) = parser.parse_args(args)

    if not opts.plink_out:
        exit("Please specify an output prefix. (-o option)")

    nb_indiv = opts.nb_indiv
    ref=opts.ref
    plink_out = opts.plink_out
    mu=opts.mu
    sigma=opts.sigma
    freq=opts.freq
    
    outped = open(plink_out+'.ped', 'w')
    outmap = open(plink_out+'.map', 'w')
    outsites = open(plink_out+"_sites.log",'w')
    outinfo = open(plink_out+"_sites.info",'w')

    dict_ref=read_ref(ref)
    poly_sites = determine_polymorphic_sites(dict_ref, freq)
    # normal distribution of nb SNP by ind
    norm_distr = np.random.normal(mu, sigma, nb_indiv)
    nb_snps = [abs(int(x)) for x in norm_distr]
    
    
    # beta decay of probabilities of sites being mutated
    ### PARAMETERS ####
    maf_distr = np.random.beta(0.1, 10, len(poly_sites))
    summ=sum(maf_distr)
    maf_distr_norm = [x/summ for x in maf_distr*1.]
    print("strain\tnb_snp", file=outsites)
    # for each indiv
    for idx, nb_snp in enumerate(nb_snps) :
        # select the right number of SNPs to mutate
        sites_avail=np.random.choice(poly_sites, nb_snp, p=maf_distr_norm, replace=False)
        print(idx, nb_snp, file=outsites, sep='\t')
        for site in sites_avail:
            site.mut_strains.append(idx)
        
            
    indices_to_del=[]
    for idx, site in enumerate(poly_sites):
        if (not site.mut_strains) or len(site.mut_strains)==nb_indiv:
            indices_to_del.append(idx)
        else:
            # make map
            print(site.chromosome, "rs{0}".format(idx+1), "0", site.position, file=outmap, sep='\t')
    final_sites = [i for j, i in enumerate(poly_sites) if j not in indices_to_del]
    
    for id, site in enumerate(final_sites):
        print(id, site.chromosome, site.position, site.ref_all, site.alt_all, len(site.mut_strains), sep='\t', file=outinfo)
    # for each indiv
    for idx, _ in enumerate(nb_snps) :
        # get allele for each site
        alleles = [site.get_allele(idx) for site in final_sites]
        # make ped
        print("indiv{0} indiv{1} 0 0 0 -9 ".format(idx, idx)+" ".join(alleles),file=outped)

    outped.close()
    outmap.close()
    outsites.close()
    
def read_ref(ref):
    """
    reads fasta file for reference (SACE)
    returns dictionary with sequence by chromosome
    """
    ref_seq = {}
    chr=''
    seq=''
    for line in open(ref, 'r'):
        line = line.strip()
        if line.startswith('>'):
            if seq != '':
                ref_seq[chr] = list(seq)
                seq=''
            chr = int(line.replace('>chromosome', ''))
        else:
            seq+=line
    ref_seq[chr] = list(seq)
    return ref_seq


def determine_polymorphic_sites(ref_seq, freq=0.014):
    sites = []
    for idchr, chr in enumerate(sorted(ref_seq)):
        chr= idchr +1
        for idnucl, nucl in enumerate(ref_seq[chr]):
            pos=idnucl+1
            mut_val = np.random.random_sample()
            if mut_val < freq:
                bases=['A', 'C', 'G', 'T']
                bases.remove(nucl)
                new_nucl = np.random.choice(bases)
                site=Site(chr, pos, nucl, new_nucl)
                sites.append(site) 
    return sites


###############################################
#
#                Phenotype Simulation
#
###############################################

def simulate_several_matrices(args):
    usage="""Usage: GWAS_sim.py simulate_several_matrices
                 -l         matrix list
                 -n         nb causal SNP (1)
                 -m            MAF (0.0)
                 -u         upper maf threshold
                 -H            heritability (0.1)
                 -r            replicates (1)
                 -o            output prefix
            """
    parser = OptionParser(usage=usage)
    parser.add_option("-l", "--matrix_list", dest="matrix_list", type="string", help="path to list of Plink input files prefix",)
    parser.add_option("-n", "--nb_snp", dest="nb_snps", type="int", help="Number of causal SNPs to simulate", default=1,)
    parser.add_option("-m", "--maf", dest="maf", type="float", help="Minor allele frequency lower bound threshold", default=0.0,)
    parser.add_option("-u", "--umaf", dest="umaf", type="float", help="Minor allele frequency upper bound threshold, default 1", default=0.5)
    parser.add_option("-H", "--heritability", dest="heritability", type="float", help="Heritability (or heritability of liability)", default=0.1,)
    parser.add_option("-r", "--replicates", dest="replicates", type="int", help="Number of simulations replicates", default=1,)
    parser.add_option("-o", "--outdir", dest="outdir", type="string", help="Output directory",)
    
    (opts, args) = parser.parse_args(args)
    
    matrix_list=opts.matrix_list
    low_maf = opts.maf
    up_maf = opts.umaf
    heritability = opts.heritability
    replicates = opts.replicates
    nb_snps = opts.nb_snps
    outdir = opts.outdir
    
    for line in open(matrix_list, 'r'):
        if line.startswith("#"): continue
        line = line.strip()
        cols=line.split('\t')
        mat = cols[0]
        outfreq="{0}/{1}_gcta_pruned".format(outdir, os.path.basename(mat))
        cmd1 = "{0} --bfile {1} --maf {2} --max-maf {3} --out {4} --freq --thread-num 10".format(GCTA, mat, low_maf, up_maf,outfreq)
        print(cmd1)
        os.system(cmd1)
        freq_file = outfreq+'.freq'
        coords={}
        for bim_line in open(mat+".bim", 'r'):
            bim_line = bim_line.strip()
            bim_cols= bim_line.split('\t')
            chr = bim_cols[0]
            pos = bim_cols[3]
            rs = bim_cols[1]
            coords[rs]=(chr, pos)
        out=freq_file+'.tsv'
        with open(out, 'w') as fout:
            for freq_line in open(freq_file, 'r'):
                freq_line = freq_line.strip()
                freq_cols = freq_line.split('\t')
                if freq_cols[0] in coords:
                    print(freq_cols[0], freq_cols[1], freq_cols[2], coords[freq_cols[0]][0], coords[freq_cols[0]][1], file = fout, sep='\t')
                else:
                    exit("{0} doesn't exist in bim".format(freq_cols[0]))
    print("done")
    
    markers={}
    nb_mat=len(glob.glob("{0}/*.tsv".format(outdir)))
    for info_f in glob.glob("{0}/*.tsv".format(outdir)):
        for line in open(info_f, 'r'):
            line = line.strip()
            cols = line.split('\t')
            chr = cols[3]
            pos=cols[4]
            if (chr, pos) not in markers:
                markers[(chr, pos)] = 0
            markers[(chr, pos)] += 1
    
    #to_keep={k: v for k, v in markers.iteritems() if v == nb_mat}
    for k, v in markers.iteritems():
        if v == nb_mat:
            to_keep[k] =v
    print(nb_mat)
    print(len(to_keep))
    print(len(markers))
                
    
    
    
    
    

def simulate_qt(args):
    usage = """Usage: GWAS_sim.py simulate_qt
                 -i         input plink prefix
                 -n         nb causal SNP (1)
                 -m            MAF (0.0)
                 -u         upper maf threshold
                 -H            heritability (0.1)
                 -r            replicates (1)
                 -o            output prefix
            """
    parser = OptionParser(usage=usage)
    parser.add_option("-i", "--input", dest="input", type="string", help="Plink input files prefix",)
    parser.add_option("-n", "--nb_snp", dest="nb_snps", type="int", help="Number of causal SNPs to simulate", default=1,)
    parser.add_option("-m", "--maf", dest="maf", type="float", help="Minor allele frequency threshold", default=0.0,)
    parser.add_option("-u", "--umaf", dest="umaf", type="float", help="Minor allele frequency upper threshold, default 1", default=1)
    parser.add_option("-H", "--heritability", dest="heritability", type="float", help="Heritability (or heritability of liability)", default=0.1,)
    parser.add_option("-r", "--replicates", dest="replicates", type="int", help="Number of simulations replicates", default=1,)
    parser.add_option("-f", "--effectsize", dest="effect", type="float", help="effect size to simulate")
    parser.add_option("-o", "--output", dest="output", type="string", help="Output files prefix",)

    (opts, args) = parser.parse_args(args)
    
    if not opts.input:
        print(usage)
        exit("Please specify input files prefix (-i option)")
    input = opts.input
 
    if not opts.output:
        print(usage)
        exit("Please specify output files prefix (-o option)")
    elif not os.path.exists(os.path.dirname(opts.output)):
        os.makedirs(os.path.dirname(opts.output))
        
    output = opts.output
    maf_thresh = opts.maf
    heritability = opts.heritability
    replicates = opts.replicates
    nb_snps = opts.nb_snps
    umaf = opts.umaf
 
    log = open(output+'.simu.log', 'w')

    # make freq file with selected snps (cutoffs)
    outfreq="{0}_gcta_pruned".format(output)
    cmd1 = "{0} --bfile {1} --maf {2} --max-maf {3} --out {4} --freq --thread-num 10".format(GCTA, input, maf_thresh, umaf,outfreq)
    print(cmd1)
    print("Nb causal SNPs : {0}".format(nb_snps), file=log)
    os.system(cmd1)
    freq_file = outfreq+'.freq'
    # create list of nb_snps snps as causal loci (kept after pruning)
    if opts.effect:
        snp_file = get_causal_snplist(freq_file, nb_snps, output, opts.effect)   
    else:
        snp_file = get_causal_snplist(freq_file, nb_snps, output, "NA")
    print("input prefix : {0}".format(input), file=log) 
    print("maf threshold : {0}".format(maf_thresh), file=log)
    print("maf upper threshold : {0}".format(umaf), file=log)
    print("heritability : {0}".format(heritability), file=log)
    print("replicates : {0}".format(replicates), file=log)
    print("output prefix : {0}".format(output), file=log)
    log.close()


    # simuation of quantitative trait
    cmd2 = "{0} --bfile {1} --maf {2} --simu-qt --simu-causal-loci {3} --simu-hsq {4} --simu-rep {5} --out {6}".format(GCTA, input, maf_thresh, snp_file, heritability, replicates, output)
    os.system(cmd2)
    simulated_pheno = output+'.phen'
    # convert the simulated phenotypes to a Gemma readable fam file
    phen2_file_simu = simulated_pheno+'.2.tsv'
    fam_file_simu = output+'.fam'
    simulated_pheno_to_fam(simulated_pheno, fam_file_simu)
    simulated_pheno_to_phen2(simulated_pheno, phen2_file_simu)
    #os.symlink(input+'.bed', output+'.bed')
    #os.symlink(input+'.bim', output+'.bim')
    print("Everything is ready for GWAS now")

def simulate_qt_from_file(args):
    usage = """Usage: GWAS_sim.py simulate_qt
                 -i         input plink prefix
                 -n         nb causal SNP (1)
                 -m            MAF (0.0)
                 -u         upper maf threshold
                 -H            heritability (0.1)
                 -r            replicates (1)
                 -o            output prefix
            """
    parser = OptionParser(usage=usage)
    parser.add_option("-i", "--input", dest="input", type="string", help="Plink input files prefix",)
    parser.add_option("-n", "--nb_snp", dest="nb_snps", type="int", help="Number of causal SNPs to simulate", default=1,)
    parser.add_option("-m", "--maf", dest="maf", type="float", help="Minor allele frequency threshold", default=0.0,)
    parser.add_option("-u", "--umaf", dest="umaf", type="float", help="Minor allele frequency upper threshold, default 1", default=1)
    parser.add_option("-H", "--heritability", dest="heritability", type="float", help="Heritability (or heritability of liability)", default=0.1,)
    parser.add_option("-r", "--replicates", dest="replicates", type="int", help="Number of simulations replicates", default=1,)
    parser.add_option("-o", "--output", dest="output", type="string", help="Output files prefix",)
    parser.add_option("-d", "--runid", dest="runid", type="int", help="runid",)
    parser.add_option("-k", "--causal", dest="causal", type="string", help="causal snp file",)
    parser.add_option("-b", "--bim", dest="oldbim", type="string", help="old bim file",)


    (opts, args) = parser.parse_args(args)
    
    if not opts.input:
        print(usage)
        exit("Please specify input files prefix (-i option)")
    input = opts.input
 
    if not opts.output:
        print(usage)
        exit("Please specify output files prefix (-o option)")
    elif not os.path.exists(os.path.dirname(opts.output)):
        os.makedirs(os.path.dirname(opts.output))
        
    output = opts.output
    maf_thresh = opts.maf
    heritability = opts.heritability
    replicates = opts.replicates
    nb_snps = opts.nb_snps
    umaf = opts.umaf
     
    log = open(output+'.simu.log', 'w')

    # make freq file with selected snps (cutoffs)
    outfreq="{0}_gcta_pruned".format(output)
    cmd1 = "{0} --bfile {1} --maf {2} --max-maf {3} --out {4} --freq --thread-num 10".format(GCTA, input, maf_thresh, umaf,outfreq)
    print(cmd1)
    print("Nb causal SNPs : {0}".format(nb_snps), file=log)
    os.system(cmd1)
    freq_file = outfreq+'.freq'
    # create list of nb_snps snps as causal loci (kept after pruning)

    print("input prefix : {0}".format(input), file=log) 
    print("maf threshold : {0}".format(maf_thresh), file=log)
    print("maf upper threshold : {0}".format(umaf), file=log)
    print("heritability : {0}".format(heritability), file=log)
    print("replicates : {0}".format(replicates), file=log)
    print("output prefix : {0}".format(output), file=log)
    log.close()
    
    oldbim={}
    newbim={}
    oldbim_newbim={}

    for line in open(input+'.bim', 'r'):
        line = line.strip()
        cols=line.split('\t')
        newbim[cols[0], cols[3]]=[cols[1]]
        
    
    for line in open(opts.oldbim, 'r'):
        line = line.strip()
        cols=line.split('\t')
        oldbim[cols[0], cols[3]]=[cols[1]]
        oldbim_newbim[cols[1]]=[newbim[cols[0], cols[3]]]
    

    
    dict_causal={}
    snpname=output+'.snplist'
    snp_file=open(snpname, 'w')

    file_causal=opts.causal
    header=True
    for line in open(file_causal, 'r'):
        if header:
            header=False
            continue
        line=line.strip()
        cols=line.split('\t')
        run=cols[0]
        if int(run) == int(opts.runid):
            snp=oldbim_newbim[cols[1]][0][0]
            effect =cols[3]
            if int(run)==1:
                print("&&&&&&&&&&&&&&&&&&&&&&&",snp, effect)
            print(snp, effect, sep='\t', file=snp_file)
    snp_file.close()


    # simuation of quantitative trait
    cmd2 = "{0} --bfile {1} --maf {2} --simu-qt --simu-causal-loci {3} --simu-hsq {4} --simu-rep {5} --out {6}".format(GCTA, input, maf_thresh, snpname, heritability, replicates, output)
    print(cmd2)
    os.system(cmd2)
    print("ok")
    simulated_pheno = output+'.phen'
    # convert the simulated phenotypes to a Gemma readable fam file
    fam_file_simu = output+'.fam'
    simulated_pheno_to_fam(simulated_pheno, fam_file_simu)
    phen2_file_simu = simulated_pheno+'.2.tsv'
    print(phen2_file_simu)
    simulated_pheno_to_phen2(simulated_pheno, phen2_file_simu)
    os.symlink(input+'.bed', output+'.bed')
    os.symlink(input+'.bim', output+'.bim')
    print("Everything is ready for GWAS now")


def get_causal_snplist(freq_file, nb_snps, out, effect):
    """
    freq_file contains the frequencies of all snps kept for the analysis (e.g. after maf pruning)
    """
    
    out_name = out+'.snplist'
    output = open(out_name, 'w')

    random_lines =random.sample(set(open(freq_file)), nb_snps)

    for line in random_lines:
        line = line.strip()
        snp =line.split()
        #print(snp[0], snp[2], file=output)
        if effect == "NA":
            print(snp[0], file=output)
        else:
            print(snp[0], effect, file=output)
    output.close()
    return out_name
    
def simulated_pheno_to_fam(pheno_file, out_fam):
    out = open(out_fam, 'w')
    with open(pheno_file, 'r') as input:
        for line in input:
            line = line.strip()
            cols = line.split()
            new_line = "{0}  {1}  0  0  0  ".format(cols[0], cols[1])
            new_line = new_line+"  ".join(cols[2:])
            print(new_line, file=out)

def simulated_pheno_to_phen2(pheno_file, out_phen2):
    out = open(out_phen2, 'w')
    print("strain\tsimulated_phenotype", file=out)
    with open(pheno_file, 'r') as input:
        for line in input:
            line = line.strip()
            cols = line.split()
            print(cols[0], cols[2], sep='\t', file=out)

            
###############################################
#
#                Permutations
#
###############################################

def run_permutations(args):
    usage = "Usage: %prog run_permutations [-p pheno -n nb_permut]"
    parser = OptionParser(usage=usage)
    parser.add_option("-p", "--pheno", dest="pheno_file", type="string", help="Input pheno file",)
    parser.add_option("-n", "--nb_permut", dest="nb_permut", type="int", help="number of permutations", default=100)
    parser.add_option("-o", "--outdir", dest="outdir", type="string", help="output directory",)
    parser.add_option("-e", "--extension", dest="extension", type="string", default='.fam.phen')
    (opts, args) = parser.parse_args(args)
    
    # check arguments
    if not opts.pheno_file:
        exit("Please specify a pheno file.")
    elif not opts.outdir or not os.path.exists(os.path.dirname(opts.outdir)):
        exit("Please specify a valid output directory path (-o option)")
    
    pheno_file = opts.pheno_file
    outdir = opts.outdir
    extension=opts.extension
    if not outdir.endswith('/'):
        outdir+='/'
    nb_permut =opts.nb_permut
    
    for i in range(1, nb_permut+1):
        outname = outdir+os.path.basename(pheno_file).replace(extension, '_perm_{0}.phen'.format(i))
        df = pd.read_csv(pheno_file, sep=" ", header=None)
        df[2] = np.random.permutation(df[2])
        df.to_csv(outname, sep=" ", header=False, index=False)

def permut_analysis(args):
    usage = "Usage: %prog permut_analysis [-d dirin -e extension]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="dir", type="string", help="Idirectory containing permutation association files",)
    parser.add_option("-e", "--extension", dest="extension", type="string", default='.gwas.txt')
    (opts, args) = parser.parse_args(args)

    if not opts.dir or not os.path.isdir(opts.dir):
        exit("Please specify a valid directory (-d option)")
    

    dir = opts.dir
    if not dir.endswith('/'):
        dir+='/'
    out = open(dir+'threshold.txt', 'w')
    pvals=[]
    for gwas in glob.glob(dir+'*.gwas.txt'):
        header=True
        with open(gwas, 'r') as g:
            for line in g:
                if header:
                    header=False
                    continue
                line = line.strip()
                cols=line.split('\t')
                pval = float(cols[5])
                pvals.append(pval)
                break
    thresh=np.percentile(pvals,5)
    print(thresh)
    print("x", file=out)
    print(thresh, file=out)



def list_significant_snps(args):
    """
    gets threshold after permutations
    extracts from real results the snps higher than threshold
    writes the snps in signif_snp1
    """
    usage = "Usage: %prog list_significant_snps [-d dir ]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="gwas_dir", type="string", help="gwas results directory for condition")
    parser.add_option("-e", "--ext", dest="extension", type="string", help="extension of gwas result files", default=".gwas.txt")
    parser.add_option("-o", "--output", dest="out_f", type="string", help="name of output file")
    parser.add_option("-l", "--list", dest="list", type="string", help="list of causal snps (simulations)")
    (opts, args) = parser.parse_args(args)
    
    if not opts.gwas_dir or not os.path.isdir(opts.gwas_dir):
        exit("Please specify a valid directory path (-d option)")
    
    extension=opts.extension
    gwas_dir=opts.gwas_dir
    if not gwas_dir.endswith('/'):
        gwas_dir+='/'
    
    out_f=""
    if not opts.out_f:
        out_f=gwas_dir+'signif_snp1'
    else:
        out_f=opts.out_f
    
    out_sign = open(out_f, 'w')
    list=""
    causal_snps=[]
    
    if opts.list:
        list=opts.list
        out_causal = open(out_f+'.causal', 'w')
    if list!="":
        for line in open(list, 'r'):
            line=line.strip()
            snp = line.split(' ')[0]
            causal_snps.append(snp)
                

    thresh_file = gwas_dir+'permutations/threshold.txt'
    thresh=0.0
    header=True
    for line in open(thresh_file, 'r'):
        if header:
            header=False
            continue
        line = line.strip()
        thresh = float(line)
    
    output=open(out_f, 'w')
    
    gwas_res = gwas_dir+os.path.basename(os.path.dirname(gwas_dir))+extension
    #gwas_res = glob.glob(gwas_dir+'*'+extension)[0]
    nb_signif = 0
    header=True
    for line in open(gwas_res, 'r'):
        if header:
            header=False
            continue
        line = line.strip()
        cols = line.split('\t')
        snp=cols[0]
        try:
            p_val = float(cols[5])
        except ValueError:
            continue
        if p_val < thresh:
            nb_signif +=1
            if snp in causal_snps:
                print(snp, file=out_causal)
            print(snp, file=out_sign)
        else:
            break
    output.close()
    if nb_signif == 0:
        os.remove(gwas_dir+'signif_snp1')


###############################################
#
#                Meta analysis
#
###############################################

def get_causal_infos(snplist, freq_file, assoc_f):
    causal_snps={}
    header=True
    for line in open(snplist,'r'):
        if header:
            header=False
            continue
        line=line.strip()
        cols=line.split('\t')
        snp=cols[0]
        effectsize=float(cols[3])
        maf = float(cols[2])
        causal_snps[snp]=[effectsize, maf]

    
    nb_causal=len(causal_snps)
    nb_found=0
    header=True
    for line in open(assoc_f, 'r'):
        if nb_found == nb_causal:
            break
        line=line.strip()
        cols=line.split('\t')
        if header:
            header=False
            continue
        if cols[0] in causal_snps:
           nb_found+=1
           causal_snps[cols[0]].append(float(cols[5]))

    return causal_snps           

def effect_size_maf_power(args):
    """
    Links Effect size of causal snps (determined by GCTA) with the p-value
    """
    usage = "Usage: %prog effect_size_maf_power [-d dir ]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="gwas_dir", type="string", help="directory containing all runs to analyse")
    parser.add_option("-f", "--freqext", dest="freqext", type="string", help="freq file extension", default=".frq")
    parser.add_option("-s", "--snpext", dest="snpext", type="string", help="extension of snplist files", default='_sim.par')
    parser.add_option("-g", "--gwasext", dest="gwasext", type="string", help="extension of gwas files", default='_sim.gwas.txt')
    (opts, args) = parser.parse_args(args)
    if not opts.gwas_dir or not os.path.isdir(opts.gwas_dir):
        exit("Please specify a valid directory path (-d option)")
    gwas_dir=opts.gwas_dir
    snpext=opts.snpext
    gwasext=opts.gwasext
    freqext=opts.freqext
    if not gwas_dir.endswith('/'):
        gwas_dir+='/'
    out=open("{0}{1}_effect_pvals.tsv".format(gwas_dir, os.path.basename(gwas_dir[:-1])), 'w')
    print("SNPID\tEFFECT_SIZE\tMAF\tPVALUE\tRUN", file=out)
    for rundir in sorted(glob.glob(gwas_dir+'*')):
        if not os.path.isdir(rundir):
            continue
        runid=os.path.basename(rundir)
        print(runid)
        snplist="{0}/{1}{2}".format(rundir, runid, snpext)
        assoc_f="{0}/{1}{2}".format(rundir, runid, gwasext)
        freqfile="{0}/run_1/{1}{2}".format(os.path.dirname(rundir), os.path.basename(os.path.dirname(rundir)), freqext)
        causal_infos=get_causal_infos(snplist, freqfile, assoc_f)
        #print(causal_infos)
        for snp in causal_infos:
            print(snp, causal_infos[snp][0], causal_infos[snp][1], causal_infos[snp][2], runid, file=out, sep='\t')

        

    out.close()
 

def count_nb_causal_best(args):
    usage = "Usage: %prog count_nb_causal_best [-d dir ]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="gwas_dir", type="string", help="directory containing all runs to analyse")
    
    (opts, args) = parser.parse_args(args)
    
    if not opts.gwas_dir or not os.path.isdir(opts.gwas_dir):
        exit("ERROR: Please specify a valid directory path (-d option)")
    
    gwas_dir=opts.gwas_dir
    if not gwas_dir.endswith('/'):
        gwas_dir+='/'


def meta_analysis(args):
    usage = "Usage: %prog meta_analysis [-d dir ]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="gwas_dir", type="string", help="directory containing all runs to analyse")
    parser.add_option("-c", "--causal_ext", dest="causal_ext", type="string", help="extension of causal snps file", default="_sim.snplist",)
    parser.add_option("-t", "--threshold", dest="threshold", type="string", help="path of threshold file (relative to the run dir)", default="permutations/threshold.txt")
    parser.add_option("-b", "--bim", dest="bim", type="string", help="Plink bim file",)
    (opts, args) = parser.parse_args(args)
    
    if not opts.gwas_dir or not os.path.isdir(opts.gwas_dir):
        exit("ERROR: Please specify a valid directory path (-d option)")
    
    gwas_dir=opts.gwas_dir
    if not gwas_dir.endswith('/'):
        gwas_dir+='/'
    causal_ext=opts.causal_ext
    threshold=opts.threshold
    bim= opts.bim
            
    sites={}
    sites_list=[]
    for line in open(bim, 'r'):
        line=line.strip()
        cols=line.split('\t')
        chr=cols[0]
        rs=cols[1]
        pos=cols[3]
        sites[rs]=[chr, pos]
        sites_list.append(rs)

    for rundir in sorted(glob.glob(gwas_dir+'*')):
        if not os.path.isdir(rundir):
            continue
        thresh=rundir+'/'+threshold
        thresh_val=open(thresh, 'r').readlines()[1]
        causal_f=rundir+'/'+os.path.basename(rundir)+causal_ext
        causal_snps=[]
        for line in open(causal_f, 'r'):
            line =line.strip()
            cols=line.split()
            snp=cols[0]
            causal_snps.append(snp)
        signif_file=rundir+'/signif_snp1'
        signif_snps=[]
        if not os.path.isfile(signif_file):
           signif_snps=[]
        else:
            for line in open(signif_file, 'r'):
                line=line.strip()
                cols = line.split('\t')
                snp = cols[0]
                signif_snps.append(snp)
        for site in sites_list:
            if site in causal_snps:
                
                if site in signif_snps:
                    sites[site].append('TP')
                    #print("causal  + signif = TP")
                elif site not in signif_snps:
                    sites[site].append('FN')
                    #print("causal  + not signif = FN")
                else:
                    sys.exit("error 1")
            elif site not in causal_snps:
                if site in signif_snps:
                    sites[site].append('FP')
                    #print("not causal  + signif = FP")
                elif site not in signif_snps:
                    sites[site].append('TN')
                    #print("not causal  + not signif = TN")
                else:
                    sys.exit("error 2")
            else:
                sys.exit("error 3")
    
    summary = open(gwas_dir + os.path.basename(gwas_dir[:-1])+ '_simu_summary.tsv', 'w')
    for site in sites_list:
        print(site, sites[site][0], sites[site][1])
        tp = sites[site][2:].count('TP')
        fp = sites[site][2:].count('FP')
        tn = sites[site][2:].count('TN')
        fn = sites[site][2:].count('FN')
        try:
            fp_rate= float(fp)/(fp+tn)
        except ZeroDivisionError:
            fp_rate= 'NA'
        try:
            fn_rate=float(fn)/(tp+fn)
        except ZeroDivisionError:
            fn_rate='NA'
        try:
            tp_rate=float(tp)/(tp+fn)
        except ZeroDivisionError:
            tp_rate='NA'
        try:
            tn_rate=float(tn)/(tn+fp)
        except ZeroDivisionError:
            tn_rate='NA'
        print(tp, fp, tn, fn)
        print(tp+ fp+ tn+ fn)
        print(site, '\t'.join(sites[site]), str(fp_rate), str(fn_rate), str(tp_rate), str(tn_rate), file=summary, sep='\t')
        #print(site, '\t'.join(sites[site]), str(fp_rate), str(tn_rate), file=summary, sep='\t')   


def post_process(args):
    usage = "Usage: %prog meta_analysis [-d dir ]"
    parser = OptionParser(usage=usage)
    parser.add_option("-d", "--dir", dest="gwas_dir", type="string", help="directory containing all runs to analyse")
    parser.add_option("-c", "--causal_ext", dest="causal_ext", type="string", help="extension of causal snps file", default="_sim.snplist",)
    parser.add_option("-t", "--threshold", dest="threshold", type="string", help="path of threshold file (relative to the run dir)", default="permutations/threshold.txt")
    #parser.add_option("-b", "--bim", dest="bim", type="string", help="Plink bim file",)
    (opts, args) = parser.parse_args(args)
    
    if not opts.gwas_dir:
        exit("Specify directory to analyse (-d option)")
    gwas_dir=opts.gwas_dir
    
    if not gwas_dir.endswith('/'):
        gwas_dir+='/'
    allfrq=open(gwas_dir+"all_datasets_freq.tsv", "w")
    print("dataset\tSNP\tA1\tMAF", file=allfrq)
    for run in glob.glob(gwas_dir+'*/run_1/*.freq'):
        ds=os.path.basename(os.path.dirname(os.path.dirname(run)))
        for line in open(run, 'r'):
            line=line.strip()
            cols=line.split('\t')
            print(ds, line, file=allfrq, sep='\t')

    for summ in glob.glob(gwas_dir+'*/*simu_summary.tsv'):
        print(summ)
        short = open(summ.replace('.tsv','.short.tsv'), 'w')
        print("rs\tchr\tpos\tFPR\tFNR\tTPR\tTNR", file=short)
        for line in open(summ, 'r'):
            line=line.strip()
            cols=line.split('\t')
            print(cols[0], cols[1], cols[2], cols[1003], cols[1004], cols[1005], cols[1006], file=short, sep='\t')
        print("write FN")   
        write_FN_list(summ)
        print("write TP")
        write_TP_list(summ)
        print("write FP")
        write_FP_list(summ)
    return "ok"
        


def write_FN_list(summ):
    out = open(os.path.join(os.path.dirname(summ), "FN_SNPs_by_run.tsv"), 'w')
    FN_by_run={}
    for line in open(summ, 'r'):
        line = line.strip()
        cols=line.split('\t')
        rs=cols[0]
        chr=cols[1]
        pos=cols[2]
         for idx, val in enumerate(cols[3:1003]):
            run=idx+1
            if val=='FN':
                if run not in FN_by_run:
                    FN_by_run[run]=[]
                FN_by_run[run].append(rs)
                
    for runid in FN_by_run:
        print(runid, '\t'.join(FN_by_run[runid]), sep='\t', file=out)
        
def write_TP_list(summ):
    out = open(os.path.join(os.path.dirname(summ), "TP_SNPs_by_run.tsv"), 'w')
    TP_by_run={}
    for line in open(summ, 'r'):
        line = line.strip()
        cols=line.split('\t')
        rs=cols[0]
        chr=cols[1]
        pos=cols[2]
       for idx, val in enumerate(cols[3:1003]):
            run=idx+1
            if val=='TP':
                if run not in TP_by_run:
                    TP_by_run[run]=[]
                TP_by_run[run].append(rs)
                
    for runid in TP_by_run:
        print(runid, '\t'.join(TP_by_run[runid]), sep='\t', file=out)


def write_FP_list(summ):
    out = open(os.path.join(os.path.dirname(summ), "FP_SNPs_by_run.tsv"), 'w')
    FP_by_run={}
    for line in open(summ, 'r'):
        line = line.strip()
        cols=line.split('\t')
        rs=cols[0]
        chr=cols[1]
        pos=cols[2]
        for idx, val in enumerate(cols[3:1003]):
            run=idx+1
            if val=='FP':
                if run not in FP_by_run:
                    FP_by_run[run]=[]
                FP_by_run[run].append(rs)
                
    for runid in FP_by_run:
        print(runid, '\t'.join(FP_by_run[runid]), sep='\t', file=out)



###############################################
#
#                Debug
#
###############################################

def get_site_from_ped(args):
    usage = "Usage: %prog get_site_from_ped [-i fileset -s site]"
    parser = OptionParser(usage=usage)
    parser.add_option("-i", "--input", dest="input", type="string", help="input plink fileset (ped/map format)",)
    parser.add_option("-s", "--site", dest="site", type="string", help="polymorphic site",)
    (opts, args) = parser.parse_args(args)

    if not opts.input:
        exit("Please specify an input prefix. (-i option)")
    if not opts.site:
        exit("Please specify site. (-i option)")

    input = opts.input
    site=opts.site
    site_id = int(site.replace('rs', '')) -1
    alleles={}
    for line in open(input+'.ped', 'r'):
        line=line.strip()
        cols = line.split()
        indiv = cols[0]
        genotypes = list(group_list(cols[6:], 2))
        alleles[indiv] = genotypes[site_id]
    print(alleles)


def group_list(lst, n):
    return zip(*[lst[i::n] for i in range(n)]) 


######################################################################################## 
########################################################################################

   
if __name__ == "__main__":
    import time
    start_time = time.time()
    command_list = set(['simulate_genotype','simulate_qt_from_file', 'simulate_several_matrices', 'simulate_qt', 'run_permutations', 'list_significant_snps', 'get_site_from_ped', 'count_nb_causal_best', 'permut_analysis', 'effect_size_maf_power', 'meta_analysis', 'post_process', 'help'])
    if len(sys.argv) < 2:
        print(__doc__)
        exit("No command specified, use one of:\n - %s" % "\n - ".join(command_list))

    cmd = sys.argv[1]
    if cmd not in command_list:
        exit("Unrecognized command '%s', use one of:\n - %s" % (cmd, "\n - ".join(command_list)))

    if cmd == 'simulate_genotype':
        simulate_genotype(sys.argv[1:])
    elif cmd == 'get_selectable_snps':
        get_selectable_snps(sys.argv[1:])
    elif cmd == 'simulate_qt':
        simulate_qt(sys.argv[1:])
    elif cmd == 'get_site_from_ped':
        get_site_from_ped(sys.argv[1:])
    elif cmd == 'run_permutations':
        run_permutations(sys.argv[1:])
    elif cmd == 'permut_analysis':
        permut_analysis(sys.argv[1:])
    elif cmd == 'simulate_several_matrices':
        simulate_several_matrices(sys.argv[1:])
    elif cmd == 'list_significant_snps':
        list_significant_snps(sys.argv[1:])
    elif cmd == 'effect_size_maf_power':
        effect_size_maf_power(sys.argv[1:])
    elif cmd == 'meta_analysis':
        meta_analysis(sys.argv[1:])
    elif cmd == 'post_process':
        post_process(sys.argv[1:])
    elif cmd == 'simulate_qt_from_file':
        simulate_qt_from_file(sys.argv[1:])
    elif cmd == 'count_nb_causal_best':
        count_nb_causal_best(sys.argv[1:])
    elif cmd == 'help':
        help()
    else:
        print("Unrecognized command specified, use one of:\n - %s" % '\n - '.join(command_list))
    print("--- %s seconds ---" % (time.time() - start_time))
