
#include "QTree.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>

#define G (6.67e-11) // gravitational constant
#define THETA 0.9    // barnes-hut parameter

QTree::QTree()
{
  m_root = NULL;
}

QTree::~QTree()
{
  deleteSubTree(m_root);
}

void QTree::deleteSubTree(QTreeNode* root)
{
  for (int i=0; i<4; i++) {
    if (root->children[i] != NULL) {
      deleteSubTree(root->children[i]);
    }
  }
  delete root;
}

void QTree::build(Body* bodies, int num_bodies, 
		  double min_x, double max_x, double min_y, double max_y)
{
  m_root = new QTreeNode(Rectangle(Point(min_x, max_y), Point(max_x, min_y)));
  for (int i=0;i<num_bodies;i++) {
    insert(&bodies[i], m_root);
  }
}

void QTree::insert(Body * b, QTreeNode* root)
{
  if (root->num_bodies > 1) {
    for (int i=0; i<4;i++) {
      if (root->children[i]->region.contains(b->point_mass.center_of_mass)) {
        root->num_bodies++;
        insert(b, root->children[i]);
        recalculateMass(root);
        break;
      }
    }
  } else if (root->num_bodies == 1) {
    root->children[UPPER_LEFT_CHILD] = new QTreeNode(Rectangle(root->region.upperLeftCorner(), root->region.middle()));
    root->children[UPPER_RIGHT_CHILD] = new QTreeNode(Rectangle(Point(root->region.middle().x, root->region.upperLeftCorner().y), 
                                                                Point(root->region.lowerRightCorner().x, root->region.middle().y)));
    root->children[LOWER_LEFT_CHILD] = new QTreeNode(Rectangle(Point(root->region.upperLeftCorner().x, root->region.middle().y),
                                                               Point(root->region.middle().x, root->region.lowerRightCorner().y)));
    root->children[LOWER_RIGHT_CHILD] = new QTreeNode(Rectangle(root->region.middle(), root->region.lowerRightCorner()));
    for (int j=0;j<4;j++) {
      if (root->children[j]->region.contains(root->point_mass.center_of_mass)) {
        root->children[j]->point_mass = root->point_mass;
        root->children[j]->num_bodies = 1;
        break; // no need to keep looking
      }
    }
    for (int j=0;j<4;j++) {
      if (root->children[j]->region.contains(b->point_mass.center_of_mass)) {
        root->num_bodies++;
        insert(b, root->children[j]);
        recalculateMass(root);
        break;
      }
    }
  } else if (root->num_bodies == 0) {
    root->num_bodies++;
    root->point_mass = b->point_mass;
  }
}

void QTree::recalculateMass(QTreeNode* root)
{
  root->point_mass.mass = 0.0;
  double cum_x = 0.0, cum_y = 0.0;
  for (int i=0; i<4;i++) {
    root->point_mass.mass += root->children[i]->point_mass.mass;
    cum_x += root->children[i]->point_mass.mass * root->children[i]->point_mass.center_of_mass.x;
    cum_y += root->children[i]->point_mass.mass * root->children[i]->point_mass.center_of_mass.y;
  }
  assert(root->point_mass.mass != 0.0);
  root->point_mass.center_of_mass.x = cum_x/root->point_mass.mass;
  root->point_mass.center_of_mass.y = cum_y/root->point_mass.mass;
}

DoubleVector QTree::calculateForce(Body* body) 
{
  return calculateForceHelper(body, m_root);
}

DoubleVector QTree::calculateForceHelper(Body* body, QTreeNode* root) 
{
  DoubleVector force(0.0, 0.0);
  bool contains = root->region.contains(body->point_mass.center_of_mass);
  double d = 0.0, d_x = 0.0, d_y = 0.0;
  d_x = root->point_mass.center_of_mass.x - body->point_mass.center_of_mass.x;
  d_y = root->point_mass.center_of_mass.y - body->point_mass.center_of_mass.y;
  d = sqrt(d_x*d_x + d_y*d_y);
  if (d == 0.0) return force;
  if (!contains && root->num_bodies == 1) {
    double mag = (G * body->point_mass.mass * root->point_mass.mass) / (d * d);
    force.x = mag * (d_x / d);
    force.y = mag * (d_y / d);
  } else {
    if ((!contains) && (d / root->region.area() < THETA)) {
      double mag = (G * body->point_mass.mass * root->point_mass.mass) / (d * d);
      force.x = mag * (d_x / d);
      force.y = mag * (d_y / d);
    } else {
      for (int i=0; i<4; i++) {
        if (root->children[i]->num_bodies > 0) {
	  DoubleVector f(calculateForceHelper(body, root->children[i]));
          force.x += f.x;
          force.y += f.y;
	}
      }
    }
  }
  return force;
}
