/* Ergo, version 3.3, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2013 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */
#ifdef USE_CHUNKS_AND_TASKS

#include <cstdio>
#include <iostream>
#include "CreateAtomCenteredBasisSet.h"
#include "integrals_general.h"

static BoxStruct get_box_for_atom_list(const std::vector<Atom> & atomList) {
  BoxStruct box;
  for(int coordIdx = 0; coordIdx < 3; coordIdx++) {
    box.minCoord[coordIdx] = atomList[0].coords[coordIdx];
    box.maxCoord[coordIdx] = atomList[0].coords[coordIdx];
  }
  int nAtoms = atomList.size();
  for(int i = 0; i < nAtoms; i++) {
    for(int coordIdx = 0; coordIdx < 3; coordIdx++) {
      ergo_real currCoord = atomList[i].coords[coordIdx];
      if(currCoord < box.minCoord[coordIdx])
	box.minCoord[coordIdx] = currCoord;
      if(currCoord > box.maxCoord[coordIdx])
	box.maxCoord[coordIdx] = currCoord;
    }
  }
  return box;
}

class CreateAtomCenteredBasisSetLowestLevel : public cht::Task {
public:
  cht::ID execute(const chttl::ChunkVector<Atom> &, const chttl::ChunkBasic<basisset_struct> &, const IntegralInfoChunk &);
  CHT_TASK_INPUT((chttl::ChunkVector<Atom>, chttl::ChunkBasic<basisset_struct>, IntegralInfoChunk));
  CHT_TASK_OUTPUT((DistrBasisSetChunk));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((CreateAtomCenteredBasisSetLowestLevel));

cht::ID CreateAtomCenteredBasisSetLowestLevel::execute(const chttl::ChunkVector<Atom> & atomList, 
						       const chttl::ChunkBasic<basisset_struct> & basisSetDef, 
						       const IntegralInfoChunk & integralInfo) {
  // Now we create basis functions for the given atoms.
  BoxStruct box = get_box_for_atom_list(atomList);
  int use_6_d_funcs = 0;
  BasisInfoStruct b(use_6_d_funcs);
  const int do_basis_normalization = 1;
  const int skip_sort_shells = 0;
  b.addBasisfuncsForAtomList(&atomList[0], 
			     atomList.size(),
			     &basisSetDef.x,
			     0,
			     NULL,
			     integralInfo.ii, 
			     0,
			     do_basis_normalization,
			     skip_sort_shells);
  return registerChunk(new DistrBasisSetChunk(b, box), cht::persistent);
}


class CombineResults : public cht::Task {
public:
  cht::ID execute(const DistrBasisSetChunk &, const DistrBasisSetChunk &);
  CHT_TASK_INPUT((DistrBasisSetChunk,  // part 1
		  DistrBasisSetChunk   // part 2
		  ));
  CHT_TASK_OUTPUT((DistrBasisSetChunk));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((CombineResults));

cht::ID CombineResults::execute(const DistrBasisSetChunk & part1, const DistrBasisSetChunk & part2) {
  cht::ChunkID const & cid_part1 = getInputChunkID(part1);
  cht::ChunkID const & cid_part2 = getInputChunkID(part2);
  int noOfBasisFuncs = part1.noOfBasisFuncs + part2.noOfBasisFuncs;
  BoxStruct box;
  for(int coordIdx = 0; coordIdx < 3; coordIdx++) {
    ergo_real min1 = part1.boundingBoxForCenters.minCoord[coordIdx];
    ergo_real min2 = part2.boundingBoxForCenters.minCoord[coordIdx];
    if(min1 < min2)
      box.minCoord[coordIdx] = min1;
    else
      box.minCoord[coordIdx] = min2;
    ergo_real max1 = part1.boundingBoxForCenters.maxCoord[coordIdx];
    ergo_real max2 = part2.boundingBoxForCenters.maxCoord[coordIdx];
    if(max1 > max2)
      box.maxCoord[coordIdx] = max1;
    else
      box.maxCoord[coordIdx] = max2;
  }
  ergo_real maxExtent = part1.maxExtent;
  if(part2.maxExtent > maxExtent)
    maxExtent = part2.maxExtent;
  return registerChunk(new DistrBasisSetChunk(noOfBasisFuncs, box, maxExtent, cid_part1, cid_part2), cht::persistent);
}
  

CHT_TASK_TYPE_IMPLEMENTATION((CreateAtomCenteredBasisSet));

cht::ID CreateAtomCenteredBasisSet::execute(const chttl::ChunkVector<Atom> & atomList, 
					    const cht::ChunkID & cid_basisSetDef, 
					    const cht::ChunkID & cid_integralInfo, 
					    const chttl::ChunkBasic<ergo_real> & coordDiffLimit) {
  cht::ChunkID const & cid_atomList = getInputChunkID(atomList);
  int nAtoms = atomList.size();
  if(nAtoms < 1)
    throw std::runtime_error("Error in CreateAtomCenteredBasisSet::execute(), (nAtoms < 1).");
  // Get bounding box.
  BoxStruct box = get_box_for_atom_list(atomList);
  // Check which coord direction has largest diff.
  ergo_real maxDiff = 0;
  int maxDiffCoordIdx = 0;
  for(int coordIdx = 0; coordIdx < 3; coordIdx++) {
    ergo_real currDiff = box.maxCoord[coordIdx] - box.minCoord[coordIdx];
    if(currDiff > maxDiff) {
      maxDiff = currDiff;
      maxDiffCoordIdx = coordIdx;
    }
  }
  // If maxDiff is small enough, let this be lowest level.
  if(maxDiff < coordDiffLimit.x) {
    return registerTask<CreateAtomCenteredBasisSetLowestLevel>(cid_atomList, cid_basisSetDef, cid_integralInfo, cht::persistent);
  }
  // Now we know this is not lowest level.
  // Split box along longest dimension adn register a new task for each part.
  ergo_real midCoordValue = (box.minCoord[maxDiffCoordIdx] + box.maxCoord[maxDiffCoordIdx]) / 2;
  // Now create two new atom lists (bucket sort).
  std::vector<Atom> list1(nAtoms);
  std::vector<Atom> list2(nAtoms);
  int count1 = 0;
  int count2 = 0;
  for(int i = 0; i < nAtoms; i++) {
    if(atomList[i].coords[maxDiffCoordIdx] < midCoordValue)
      list1[count1++] = atomList[i];
    else
      list2[count2++] = atomList[i];
  }
  list1.resize(count1);
  list2.resize(count2);
  if(count1 <= 0 || count2 <= 0)
    throw std::runtime_error("Error in CreateAtomCenteredBasisSet::execute(), (count1 <= 0 || count2 <= 0).");
  // Create new coordDiffLimit chunk (since we do not have the ChunkID)
  cht::ChunkID cid_coordDiffLimit = registerChunk(new chttl::ChunkBasic<ergo_real>(coordDiffLimit.x));
  // Register two new tasks of the same type as this one.
  // Part 1
  cht::ChunkID cid_list1 = registerChunk(new chttl::ChunkVector<Atom>(list1));
  cht::ID id1 = registerTask<CreateAtomCenteredBasisSet>(cid_list1, cid_basisSetDef, cid_integralInfo, cid_coordDiffLimit, cht::persistent);
  // Part 2
  cht::ChunkID cid_list2 = registerChunk(new chttl::ChunkVector<Atom>(list2));
  cht::ID id2 = registerTask<CreateAtomCenteredBasisSet>(cid_list2, cid_basisSetDef, cid_integralInfo, cid_coordDiffLimit, cht::persistent);
  // Register a third task to combine the two sub-results.
  return registerTask<CombineResults>(id1, id2, cht::persistent);
} // end execute


class SetFuncIndexesForBasisSetParts : public cht::Task {
public:
  cht::ID execute(const DistrBasisSetChunk &, const DistrBasisSetChunk &, const chttl::ChunkVector<int> &);
  CHT_TASK_INPUT((DistrBasisSetChunk, DistrBasisSetChunk, chttl::ChunkVector<int>));
  CHT_TASK_OUTPUT((DistrBasisSetChunk));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((SetFuncIndexesForBasisSetParts));

cht::ID SetFuncIndexesForBasisSetParts::execute(const DistrBasisSetChunk & basisSet1, 
						const DistrBasisSetChunk & basisSet2, 
						const chttl::ChunkVector<int> & indexList) {
  cht::ChunkID const & cid_basisSet1 = getInputChunkID(basisSet1);
  cht::ChunkID const & cid_basisSet2 = getInputChunkID(basisSet2);
  int n1 = basisSet1.noOfBasisFuncs;
  int n2 = basisSet2.noOfBasisFuncs;
  std::vector<int> indexList1(n1);
  std::vector<int> indexList2(n2);
  for(int i = 0; i < n1; i++)
    indexList1[i] = indexList[i];
  for(int i = 0; i < n2; i++)
    indexList2[i] = indexList[n1+i];
  // Part 1
  cht::ChunkID cid_list1 = registerChunk(new chttl::ChunkVector<int>(indexList1));
  cht::ID id1 = registerTask<SetFuncIndexesForBasisSet>(cid_basisSet1, cid_list1, cht::persistent);
  // Part 2
  cht::ChunkID cid_list2 = registerChunk(new chttl::ChunkVector<int>(indexList2));
  cht::ID id2 = registerTask<SetFuncIndexesForBasisSet>(cid_basisSet2, cid_list2, cht::persistent);
  // Register a third task to combine the two sub-results.
  return registerTask<CombineResults>(id1, id2, cht::persistent);
}


CHT_TASK_TYPE_IMPLEMENTATION((SetFuncIndexesForBasisSet));

cht::ID SetFuncIndexesForBasisSet::execute(const DistrBasisSetChunk & basisSet, 
					   const chttl::ChunkVector<int> & indexList) {
  cht::ChunkID const & cid_indexList = getInputChunkID(indexList);
  // Check if lowest level.
  if(basisSet.noOfBasisFuncs == basisSet.basisInfo.noOfBasisFuncs) {
    // This is lowest level.
    if(indexList.size() != basisSet.noOfBasisFuncs)
      throw std::runtime_error("Error in SetFuncIndexesForBasisSet::execute(): (indexList.x.size() != basisSet.noOfBasisFuncs).");
    return registerChunk(new DistrBasisSetChunk(basisSet.basisInfo, basisSet.boundingBoxForCenters, indexList), cht::persistent);
  } // End if lowest level
  // This is not lowest level.
  cht::ChunkID cid1 = basisSet.cid_child_chunks[0];
  cht::ChunkID cid2 = basisSet.cid_child_chunks[1];
  return registerTask<SetFuncIndexesForBasisSetParts>(cid1, cid2, cid_indexList, cht::persistent);
}



CHT_TASK_TYPE_IMPLEMENTATION((SetExtentsForBasisSet));

cht::ID SetExtentsForBasisSet::execute(const DistrBasisSetChunk & basisSet, 
				       const chttl::ChunkBasic<ergo_real> & maxAbsValue) {
  cht::ChunkID const & cid_maxAbsValue = getInputChunkID(maxAbsValue);
  // Check if lowest level.
  if(basisSet.noOfBasisFuncs == basisSet.basisInfo.noOfBasisFuncs) {
    // This is lowest level.
    int n = basisSet.noOfBasisFuncs;
    std::vector<ergo_real> extentList(n);
    get_basis_func_extent_list(basisSet.basisInfo, &extentList[0], maxAbsValue.x);
    return registerChunk(new DistrBasisSetChunk(basisSet.basisInfo, basisSet.boundingBoxForCenters, basisSet.basisFuncIndexList, extentList), cht::persistent);
  } // End if lowest level
  // This is not lowest level.
  cht::ChunkID cid1 = basisSet.cid_child_chunks[0];
  cht::ChunkID cid2 = basisSet.cid_child_chunks[1];
  cht::ID id1 = registerTask<SetExtentsForBasisSet>(cid1, cid_maxAbsValue, cht::persistent);
  cht::ID id2 = registerTask<SetExtentsForBasisSet>(cid2, cid_maxAbsValue, cht::persistent);
  return registerTask<CombineResults>(id1, id2, cht::persistent);
}


class MaxTask : public cht::Task {
public:
  cht::ID execute(const chttl::ChunkBasic<ergo_real> &, const chttl::ChunkBasic<ergo_real> &);
  CHT_TASK_INPUT((chttl::ChunkBasic<ergo_real>, chttl::ChunkBasic<ergo_real>));
  CHT_TASK_OUTPUT((chttl::ChunkBasic<ergo_real>));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((MaxTask));

cht::ID MaxTask::execute(const chttl::ChunkBasic<ergo_real> & a, 
			 const chttl::ChunkBasic<ergo_real> & b) {
  ergo_real max;
  if(a.x > b.x)
    max = a.x;
  else
    max = b.x;
  return registerChunk(new chttl::ChunkBasic<ergo_real>(max), cht::persistent);
}


CHT_TASK_TYPE_IMPLEMENTATION((GetLargestSimpleIntegralForBasisSet));

cht::ID GetLargestSimpleIntegralForBasisSet::execute(const DistrBasisSetChunk & basisSet) {
  // Check if lowest level.
  if(basisSet.noOfBasisFuncs == basisSet.basisInfo.noOfBasisFuncs) {
    // This is lowest level.
    ergo_real largest_simple_integral = get_largest_simple_integral(basisSet.basisInfo);
    return registerChunk(new chttl::ChunkBasic<ergo_real>(largest_simple_integral), cht::persistent);
  } // End if lowest level
  // This is not lowest level.
  cht::ChunkID cid1 = basisSet.cid_child_chunks[0];
  cht::ChunkID cid2 = basisSet.cid_child_chunks[1];
  cht::ID id1 = registerTask<GetLargestSimpleIntegralForBasisSet>(cid1);
  cht::ID id2 = registerTask<GetLargestSimpleIntegralForBasisSet>(cid2);
  return registerTask<MaxTask>(id1, id2, cht::persistent);
}




class CombineCoordLists : public cht::Task {
public:
  cht::ID execute(const chttl::ChunkVector<CoordStruct> &, const chttl::ChunkVector<CoordStruct> &);
  CHT_TASK_INPUT((chttl::ChunkVector<CoordStruct>, chttl::ChunkVector<CoordStruct>));
  CHT_TASK_OUTPUT((chttl::ChunkVector<CoordStruct>));
  CHT_TASK_TYPE_DECLARATION;
};

CHT_TASK_TYPE_IMPLEMENTATION((CombineCoordLists));

cht::ID CombineCoordLists::execute(const chttl::ChunkVector<CoordStruct> & coordList1, 
				   const chttl::ChunkVector<CoordStruct> & coordList2) {
  int n1 = coordList1.size();
  int n2 = coordList2.size();
  std::vector<CoordStruct> coordListCombined(n1+n2);
  for(int i = 0; i < n1; i++)
    coordListCombined[i] = coordList1[i];
  for(int i = 0; i < n2; i++)
    coordListCombined[n1+i] = coordList2[i];
  return registerChunk(new chttl::ChunkVector<CoordStruct>(coordListCombined), cht::persistent);
}


CHT_TASK_TYPE_IMPLEMENTATION((GetBasisSetCoords));

cht::ID GetBasisSetCoords::execute(const DistrBasisSetChunk & basisSet) {
  int n = basisSet.noOfBasisFuncs;
  // Check if lowest level.
  if(basisSet.noOfBasisFuncs == basisSet.basisInfo.noOfBasisFuncs) {
    // This is lowest level.
    std::vector<CoordStruct> basisFuncCoordList(n);
    for(int i = 0; i < n; i++) {
      for(int coordIdx = 0; coordIdx < 3; coordIdx++)
	basisFuncCoordList[i].coords[coordIdx] = basisSet.basisInfo.basisFuncList[i].centerCoords[coordIdx];
    }
    return registerChunk(new chttl::ChunkVector<CoordStruct>(basisFuncCoordList), cht::persistent);
  } // End if lowest level
  // This is not lowest level.
  cht::ID id1 = registerTask<GetBasisSetCoords>(basisSet.cid_child_chunks[0]);
  cht::ID id2 = registerTask<GetBasisSetCoords>(basisSet.cid_child_chunks[1]);
  return registerTask<CombineCoordLists>(id1, id2, cht::persistent);
}


#endif
