/***********************************************************************
ESPRIT-Forest: Parallel Clustering of Massive Amplicon Sequence Data in Subquadratic Time 
by: Yunpeng Cai, Yijun Sun, Wei Zheng, Jin Yao and Yujie Yang  (C) 2016
Please kindly cite [Y.Cai et.al PLOS Comp. Biol. 2016]

THE LICENSED WORK IS PROVIDED UNDER THE TERMS OF THE ADAPTIVE PUBLIC LICENSE ("LICENSE") AS FIRST COMPLETED BY: _Yunpeng Cai, Yijun Sun, Wei Zheng, Jin Yao, Yujie Yang_ [Insert the name of the Initial Contributor here]. ANY USE, PUBLIC DISPLAY, PUBLIC PERFORMANCE, REPRODUCTION OR DISTRIBUTION OF, OR PREPARATION OF DERIVATIVE WORKS BASED ON, THE LICENSED WORK CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS LICENSE AND ITS TERMS, WHETHER OR NOT SUCH RECIPIENT READS THE TERMS OF THE LICENSE. "LICENSED WORK" AND "RECIPIENT" ARE DEFINED IN THE LICENSE. A COPY OF THE LICENSE IS LOCATED IN THE TEXT FILE ENTITLED "LICENSE.TXT" ACCOMPANYING THE CONTENTS OF THIS FILE. IF A COPY OF THE LICENSE DOES NOT ACCOMPANY THIS FILE, A COPY OF THE LICENSE MAY ALSO BE OBTAINED AT THE FOLLOWING WEB SITE: http://www.acsu.buffalo.edu/~yijunsun/lab/ESPRIT-Forest.html [Insert Initial Contributor's Designated Web Site here]

Software distributed under the License is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for the specific language governing rights and limitations under the License.
*/

#include <iostream>
#include <fstream>
#include <time.h>
#include <math.h>
#include "global.h"
#include "kmer.h"
#include "TreeClust.h"
#include "needle.h"
#include "MinHeap.h"
using namespace std;

typedef struct{
	int idx;
	float dist;
}KDist;

int CompDist(const void * a, const void * b)
{
	KDist *kd1,*kd2;
	kd1=(KDist *)a;
	kd2=(KDist *)b;
	if (kd1->dist > kd2->dist) return 1;
	if (kd1->dist < kd2->dist) return -1;
	return 0;
}

TreeClust::TreeClust(float lmin,float lmax)
{
	root=NULL;
	levelinc=Global::level_inc;
	levelmax=lmax*levelinc;
	levelmin=levelmax;
	NumSeq=0;
	while (levelmin/levelinc > lmin) levelmin=levelmin/levelinc;
}

Tree* TreeClust::CreateBranch(int uid,float levelup)
{
	Tree *base=new Tree(uid);
	Tree *upper;
	float level=levelmin;
	while (level <=levelup*1.001)
	{
		upper=new Tree(uid);
		upper->AddChild(base);
		
		upper->SetThres(level);
		level*=levelinc;
		base=upper;
	}
	return base;
}	

Tree *TreeClust::Condense(Tree *node,int uid)
{
	Tree *newchild=CreateBranch(uid,node->GetThres()/levelinc);
	while (node->NumChildren() >0)
		node->DeleteChild(node->FirstChild());
	node->AddChild(newchild);
	while (!newchild->IsLeaf())
		newchild=newchild->FirstChild();
	return newchild;
}

