// Author: Stefan Wunsch CERN  04/2019

/*************************************************************************
 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers.               *
 * All rights reserved.                                                  *
 *                                                                       *
 * For the licensing terms see $ROOTSYS/LICENSE.                         *
 * For the list of contributors see $ROOTSYS/README/CREDITS.             *
 *************************************************************************/

#include <ROOT/RDataFrame.hxx>
#include <ROOT/RDataSource.hxx>
#include <ROOT/RVec.hxx>
#include <ROOT/TSeq.hxx>

#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <typeinfo>
#include <utility>
#include <vector>

#ifndef ROOT_RVECDS
#define ROOT_RVECDS

namespace ROOT {

namespace Internal {

namespace RDF {

class R__CLING_PTRCHECK(off) RVecDSColumnReader final : public ROOT::Detail::RDF::RColumnReaderBase {
   TPointerHolder *fPtrHolder;
   void *GetImpl(Long64_t) final { return fPtrHolder->GetPointer(); }

public:
   RVecDSColumnReader(TPointerHolder *ptrHolder) : fPtrHolder(ptrHolder) {}
};

////////////////////////////////////////////////////////////////////////////////////////////////
/// \brief A RDataSource implementation which takes a collection of RVecs, which
/// are able to adopt data from Numpy arrays
///
/// This component allows to create a data source on a set of columns with data
/// coming from RVecs. The adoption of externally provided data, e.g., via Numpy
/// arrays, with RVecs allows to read arbitrary data from memory.
/// In addition, the data source has to keep a reference on the Python owned data
/// so that the lifetime of the data is tied to the datasource.
template <typename... ColumnTypes>
class RVecDS final : public ROOT::RDF::RDataSource {
   using PointerHolderPtrs_t = std::vector<ROOT::Internal::RDF::TPointerHolder *>;

   std::tuple<ROOT::RVec<ColumnTypes>...> fColumns;
   std::vector<std::string> fColNames;
   std::unordered_map<std::string, std::string> fColTypesMap;
   // The role of the fPointerHoldersModels is to be initialised with the pack
   // of arguments in the constrcutor signature at construction time
   // Once the number of slots is known, the fPointerHolders are initialised
   // according to the models.
   PointerHolderPtrs_t fPointerHoldersModels;
   std::vector<PointerHolderPtrs_t> fPointerHolders;
   std::vector<std::pair<ULong64_t, ULong64_t>> fEntryRanges{};
   std::function<void()> fDeleteRVecs;

   Record_t GetColumnReadersImpl(std::string_view, const std::type_info &) { return {}; }

   size_t GetEntriesNumber() { return std::get<0>(fColumns).size(); }
   template <std::size_t... S>
   void SetEntryHelper(unsigned int slot, ULong64_t entry, std::index_sequence<S...>)
   {
      std::initializer_list<int> expander{
         (*static_cast<ColumnTypes *>(fPointerHolders[S][slot]->GetPointer()) = std::get<S>(fColumns)[entry], 0)...};
      (void)expander; // avoid unused variable warnings
   }

   template <std::size_t... S>
   void ColLengthChecker(std::index_sequence<S...>)
   {
      if (sizeof...(S) < 2)
         return;

      const std::vector<size_t> colLengths{std::get<S>(fColumns).size()...};
      const auto expectedLen = colLengths[0];
      std::string err;
      for (auto i : TSeqI(1, colLengths.size())) {
         if (expectedLen != colLengths[i]) {
            err += "Column \"" + fColNames[i] + "\" and column \"" + fColNames[0] +
                   "\" have different lengths: " + std::to_string(expectedLen) + " and " +
                   std::to_string(colLengths[i]);
         }
      }
      if (!err.empty()) {
         throw std::runtime_error(err);
      }
   }

protected:
   std::string AsString() { return "Numpy data source"; };

public:
   RVecDS(std::function<void()> deleteRVecs, std::pair<std::string, ROOT::RVec<ColumnTypes>> const &...colsNameVals)
      : fColumns(colsNameVals.second...),
        fColNames{colsNameVals.first...},
        fColTypesMap({{colsNameVals.first, ROOT::Internal::RDF::TypeID2TypeName(typeid(ColumnTypes))}...}),
        fPointerHoldersModels({new ROOT::Internal::RDF::TTypedPointerHolder<ColumnTypes>(new ColumnTypes())...}),
        fDeleteRVecs(deleteRVecs)
   {
   }

