//==============================================================================
/*! \file
 * OpenMesh Toolkit for mesh analysis    \n
 * Copyright (c) 2010 by Rostislav Hulik     \n
 *
 * Author:  Rostislav Hulik, rosta.hulik@gmail.com  \n
 * Date:    2010/10/20                          \n
 *
 * This file is part of software developed for support of Rostislav Hulik's dissertation thesis at dcgm-robotics@FIT group.
 *
 * This file is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This file is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU Lesser General Public License
 * along with this file.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Description:
 * - Class for constructing BDP tree for raycasting
 */

#include <OMToolkit\OMTriBSPTree.h>
#include <MDSTk/Module/mdsModule.h>
#include <float.h>
#include <OpenMesh\Tools\Utils\Timer.hh>

namespace OMToolkit {

#ifdef USE_MULTIPLE_IN_ONE
	static const int maxElem = 16;
#endif


// static pi/2 constant
static const OMTriBSPTree::ScalarT PIHALF = (OMTriBSPTree::ScalarT)(M_PI/2);



//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Constructor - creates a BSP from a mesh
// @param mesh Pointer to a mesh
// @param maxPointsInElement Maximum number of vertices in one tree cell
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
OMTriBSPTree::OMTriBSPTree(MeshT *mesh)
{
	m_mesh = mesh;
	// precomp normals (if not computed)
	if (!m_mesh->has_face_normals())
	{
		m_mesh->request_face_normals();
		m_mesh->update_normals();
	}

	// normalize normals
	MeshT::FaceIter end = m_mesh->faces_end();
	MeshT::Normal normal;
	for (MeshT::FaceIter face = m_mesh->faces_begin(); face != end; ++face)
	{
		normal = m_mesh->normal(face);
		normal.normalize_cond();
		m_mesh->set_normal(face, normal);
	}

	// create a tree
	m_root = NULL;
	ConstructTree();
	MDS_LOG_NOTE("R-tree constructed...");




	///////////////////////////////////
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Aux function for tree construction (used in constructor)
// @param maxPointsInElement Maximum number of vertices in one tree cell
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
void OMTriBSPTree::ConstructTree()
{
	std::vector<StackElement> stack;
	PointT min(FLT_MAX, FLT_MAX, FLT_MAX);
	PointT max(-FLT_MAX, -FLT_MAX, -FLT_MAX);
	PointT pt, pt2;
	TreeElement *aux;

	// create a list of all faces, their bounding boxes
	MeshT::FaceIter fend = m_mesh->faces_end();
	for (MeshT::FaceIter face = m_mesh->faces_begin(); face != fend; ++face)
	{
		PointT center(0.0, 0.0, 0.0);
		PointT aabbLow(FLT_MAX, FLT_MAX, FLT_MAX);
		PointT aabbHigh(-FLT_MAX, -FLT_MAX, -FLT_MAX);
		PointT current;

		for (MeshT::FVIter vertex = m_mesh->fv_begin(face); vertex; ++vertex)
		{
			current = m_mesh->point(vertex);
			center += current;
			if (aabbLow[0] > current[0]) aabbLow[0] = current[0];
			if (aabbLow[1] > current[1]) aabbLow[1] = current[1];
			if (aabbLow[2] > current[2]) aabbLow[2] = current[2];

			if (aabbHigh[0] < current[0]) aabbHigh[0] = current[0];
			if (aabbHigh[1] < current[1]) aabbHigh[1] = current[1];
			if (aabbHigh[2] < current[2]) aabbHigh[2] = current[2];
		}
		center /= 3.0f;
		
		m_sorted.push_back(TriCenter(center, face, aabbLow, aabbHigh, m_mesh));



		if (aabbLow[0] < min[0]) min[0] = aabbLow[0];
		if (aabbLow[1] < min[1]) min[1] = aabbLow[1];
		if (aabbLow[2] < min[2]) min[2] = aabbLow[2];

		if (aabbHigh[0] > max[0]) max[0] = aabbHigh[0];
		if (aabbHigh[1] > max[1]) max[1] = aabbHigh[1];
		if (aabbHigh[2] > max[2]) max[2] = aabbHigh[2];
	}

	int index = 0;
	
	// Push on the stack root element
	stack.push_back(StackElement(0, m_sorted.size(), 0, 0.0, NULL, index));
	++index;
	StackElement current = stack.back();
	int center;

	// go recursive
	while (stack.size() != 0)
	{
		// Current stack element entry:
		// ============================
		// current minimum index
		// current first non valid index
		// parent sort mechanism - so division plane coordinate
		// parent tree element
		current = stack.back();
		stack.pop_back();
		
		// 1st thing to do - register new tree cell
		// empty tree test
		aux = new TreeElement(current._id);
		if (m_root == NULL)
		{
			aux->aabb[0] = min;
			aux->aabb[1] = max;
			aux->axis = 0;
			m_root = aux;
		}
		// anything else
		else
		{
			// save aabb
			aux->aabb[0] = PointT(FLT_MAX, FLT_MAX, FLT_MAX);
			aux->aabb[1] = PointT(-FLT_MAX, -FLT_MAX, -FLT_MAX);
			for (int i = current.lower; i < current.higher; ++i)
			{
				if (aux->aabb[0][0] > m_sorted[i].m_aabb[0][0]) aux->aabb[0][0] = m_sorted[i].m_aabb[0][0];
				if (aux->aabb[0][1] > m_sorted[i].m_aabb[0][1]) aux->aabb[0][1] = m_sorted[i].m_aabb[0][1];
				if (aux->aabb[0][2] > m_sorted[i].m_aabb[0][2]) aux->aabb[0][2] = m_sorted[i].m_aabb[0][2];

				if (aux->aabb[1][0] < m_sorted[i].m_aabb[1][0]) aux->aabb[1][0] = m_sorted[i].m_aabb[1][0];
				if (aux->aabb[1][1] < m_sorted[i].m_aabb[1][1]) aux->aabb[1][1] = m_sorted[i].m_aabb[1][1];
				if (aux->aabb[1][2] < m_sorted[i].m_aabb[1][2]) aux->aabb[1][2] = m_sorted[i].m_aabb[1][2];
			}

			// register with parent
			aux->axis = current.sorted;
			if (current.parent->left == NULL)
				current.parent->left = aux;
			else
				current.parent->right = aux;
		}

		// 2nd thing to do - sort this cell by next axis
		current.sorted = (++current.sorted)%3;
		switch (current.sorted)
		{
			case 0:
				std::sort(m_sorted.begin() + current.lower, m_sorted.begin() + current.higher, sortXFunc);
				break;
			case 1:
				std::sort(m_sorted.begin() + current.lower, m_sorted.begin() + current.higher, sortYFunc);
				break;
			case 2:
				std::sort(m_sorted.begin() + current.lower, m_sorted.begin() + current.higher, sortZFunc);
				break;
		}

		// compute new center
		center = (current.lower + current.higher)/2;

		#ifdef USE_MULTIPLE_IN_ONE
		// 3rd thing to do - If there is less than maximum elements in the cell, we save a face in the cell
		if (current.higher - current.lower <= maxElem)
		{
			for (int i = current.lower; i < current.higher; ++i)
			{
				aux->faces.push_back(sorted[i].m_face);
			}
		}
#else
		// 3rd thing to do - If there is less than maximum elements in the cell, we save a face in the cell
		if (current.higher - current.lower <= 1)
		{
			for (int i = current.lower; i < current.higher; ++i)
			{
				//aux->face = sorted[i].m_face;
				aux->triangle = *(m_sorted.begin() + i);
				/*std::cout << "Tri: " << std::endl;
				std::cout << aux->triangle.m_points[0] << std::endl;
				std::cout << aux->triangle.m_points[1] << std::endl;
				std::cout << aux->triangle.m_points[2] << std::endl;*/
			}
		}
#endif
		// otherwise, we divide a cell into two offspings
		else
		{
			if (center - current.lower > 0)
			{
				stack.push_back(StackElement(center, current.higher, current.sorted, 0.0f, aux, index));
				++index;
			}
			if (current.higher - center > 0)
			{
				stack.push_back(StackElement(current.lower, center, current.sorted, 0.0f, aux, index));
				++index;
			}
		}
		//std::cout << std::endl;
	}
	std::sort(m_sorted.begin(), m_sorted.end(), sortIndexFunc);
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Function returns a nearest intersected triangle (both directions)
// @param origin Ray origin - intersection is computed also in inverse direction
// @param vector Ray direction
// @param face Returned face handle
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
void OMTriBSPTree::getPassingFace(PointT& origin, PointT &vector, FaceHT &face)
{
	face.invalidate();
	
	OMRay ray1(origin, vector);
	OMRay ray2(origin, -vector);

	// if ray misses root, we return nothing
	if (!m_root->intersects(ray1) && !m_root->intersects(ray2)) return;
	
	std::vector<TreeElement *> stack;
	TreeElement *current;
	stack.push_back(m_root);
	ScalarT distance = std::numeric_limits<ScalarT>::max();
	ScalarT aux;
	
	// recursively search tree
	while(stack.size() != 0)
	{
		current = stack.back();
		stack.pop_back();

#ifdef USE_MULTIPLE_IN_ONE
		// when leaf, test for intersection and distance
		if (current->left == NULL || current->right == NULL)
		{
			unsigned int size = current->faces.size();
			for (unsigned int i = 0; i < size; ++i)
				if (intersects(origin, vector, current->faces[i], aux) && distance > abs(aux))
				{
					distance = abs(aux);
					face = current->faces[i];
				}
		}
#else
		// when leaf, test for intersection and distance
		if (current->left == NULL || current->right == NULL)
		{
			if ((current->triangle.intersects(ray1, aux) || current->triangle.intersects(ray2, aux)) && distance > abs(aux))
			{
				//std::cout << "xxx";
				distance = abs(aux);
				face = current->triangle.m_face;
			}
		}	
#endif
		// else test offsprings and if intersected, push into stack
		else
		{
			if (current->left->intersects(ray1)) stack.push_back(current->left);
			//else if (current->left->intersects(ray2)) stack.push_back(current->left);

			if (current->right->intersects(ray1)) stack.push_back(current->right);
			//else if (current->right->intersects(ray2)) stack.push_back(current->right);
		}
	}

}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Function returns all intersected triangles (both directions)
// @param origin Ray origin - intersection is computed also in inverse direction
// @param vector Ray direction
// @param faces Returned face handles
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
void OMTriBSPTree::getAllPassingFaces(PointT& origin, PointT &vector, std::vector<FaceHT> &faces)
{
	OMRay ray1(origin, vector);
	OMRay ray2(origin, -vector);

	// if ray misses root, we return nothing
	if (!m_root->intersects(ray1) && !m_root->intersects(ray2)) return;
	
	std::vector<TreeElement *> stack;
	TreeElement *current;
	stack.push_back(m_root);
	
	// recursively search tree
	while(stack.size() != 0)
	{
		current = stack.back();
		stack.pop_back();

#ifdef USE_MULTIPLE_IN_ONE
		if (current->left == NULL || current->right == NULL)
			for (unsigned int i = 0; i < current->faces.size(); ++i)
				faces.push_back(current->faces[i]);
#else
		// when leaf, test for intersection and distance
		if (current->left == NULL || current->right == NULL)
			faces.push_back(current->triangle.m_face);
		// else test offsprings and if intersected, push into stack
#endif
		else
		{
			if (current->left->intersects(ray1)) stack.push_back(current->left);
			//else if (current->left->intersects(ray2)) stack.push_back(current->left);

			if (current->right->intersects(ray1)) stack.push_back(current->right);
			//else if (current->right->intersects(ray2)) stack.push_back(current->right);
		}
	}
}


//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Sorting function - X axis values
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool OMTriBSPTree::sortXFunc(TriCenter const &first, TriCenter const &second)
{ 
	return (first.m_center[0]<second.m_center[0]); 
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Sorting function - Y axis values
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool OMTriBSPTree::sortYFunc(TriCenter const &first, TriCenter const &second)
{ 
	return (first.m_center[0]<second.m_center[0]); 
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Sorting function - Z axis values
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool OMTriBSPTree::sortZFunc(TriCenter const &first, TriCenter const &second)
{ 
	return (first.m_center[0]<second.m_center[0]); 
}

//////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Sorting function - Index values
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
bool OMTriBSPTree::sortIndexFunc(TriCenter const &first, TriCenter const &second)
{ 
	return (first.m_face.idx()<second.m_face.idx()); 
}
}