Tree *TreeClust::AddSeq(int uid)
{
	
	Tree *current,*bestchild;
	float nrdist;
	NumSeq++;
	
	if (root==NULL)
	{
		root=CreateBranch(uid,levelmax*levelinc);	
		current=root;
		while (!current->IsLeaf()) 
		{
			current=current->FirstChild();
		}
		return current;
	}
	
	current=root;

	do{
		if (current->IsLeaf()) cerr <<"Shouldn't Get to Leaf" <<endl;
		
		if (current->BottomLevel())
		{
			bestchild=new Tree(uid);
			current->AddChild(bestchild);
			break;
		}
		
		if (current->NumChildren()==1)
		{
			bestchild=current->FirstChild();
			if (current==root)
			{
				nrdist=Kdist2Ndist(KmerDist(uid,bestchild->UID)); 
			}
		}
		else
		{
			nrdist=current->FindSpanChild(uid,bestchild);
		}
		
		if (bestchild !=NULL && nrdist < bestchild->GetThres())
		{
			current=bestchild;
		}
		else
		{
			bestchild=CreateBranch(uid,current->GetThres()/levelinc);
			current->AddChild(bestchild);
			break;
		}
	}while(1);
		
	while (!bestchild->IsLeaf()) 
	{
		bestchild=bestchild->FirstChild();
	}
	return bestchild;
}

Tree *TreeClust::AddSeqFrom(int uid, int branch, Tree *par)
{
	Tree *current,*bestchild;
	float nrdist;
	NumSeq++;
	
	current=par->FindChild(branch);
	if (current==NULL)
	{
		cerr <<" Wrong Branch ID " <<endl;
		return NULL;
	}
	
	nrdist=current->FindSpanChild(uid,bestchild);
	
	do{
		if (current->IsLeaf()) cerr <<"Shouldn't Get to Leaf" <<endl;
		
		if (current->BottomLevel())
		{
			bestchild=new Tree(uid);
			current->AddChild(bestchild);
			break;
		}
		
		if (bestchild !=NULL && nrdist < bestchild->GetThres())
		{
			current=bestchild;
			if (current->NumChildren()==1)
			{
				bestchild=current->FirstChild();
			}
			else
			{
				nrdist=current->FindSpanChild(uid,bestchild);
			}
		}
		else
		{
			bestchild=CreateBranch(uid,current->GetThres()/levelinc);
			current->AddChild(bestchild);
			break;
		}
	}while(1);
		
	while (!bestchild->IsLeaf()) 
	{
		bestchild=bestchild->FirstChild();
	}
	return bestchild;
}


Tree *TreeClust::AddSeqAt(int uid, Tree *par){
	Tree *bestchild;
	NumSeq++;
	
	if (par->BottomLevel())
	{
		bestchild=new Tree(uid);
		par->AddChild(bestchild);
	}
	else
	{
		bestchild=CreateBranch(uid,par->GetThres()/levelinc);
		par->AddChild(bestchild);
	}

	while (!bestchild->IsLeaf()) 
	{
		bestchild=bestchild->FirstChild();
	}
	return bestchild;
}


Tree* TreeClust::FindInsertParent(int uid)
{
	Tree *current,*bestchild;
	float nrdist;
	
	if (root==NULL)
	{
		return NULL;
	}
	
	current=root;

	do{
		if (current->BottomLevel() || current->IsLeaf())
		{
			return current;
		}
		
		if (current->NumChildren()==1)
		{
			bestchild=current->FirstChild();
			if (current==root)
			{
				nrdist=Kdist2Ndist(KmerDist(uid,bestchild->UID)); 
			}
		}
		else
		{
			nrdist=current->FindSpanChild(uid,bestchild);
		}
		
		if (bestchild !=NULL && nrdist < bestchild->GetThres())
		{
			current=bestchild;
		}
		else
		{
			return current;
		}
	}while(1);
		
	return NULL;
}

void TreeClust::RemoveNode(Tree *node)
{
	if (node==NULL)
	{
		fprintf(stderr,"Error Deleting Empty Node\n");
		exit(0);
	}
	Tree *ptr=node->GetParent();
	do{
		ptr->DeleteChild(node);
		node=ptr;
		ptr=node->GetParent();
	}while(node->IsLeaf() && ptr !=NULL); 
	if (node==root && node->IsLeaf())
	{
		root=NULL;
		delete node;
	}
}

