//  ************************************************************************************************
//
//  BornAgain: simulate and fit reflection and scattering
//
//! @file      Device/Data/DataUtil.cpp
//! @brief     Implements namespace DataUtil.
//!
//! @homepage  http://www.bornagainproject.org
//! @license   GNU General Public License v3 or higher (see COPYING)
//! @copyright Forschungszentrum Jülich GmbH 2018
//! @authors   Scientific Computing Group at MLZ (see CITATION, AUTHORS)
//
//  ************************************************************************************************

#include "Device/Data/DataUtil.h"
#include "Base/Axis/Frame.h"
#include "Base/Axis/MakeScale.h"
#include "Base/Axis/Scale.h"
#include "Base/Math/FourierTransform.h"
#include "Base/Util/Assert.h"
#include "Device/Data/ArrayUtil.h"
#include "Device/Data/Datafield.h"
#include <algorithm>
#include <cmath>
#include <functional>
#include <tspectrum.h> // third-party code, extracted from CERN ROOT (class TSpectrum2)

namespace {

std::vector<std::vector<double>> FT2DArray(const std::vector<std::vector<double>>& signal)
{
    FourierTransform ft;
    std::vector<std::vector<double>> result;
    ft.fft(signal, result);
    ft.fftshift(result); // low frequency to center of array
    return result;
}

} // namespace

std::vector<std::vector<double>>
DataUtil::Data::invertAxis(int axis, const std::vector<std::vector<double>>& original)
{
    std::vector<std::vector<double>> inverse = original;

    size_t orig_rows = original.size();

    if (axis == 1) {
        for (size_t i = 0; i < orig_rows; ++i)
            inverse[i] = original[orig_rows - i - 1];
    } else if (axis == 0) {
        ASSERT(orig_rows > 0);
        size_t orig_cols = original.front().size();
        for (size_t i = 0; i < orig_rows; ++i)
            for (size_t j = 0; j < orig_cols; ++j)
                inverse[i][j] = original[i][orig_cols - j - 1];
    } else
        ASSERT(false);

    return inverse;
}

std::vector<std::vector<double>>
DataUtil::Data::transpose(const std::vector<std::vector<double>>& original)
{
    ASSERT(original.size() > 0);

    size_t orig_rows = original.size();
    size_t orig_cols = original.front().size();

    std::vector<std::vector<double>> transposed(orig_cols, std::vector<double>(orig_rows));

    for (size_t i = 0; i < orig_rows; ++i)
        for (size_t j = 0; j < orig_cols; ++j)
            transposed[j][i] = original[i][j];

    return transposed;
}

std::unique_ptr<Datafield> DataUtil::Data::createRearrangedDataSet(const Datafield& data, int n)
{
    ASSERT(data.rank() == 2);
    n = (4 + n % 4) % 4;
    if (n == 0)
        return std::unique_ptr<Datafield>(data.clone());

    std::unique_ptr<Datafield> output;
    std::function<void(std::vector<int>&)> index_mapping;

    if (n == 2) {
        output.reset(new Datafield({data.axis(0).clone(), data.axis(1).clone()}));
        const int end_bin_x = static_cast<int>(data.axis(0).size()) - 1;
        const int end_bin_y = static_cast<int>(data.axis(1).size()) - 1;
        index_mapping = [end_bin_x, end_bin_y](std::vector<int>& inds) {
            inds[0] = end_bin_x - inds[0];
            inds[1] = end_bin_y - inds[1];
        };

    } else {
        output.reset(new Datafield({data.axis(1).clone(), data.axis(0).clone()}));
        const size_t rev_axis_i = n % 3;
        const size_t end_bin = data.axis(rev_axis_i).size() - 1;
        index_mapping = [rev_axis_i, end_bin](std::vector<int>& inds) {
            const int tm_index = inds[rev_axis_i];
            inds[rev_axis_i] = inds[rev_axis_i ^ 1];
            inds[rev_axis_i ^ 1] = static_cast<int>(end_bin) - tm_index;
        };
    }

    for (size_t i = 0, size = data.size(); i < size; ++i) {
        std::vector<int> axis_inds = data.frame().allIndices(i);
        index_mapping(axis_inds);
        size_t iout = output->frame().toGlobalIndex(
            {static_cast<unsigned>(axis_inds[0]), static_cast<unsigned>(axis_inds[1])});
        (*output)[iout] = data[i];
    }
    return output;
}

std::vector<std::vector<double>> DataUtil::Data::create2DArrayfromDatafield(const Datafield& data)
{
    ASSERT(data.rank() == 2);
    std::vector<std::vector<double>> array_2d;
    std::vector<double> row_vec; // row vector for constructing each row of 2D array

    size_t nrows = data.axis(0).size();
    size_t ncols = data.axis(1).size();

    size_t it = 0; // iterator of 'data'
    for (size_t row = 0; row < nrows; row++) {
        row_vec.clear();
        for (size_t col = 0; col < ncols; col++) {
            row_vec.push_back(data[it]);
            it++;
        }
        array_2d.push_back(row_vec);
    }

    return array_2d;
}

std::unique_ptr<Datafield>
DataUtil::Data::vecvecToDatafield(const std::vector<std::vector<double>>& array_2d)
{
    size_t nrows = array_2d.size();
    size_t ncols = array_2d[0].size();

    std::vector<const Scale*> axes{newEquiDivision("x", nrows, 0.0, double(nrows)),
                                   newEquiDivision("y", ncols, 0.0, double(ncols))};
    std::vector<double> out;
    out.reserve(nrows * ncols);
    for (size_t row = 0; row < nrows; row++) {
        for (size_t col = 0; col < ncols; col++)
            out.push_back(array_2d[row][col]);
    }
    return std::make_unique<Datafield>(std::move(axes), out);
}

std::unique_ptr<Datafield> DataUtil::Data::createFFT(const Datafield& data)
{
    auto array_2d = DataUtil::Data::create2DArrayfromDatafield(data);
    auto fft_array_2d = FT2DArray(array_2d);
    return DataUtil::Data::vecvecToDatafield(fft_array_2d);
}

Datafield* DataUtil::Data::importArrayToDatafield(const std::vector<double>& vec)
{
    return DataUtil::Array::createPField1D(vec).release();
}

Datafield* DataUtil::Data::importArrayToDatafield(const std::vector<std::vector<double>>& vec)
{
    return DataUtil::Array::createPField2D(vec).release();
}

std::vector<std::pair<double, double>> DataUtil::Data::FindPeaks(const Datafield& data,
                                                                 double sigma,
                                                                 const std::string& option,
                                                                 double threshold)
{
    std::vector<std::vector<double>> arr = DataUtil::Array::createVector2D(data);
    tspectrum::Spectrum2D spec;
    auto peaks = spec.find_peaks(arr, sigma, option, threshold);

    // coordinates of peaks in histogram axes units
    std::vector<std::pair<double, double>> result;

    for (const auto& p : peaks) {
        double row_value = p.first;
        double col_value = p.second;

        auto xaxis_index = static_cast<size_t>(col_value);
        size_t yaxis_index = data.yAxis().size() - 1 - static_cast<size_t>(row_value);

        Bin1D xbin = data.xAxis().bin(xaxis_index);
        Bin1D ybin = data.yAxis().bin(yaxis_index);

        double dx = col_value - static_cast<size_t>(col_value);
        double dy = -1.0 * (row_value - static_cast<size_t>(row_value));

        double x = xbin.center() + xbin.binSize() * dx;
        double y = ybin.center() + ybin.binSize() * dy;

        result.emplace_back(x, y);
    }
    return result;
}