   // Rule of five
   RVecDS(const RVecDS &) = delete;
   RVecDS &operator=(const RVecDS &) = delete;
   RVecDS(RVecDS &&) = delete;
   RVecDS &operator=(RVecDS &&) = delete;
   ~RVecDS() final
   {
      for (auto &&ptrHolderv : fPointerHolders) {
         for (auto &&ptrHolder : ptrHolderv) {
            delete ptrHolder;
         }
      }
      // Release the data associated to this data source
      fDeleteRVecs();
   }

   std::unique_ptr<ROOT::Detail::RDF::RColumnReaderBase>
   GetColumnReaders(unsigned int slot, std::string_view colName, const std::type_info &id) final
   {
      auto colNameStr = std::string(colName);

      auto it = fColTypesMap.find(colNameStr);
      if (fColTypesMap.end() == it) {
         std::string err = "The specified column name, \"" + colNameStr + "\" is not known to the data source.";
         throw std::runtime_error(err);
      }

      const auto &colIdName = it->second;
      const auto idName = ROOT::Internal::RDF::TypeID2TypeName(id);
      if (colIdName != idName) {
         std::string err = "Column " + colNameStr + " has type " + colIdName +
                           " while the id specified is associated to type " + idName;
         throw std::runtime_error(err);
      }

      if (auto colNameIt = std::find(fColNames.begin(), fColNames.end(), colNameStr); colNameIt != fColNames.end()) {
         const auto index = std::distance(fColNames.begin(), colNameIt);
         return std::make_unique<ROOT::Internal::RDF::RVecDSColumnReader>(fPointerHolders[index][slot]);
      }

      throw std::runtime_error("Could not find column name \"" + colNameStr + "\" in available column names.");
   }

   const std::vector<std::string> &GetColumnNames() const { return fColNames; }

   std::vector<std::pair<ULong64_t, ULong64_t>> GetEntryRanges()
   {
      auto entryRanges(std::move(fEntryRanges)); // empty fEntryRanges
      return entryRanges;
   }

   std::string GetTypeName(std::string_view colName) const
   {
      const auto key = std::string(colName);
      return fColTypesMap.at(key);
   }

   bool HasColumn(std::string_view colName) const
   {
      const auto key = std::string(colName);
      const auto endIt = fColTypesMap.end();
      return endIt != fColTypesMap.find(key);
   }

   bool SetEntry(unsigned int slot, ULong64_t entry)
   {
      SetEntryHelper(slot, entry, std::index_sequence_for<ColumnTypes...>());
      return true;
   }

   void SetNSlots(unsigned int nSlots) final
   {
      fNSlots = nSlots;
      const auto nCols = fColNames.size();
      fPointerHolders.resize(nCols); // now we need to fill it with the slots, all of the same type
      auto colIndex = 0U;
      for (auto &&ptrHolderv : fPointerHolders) {
         for (auto slot : ROOT::TSeqI(fNSlots)) {
            auto ptrHolder = fPointerHoldersModels[colIndex]->GetDeepCopy();
            ptrHolderv.emplace_back(ptrHolder);
            (void)slot;
         }
         colIndex++;
      }
      for (auto &&ptrHolder : fPointerHoldersModels)
         delete ptrHolder;
   }

   void Initialize()
   {
      ColLengthChecker(std::index_sequence_for<ColumnTypes...>());
      const auto nEntries = GetEntriesNumber();
      const auto nEntriesInRange = nEntries / fNSlots; // between integers. Should make smaller?
      auto reminder = 1U == fNSlots ? 0 : nEntries % fNSlots;
      fEntryRanges.resize(fNSlots);
      auto init = 0ULL;
      auto end = 0ULL;
      for (auto &&range : fEntryRanges) {
         end = init + nEntriesInRange;
         if (0 != reminder) { // Distribute the reminder among the first chunks
            reminder--;
            end += 1;
         }
         range.first = init;
         range.second = end;
         init = end;
      }
   }

   std::string GetLabel() { return "RVecDS"; }
};

// Factory to create datasource able to read Numpy arrays through RVecs.
// \param pyRVecs Pointer to PyObject holding RVecs.
//                The RVecs itself hold a reference to the associated Numpy arrays so that
//                the data cannot go out of scope as long as the datasource survives.
template <typename... ColumnTypes>
std::unique_ptr<RDataFrame>
MakeRVecDataFrame(std::function<void()> deleteRVecs,
                  std::pair<std::string, ROOT::RVec<ColumnTypes>> const &...colNameProxyPairs)
{
   return std::make_unique<RDataFrame>(std::make_unique<RVecDS<ColumnTypes...>>(deleteRVecs, colNameProxyPairs...));
}

} // namespace RDF
} // namespace Internal
} // namespace ROOT

#endif // ROOT_RNUMPYDS