int TreeClust::FindNN(int uid,float thres)
{
	vector<Tree *> tvec;
	
	MinHeap heap;
	float dist;
	Tree *top;
	
	if (root==NULL) return -1;
	
	//enroll all potential NN branches	
	tvec.clear();	
	root->ListChildren(tvec);
	for (int i=0; i<tvec.size();i++)
	{
		while (!tvec[i]->IsLeaf() && tvec[i]->NumChildren()==1)
		{
			tvec[i]=tvec[i]->FirstChild();
		}
		dist=KmerDist(tvec[i]->UID,uid); 
		if (dist < KdistBound(thres+tvec[i]->GetThres()))
		{
			heap.Add((void *)tvec[i],dist);
		}
	}

	//perform A* search
	while (!heap.Empty())
	{
		dist=heap.Pop((void *&)top);
		if (top->UID==uid) continue;
		if (dist < KdistBound(thres+top->GetThres()))
		{
			if (top->IsLeaf()) //reach terminal
			{
				dist=NeedleDist(top->UID,uid,2*thres);
				if (dist <= thres)
				{
					return top->UID;
				}
			}
			else //expand
			{
				while (!top->BottomLevel() && top->NumChildren()==1)
					top=top->FirstChild();
				tvec.clear();	
				top->ListChildren(tvec);

				for (int i=0; i<tvec.size();i++)
				{
					dist=KmerDist(tvec[i]->UID,uid); 
					if (dist < KdistBound(thres+tvec[i]->GetThres()))
					{
						heap.Add((void *)tvec[i],dist);
					}
				}
			}
		}	
	}
	return -1;
}

void TreeClust::EstimateNN(int uid, ClustRec *clRec)
{
	Tree *top=clRec[uid].Node;
	Tree *par=top->GetParent();
	Tree *brother;
	
	while (par !=root && par->NumChildren() <=1)
	{
		top=par;
		par=top->GetParent();
	}
	if (par == root) return;
	int NNseq;
	float Ndist;
	if (par->UID !=uid && clRec[par->UID].clsid ==-1)
	{
		NNseq=par->UID;
	}
	else
	{
		Tree *brother=top->GetBrother();
		while (!brother->IsLeaf())
			brother=brother->FirstChild();
		NNseq=brother->UID;
	}

	Ndist=NeedleDist(uid,NNseq,par->GetThres());
	#pragma omp critical (CRI_clRec)
	{
		clRec[uid].NNseq=NNseq;
		clRec[uid].NNdist=Ndist;
		clRec[NNseq].NNlist.insert(uid);
	}
}

void TreeClust::FindNN(int uid,ClustRec *clRec,float leveldn) // leveldn is the possible lower bound of NN dist
{
	vector<Tree *> tvec;
	
	MinHeap heap;
	float dist;
	Tree *top;
	
	if (root==NULL) return;
	
	tvec.clear();	
	root->ListChildren(tvec);
	for (int i=0; i<tvec.size();i++)
	{
		while (!tvec[i]->IsLeaf() && tvec[i]->NumChildren()==1)
		{
			tvec[i]=tvec[i]->FirstChild();
		}
		dist=KmerDist(tvec[i]->UID,uid); 
		if (dist < KdistBound(clRec[uid].NNdist+tvec[i]->GetThres()))
		{
			heap.Add((void *)tvec[i],dist);
		}
	}
	
	//perform A* search
	while (!heap.Empty())
	{
		dist=heap.Pop((void *&)top);
		
		if (dist < KdistBound(clRec[uid].NNdist+top->GetThres()))
		{
			if (top->IsLeaf()) //reach terminal
			{
				if (top->UID==uid) continue;
				dist=NeedleDist(top->UID,uid,min(levelmax,clRec[uid].NNdist+0.02f));
				int tagret=0;
			#pragma omp critical (CRI_clRec)                  //avoid parallel access of clRec
			{
				if (dist < clRec[top->UID].NNdist)
				{
					clRec[top->UID].NNdist2=clRec[top->UID].NNdist;
					clRec[top->UID].NNseq2=clRec[top->UID].NNseq;
					if (clRec[top->UID].NNseq >=0)
						clRec[clRec[top->UID].NNseq].NNlist.erase(top->UID);
					clRec[top->UID].NNdist=dist;
					clRec[top->UID].NNseq=uid;
					clRec[uid].NNlist.insert(top->UID);
				}
				else if (dist < clRec[top->UID].NNdist2)
				{
					clRec[top->UID].NNdist2=dist;
					clRec[top->UID].NNseq2=uid;
				}
				
				if (dist < clRec[uid].NNdist)
				{
					clRec[uid].NNdist2=clRec[uid].NNdist;
					clRec[uid].NNseq2=clRec[uid].NNseq;
					if (clRec[uid].NNseq >=0)
						clRec[clRec[uid].NNseq].NNlist.erase(uid);
					clRec[uid].NNdist=dist;
					clRec[uid].NNseq=top->UID;
					clRec[top->UID].NNlist.insert(uid);
					if (dist <leveldn)
						tagret=1;
				}
				else if (dist < clRec[uid].NNdist2)
				{
					clRec[uid].NNdist2=dist;
					clRec[uid].NNseq2=top->UID;
				}
			}
			if (tagret) return;
			}
			else //expand
			{
				while (!top->BottomLevel() && top->NumChildren()==1)
					top=top->FirstChild();
				tvec.clear();	
				top->ListChildren(tvec);

				for (int i=0; i<tvec.size();i++)
				{
					dist=KmerDist(tvec[i]->UID,uid); 
					if (dist < KdistBound(clRec[uid].NNdist+tvec[i]->GetThres()))
					{
						heap.Add((void *)tvec[i],dist);
					}
				}
			}
		}	
	}
}

