#pragma once

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <tuple>
#include <vector>
#include <cassert>

namespace kdbush {

template <std::uint8_t I, typename T>
struct nth {
    inline static typename std::tuple_element<I, T>::type get(const T &t) {
        return std::get<I>(t);
    }
};

template <typename TPoint, typename TIndex = std::size_t>
class KDBush {

public:
    using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>()));
    static_assert(
        std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value,
        "point component types must be identical");

    static const std::uint8_t defaultNodeSize = 64;

    KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) {
    }

    KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize)
        : KDBush(std::begin(points_), std::end(points_), nodeSize_) {
    }

    template <typename TPointIter>
    KDBush(const TPointIter &points_begin,
           const TPointIter &points_end,
           const std::uint8_t nodeSize_ = defaultNodeSize)
        : nodeSize(nodeSize_) {
        fill(points_begin, points_end);
    }

    void fill(const std::vector<TPoint> &points_) {
        fill(std::begin(points_), std::end(points_));
    }

    template <typename TPointIter>
    void fill(const TPointIter &points_begin, const TPointIter &points_end) {
        assert(points.empty());
        const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end));

        points.reserve(size);
        ids.reserve(size);

        TIndex i = 0;
        for (auto p = points_begin; p != points_end; p++) {
            points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p));
            ids.push_back(i++);
        }

        sortKD(0, size - 1, 0);
    }

    template <typename TVisitor>
    void range(const TNumber minX,
               const TNumber minY,
               const TNumber maxX,
               const TNumber maxY,
               const TVisitor &visitor) {
        range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0);
    }

    template <typename TVisitor>
    void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) {
        within(qx, qy, r, visitor, 0, static_cast<TIndex>(ids.size() - 1), 0);
    }

private:
    std::vector<TIndex> ids;
    std::vector<std::pair<TNumber, TNumber>> points;
    std::uint8_t nodeSize;

    template <typename TVisitor>
    void range(const TNumber minX,
               const TNumber minY,
               const TNumber maxX,
               const TNumber maxY,
               const TVisitor &visitor,
               const TIndex left,
               const TIndex right,
               const std::uint8_t axis) {

        if (right - left <= nodeSize) {
            for (auto i = left; i <= right; i++) {
                const TNumber x = std::get<0>(points[i]);
                const TNumber y = std::get<1>(points[i]);
                if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[i]);
            }
            return;
        }

        const TIndex m = (left + right) >> 1;
        const TNumber x = std::get<0>(points[m]);
        const TNumber y = std::get<1>(points[m]);

        if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(ids[m]);

        if (axis == 0 ? minX <= x : minY <= y)
            range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2);

        if (axis == 0 ? maxX >= x : maxY >= y)
            range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2);
    }

    template <typename TVisitor>
    void within(const TNumber qx,
                const TNumber qy,
                const TNumber r,
                const TVisitor &visitor,
                const TIndex left,
                const TIndex right,
                const std::uint8_t axis) {

        const TNumber r2 = r * r;

        if (right - left <= nodeSize) {
            for (auto i = left; i <= right; i++) {
                const TNumber x = std::get<0>(points[i]);
                const TNumber y = std::get<1>(points[i]);
                if (sqDist(x, y, qx, qy) <= r2) visitor(ids[i]);
            }
            return;
        }

        const TIndex m = (left + right) >> 1;
        const TNumber x = std::get<0>(points[m]);
        const TNumber y = std::get<1>(points[m]);

        if (sqDist(x, y, qx, qy) <= r2) visitor(ids[m]);

        if (axis == 0 ? qx - r <= x : qy - r <= y)
            within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2);

        if (axis == 0 ? qx + r >= x : qy + r >= y)
            within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2);
    }

    void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) {
        if (right - left <= nodeSize) return;
        const TIndex m = (left + right) >> 1;
        if (axis == 0) {
            select<0>(m, left, right);
        } else {
            select<1>(m, left, right);
        }
        sortKD(left, m - 1, (axis + 1) % 2);
        sortKD(m + 1, right, (axis + 1) % 2);
    }

    template <std::uint8_t I>
    void select(const TIndex k, TIndex left, TIndex right) {

        while (right > left) {
            if (right - left > 600) {
                const double n = right - left + 1;
                const double m = k - left + 1;
                const double z = std::log(n);
                const double s = 0.5 * std::exp(2 * z / 3);
                const double r =
                    k - m * s / n + 0.5 * std::sqrt(z * s * (1 - s / n)) * (2 * m < n ? -1 : 1);
                select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s)));
            }

            const TNumber t = std::get<I>(points[k]);
            TIndex i = left;
            TIndex j = right;

            swapItem(left, k);
            if (std::get<I>(points[right]) > t) swapItem(left, right);

            while (i < j) {
                swapItem(i++, j--);
                while (std::get<I>(points[i]) < t) i++;
                while (std::get<I>(points[j]) > t) j--;
            }

            if (std::get<I>(points[left]) == t)
                swapItem(left, j);
            else {
                swapItem(++j, right);
            }

            if (j <= k) left = j + 1;
            if (k <= j) right = j - 1;
        }
    }

    void swapItem(const TIndex i, const TIndex j) {
        std::iter_swap(ids.begin() + i, ids.begin() + j);
        std::iter_swap(points.begin() + i, points.begin() + j);
    }

    TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) {
        return std::pow(ax - bx, 2) + std::pow(ay - by, 2);
    }
};

} // namespace kdbush
