#!/usr/bin/env python
# -*- coding: utf8 -*-
#----------------------------------------------------------------------------
# Created By  : vloegler
# Last Modification Date: 2022/03/03
# ---------------------------------------------------------------------------
'''
This script runs Genome Wide Association Study using FaST-LMM. It performs
association and permutation test to output significant SNPs. 

It takes as input :
	-g --genotype: the genotype matrix (Plink format)
	-p --phenotype: the phenotypes, in format : 
											Strain		Cond1	Cond2	Cond3
											StrainXXX	0.02	0.98	0.14
											StrainXXX	0.12	0.52	0.65
	-c --covariance: covariance matrix (optional)
	-o --output: prefix of output files
	-n --nbPermutations: number of permutations for the permutation test

Output will be:
prefix_first_assoc.txt : FaST-LMM association results
prefix_condition_threshold.txt: Value of the threshold
prefix_condition_signif_snps: FaST-LMM results for SNPs above the threshold
'''
# ---------------------------------------------------------------------------
import numpy as np
import pandas as pd
import time
from sys import argv
from fastlmm.association import single_snp 
from pysnptools.snpreader import Bed, SnpData
from pysnptools.util.mapreduce1.runner import LocalMultiProc
import logging
import argparse
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.ERROR)

# Number of proc = 3
runner=LocalMultiProc(3)

# =============
# Get arguments
# =============

# Initiate the parser
parser = argparse.ArgumentParser()
parser.add_argument("-g", "--genotype", help="genotype matrix in Plink format", required=True)
parser.add_argument("-p", "--phenotype", help="phenotypes data", required=True)
parser.add_argument("-o", "--output", help="prefix of the output files", required=True)
parser.add_argument("-c", "--covariance", help="covariance matrix (optional)", type=str, default="")
parser.add_argument("-n", "--nbPermutations", help="number of permutations for the permutation test (defult = 100)", type=int, default=100)

# Read arguments from the command line
args = parser.parse_args()

# variant matrix in PLINK format
bed_fn = args.genotype
# Phenotypes in format
# strain	Cond1	Cond2	Cond3
# StrainXXX	0.02	0.98	0.14
# StrainXXX	0.12	0.52	0.65
pheno_fn = args.phenotype
# Covariance matrix in format
# Strain	Strain	Covariant1	Covariant2 ...
covar_fn = args.covariance
# Prefix of output files
outdir = args.output
# Number of permutations
nperm = args.nbPermutations



# =====================
# Run First Association
# =====================

# Read phenotypes
df=pd.read_csv(pheno_fn,sep="\t",index_col=0)

# Get phenotype file header (condition names)
headers = np.array(df.columns)
# Get phenotype values
values = df.values
# Get list of strains
strainslist=np.c_[np.array(df.index)]
strainslist_2el=np.append(strainslist,strainslist,axis=1).astype('<U23')
matrix=np.array(values)
nb_phens=matrix.shape[1]
# Create phenotype object
phenoSNPread=SnpData(iid=strainslist_2el, sid=df.columns[:], val=np.c_[matrix[:,:]])

# run assoc 
start_time=time.time()
if covar_fn == "": # If no covariance matrix
	results_df = single_snp(test_snps = bed_fn, pheno = phenoSNPread, count_A1=False, runner=runner, map_reduce_outer=False, GB_goal=4)
else:
	results_df = single_snp(test_snps = bed_fn, pheno = phenoSNPread, covar = covar_fn, count_A1=False, runner=runner, map_reduce_outer=False, GB_goal=4)	
print("association lasted %s seconds " % (time.time() - start_time))

results_df.to_csv((outdir+ ".first_assoc.txt"),index=False,sep="\t")
nbsnps=int(results_df.shape[0]/nb_phens)



# ================
# Run permutations
# ================

### random shuffling of matrix columns
def matrix_permut(x):
	for i in range(nperm):
		np.random.shuffle(x[:,i])


#### association for permuted values for each phen ALL WITHIN SAME TABLE
# Create table with all permutated phenotypes
final_permuttable=np.empty([strainslist.shape[0],0])
for k in range(nb_phens):
	phen_topermut=matrix[:, [k]]
	permut_matrix=np.repeat(phen_topermut,repeats=nperm,axis=1)
	matrix_permut(permut_matrix)
	final_permuttable=np.hstack((final_permuttable,permut_matrix))

phenoSNPread_permut=SnpData(iid=strainslist_2el, sid=np.arange(nperm*nb_phens).astype('<U23'), val=np.c_[final_permuttable[:,:]])

start_time=time.time()
# Run Permutations association
if covar_fn == "": # If no covariance matrix
	permut_df = single_snp(test_snps = bed_fn, pheno = phenoSNPread_permut, count_A1=False, runner=runner, map_reduce_outer=False, GB_goal=20)
else:
	permut_df = single_snp(test_snps = bed_fn, pheno = phenoSNPread_permut, covar = covar_fn, count_A1=False, runner=runner, map_reduce_outer=False, GB_goal=20)	
print("association lasted %s seconds " % (time.time() - start_time))

for k in range(nb_phens):
	# Get phenotype name
	phen_name=df.columns[k]
	# Get PValue threshold
	threshold=np.percentile(permut_df.iloc[(k*nperm):((k+1)*nperm)]['PValue'],5)
	# Write threshold to file
	thresh = open(outdir + "." + phen_name +".threshold.txt", "w")
	thresh.write("x\n"+str(threshold)+"\n")
	thresh.close()
	# Get signif SNPs above threshold and write to file
	signif_snps=results_df.iloc[(k*nbsnps+k):((k+1)*nbsnps+k)][results_df.iloc[(k*nbsnps+k):((k+1)*nbsnps+k)]['PValue'] < threshold]
	signif_snps.to_csv((outdir + "." + phen_name +".signif_snps.txt"),index=False,sep="\t")