void TreeClust::FindNN(int uid,Tree *node,ClustRec *clRec,float leveldn)  // leveldn is the possible lower bound of NN dist
{
	vector<Tree *> tvec;
	
	MinHeap heap;
	float dist;
	Tree *top;
	
	tvec.clear();
	
	if (node->GetParent()->NumChildren() >1)
	{
		node->GetParent()->ListChildren(tvec);
		int ptr=0;
		KDist *kdistlist=(KDist *)Malloc(tvec.size()*sizeof(KDist));
		
		for (int i=0; i<tvec.size();i++)
		{
			if (tvec[i] !=node && tvec[i]->UID !=clRec[uid].NNseq)
			{
				kdistlist[ptr].idx=i;
				kdistlist[ptr++].dist=KmerDist(tvec[i]->UID,uid);
			}
		}

		qsort(kdistlist,ptr,sizeof(KDist),CompDist);
	
		for (int i=0;i<ptr;i++)
		{
			if (kdistlist[i].dist < KdistBound(clRec[uid].NNdist))
			{
				int NNseq=tvec[kdistlist[i].idx]->UID;
				if (NNseq==uid) continue;
				dist=NeedleDist(uid,NNseq,min(clRec[uid].NNdist+0.02f,levelmax));
				int tagret=0;
				#pragma omp critical (CRI_clRec)
				{
					if (clRec[NNseq].NNdist > dist)
					{	
						clRec[NNseq].NNdist2=clRec[NNseq].NNdist;
						clRec[NNseq].NNdist=dist;
						if (clRec[NNseq].NNseq >=0)
							clRec[clRec[NNseq].NNseq].NNlist.erase(NNseq);
						clRec[uid].NNlist.insert(NNseq);
						clRec[NNseq].NNseq2=clRec[NNseq].NNseq;
						clRec[NNseq].NNseq=uid;
					}
					else if (clRec[NNseq].NNdist2 > dist)
					{
						clRec[NNseq].NNdist2=dist;
						clRec[NNseq].NNseq2=uid;
					}
					if (dist < clRec[uid].NNdist)
					{
						clRec[uid].NNdist2=clRec[uid].NNdist;
						clRec[uid].NNseq2=clRec[uid].NNseq;
						if (clRec[uid].NNseq >=0)
							clRec[clRec[uid].NNseq].NNlist.erase(uid);
						clRec[uid].NNdist=dist;
						clRec[uid].NNseq=NNseq;
						clRec[NNseq].NNlist.insert(uid);
						if (dist <leveldn)
							tagret=1;
					}
					else if (dist < clRec[uid].NNdist2)
					{
						clRec[uid].NNdist2=dist;
						clRec[uid].NNseq2=NNseq;
					}
				}
				if (tagret) 
					break;
			}
		}
		free(kdistlist);
	}
	

	//enroll all potential NN branches	
	top=node->GetParent();
	
	while (top !=root)
	{
		Tree *tpar=top->GetParent();
		
		if (clRec[uid].NNseq >=0 && KmerDist(uid,top->UID) < KdistBound(top->GetThres()-clRec[uid].NNdist))
		{
			top=tpar;
			continue;
		}
		if (tpar->NumChildren() >1)
		{
			tvec.clear();	
			tpar->ListChildren(tvec);
			
			for (int i=0; i<tvec.size();i++)
			{
				if (tvec[i] !=top)
				{
					while (!tvec[i]->IsLeaf() && tvec[i]->NumChildren()==1)
					{
						tvec[i]=tvec[i]->FirstChild();
					}
					if (KmerDist(tvec[i]->UID,uid) < KdistBound(clRec[uid].NNdist+tvec[i]->GetThres()))
					{
						heap.Add((void *)tvec[i],dist);
					}
				}
			}
		}
		top=tpar;	
	}
	
	//perform A* search
	while (!heap.Empty())
	{
		dist=heap.Pop((void *&)top);
		
		if (dist < KdistBound(clRec[uid].NNdist+top->GetThres()))
		{
			if (top->IsLeaf()) //reach terminal
			{
				if (top->UID==uid) continue;
				dist=NeedleDist(top->UID,uid,min(clRec[uid].NNdist+0.02f,levelmax));
				int tagret=0;
			#pragma omp critical (CRI_clRec)
			{
				if (dist < clRec[top->UID].NNdist)
				{
					clRec[top->UID].NNdist2=clRec[top->UID].NNdist;
					clRec[top->UID].NNseq2=clRec[top->UID].NNseq;
					if (clRec[top->UID].NNseq >=0)
						clRec[clRec[top->UID].NNseq].NNlist.erase(top->UID);
					clRec[top->UID].NNdist=dist;
					clRec[top->UID].NNseq=uid;
					clRec[uid].NNlist.insert(top->UID);
				}
				else if (dist < clRec[top->UID].NNdist2)
				{
					clRec[top->UID].NNdist2=dist;
					clRec[top->UID].NNseq2=uid;
				}
				
				if (dist < clRec[uid].NNdist)
				{
					clRec[uid].NNdist2=clRec[uid].NNdist;
					clRec[uid].NNseq2=clRec[uid].NNseq;
					if (clRec[uid].NNseq >=0)
						clRec[clRec[uid].NNseq].NNlist.erase(uid);
					clRec[uid].NNdist=dist;
					clRec[uid].NNseq=top->UID;
					clRec[top->UID].NNlist.insert(uid);
					if (dist <leveldn)
						tagret=1;
				}
				else if (dist < clRec[uid].NNdist2)
				{
					clRec[uid].NNdist2=dist;
					clRec[uid].NNseq2=top->UID;
				}
			}
			if (tagret) return;
			}
			else //expand
			{
				while (!top->BottomLevel() && top->NumChildren()==1)
					top=top->FirstChild();
				tvec.clear();	
				top->ListChildren(tvec);

				for (int i=0; i<tvec.size();i++)
				{
					dist=KmerDist(tvec[i]->UID,uid); 
					if (dist < KdistBound(clRec[uid].NNdist+tvec[i]->GetThres()))
					{
						heap.Add((void *)tvec[i],dist);
					}
				}
			}
		}	
	}
}


void TreeClust::ListLeaves(FILE *fp)
{
	vector<Tree *> tvec;
	tvec.clear();	
	root->ListLeaf(tvec);
	for (int i=0; i<tvec.size();i++)
	{	
		fprintf(fp,"[%d]",tvec[i]->UID);
	}
	fprintf(fp,"\n");
	fflush(fp);
}

void TreeClust::ListBottom(vector<Tree *> &tvec)
{
	tvec.clear();
	root->ListBottom(tvec);
}

void TreeClust::ListChildrenAt(float level, vector<Tree *> &tvec)
{
	tvec.clear();
	root->ListChildrenAt(level,tvec);
}
