//===----------------------------------------------------------------------===//
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCPP_COORDINATE
#define _LIBCPP_COORDINATE

#include <initializer_list>
#include <type_traits>
#include <limits>
#include <cassert>
#include <iterator>
#include <iostream>

namespace std
{

template <size_t...> struct __std_indices {};

template <size_t _Sp, class _IntTuple, size_t _Ep>
struct __std_make_indices_imp;

template <size_t _Sp, size_t ..._Indices, size_t _Ep>
struct __std_make_indices_imp<_Sp, __std_indices<_Indices...>, _Ep>
{
    typedef typename __std_make_indices_imp<_Sp+1, __std_indices<_Indices..., _Sp>, _Ep>::type type;
};

template <size_t _Ep, size_t ..._Indices>
struct __std_make_indices_imp<_Ep, __std_indices<_Indices...>, _Ep>
{
    typedef __std_indices<_Indices...> type;
};

template <size_t _Ep, size_t _Sp = 0>
struct __std_make_indices
{
    static_assert(_Sp <= _Ep, "__make_indices input error");
    typedef typename __std_make_indices_imp<_Sp, __std_indices<>, _Ep>::type type;
};

inline const bool coordinate_check() { return true; }
template <typename ..._Tp>
inline const bool coordinate_check(const bool& t, const _Tp&... tail)
{
    if (t)
        return coordinate_check(tail...);
    else
        return false;
}

template <size_t N> class bounds;

template <size_t _Ip>
class __coordinate_leaf {
    ptrdiff_t __idx;
    int dummy;
public:
    explicit __coordinate_leaf(ptrdiff_t __t) restrict(amp,cpu) : __idx(__t) {}

    __coordinate_leaf& operator=(const ptrdiff_t __t) restrict(amp,cpu) {
        __idx = __t;
        return *this;
    }
    __coordinate_leaf& operator+=(const ptrdiff_t __t) restrict(amp,cpu) {
        __idx += __t;
        return *this;
    }
    __coordinate_leaf& operator-=(const ptrdiff_t __t) restrict(amp,cpu) {
        __idx -= __t;
        return *this;
    }
    __coordinate_leaf& operator*=(const ptrdiff_t __t) restrict(amp,cpu) {
        __idx *= __t;
        return *this;
    }
    __coordinate_leaf& operator/=(const ptrdiff_t __t) restrict(amp,cpu) {
        __idx /= __t;
        return *this;
    }
    ptrdiff_t& get()       restrict(amp,cpu) { return __idx; }
    const ptrdiff_t& get() const restrict(amp,cpu) { return __idx; }
};

template <class ..._Tp>
inline void __std_swallow(_Tp&&...) /*noexcept*/ restrict(amp,cpu) {}

inline const ptrdiff_t coordinate_mul()
{
    return 1;
}

template <typename ..._Tp>
inline const ptrdiff_t coordinate_mul(const ptrdiff_t& t, const _Tp&... tail)
{
    const ptrdiff_t ret =  t * coordinate_mul(tail...);
#if __KALMAR_ACCELERATOR__ != 1
    assert(ret <= numeric_limits<ptrdiff_t>::max() && ret >= 0);
#endif
    return ret;
}

template <typename _Indx> struct __coordinate_impl;

template <size_t ...N>
struct __coordinate_impl<__std_indices<N...>>
    : public __coordinate_leaf<N>...
{
private:
    template<typename ..._Up>
        explicit __coordinate_impl(_Up... __u) restrict(amp,cpu)
        : __coordinate_leaf<N>(__u)... {}

public:
    __coordinate_impl() restrict(amp,cpu)
        : __coordinate_leaf<N>(0)... {}

    __coordinate_impl(initializer_list<ptrdiff_t> il) restrict(amp,cpu) :
        __coordinate_leaf<N>(*(il.begin() + N))... {}

    __coordinate_impl(const __coordinate_impl& other) restrict(amp,cpu)
        : __coordinate_impl(static_cast<const __coordinate_leaf<N>&>(other).get()...) {}

    __coordinate_impl(ptrdiff_t component) restrict(amp,cpu)
        : __coordinate_leaf<N>(component)... {}

    const ptrdiff_t& operator[] (size_t c) const restrict(amp,cpu) {
        return static_cast<const __coordinate_leaf<0>&>(*((const __coordinate_leaf<0> *)this + c)).get();
    }
    ptrdiff_t& operator[] (size_t c) restrict(amp,cpu) {
        return static_cast<__coordinate_leaf<0>&>(*((__coordinate_leaf<0> *)this + c)).get();
    }
    __coordinate_impl& operator=(const __coordinate_impl& __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator=(static_cast<const __coordinate_leaf<N>&>(__t).get())...);
        return *this;
    }
    __coordinate_impl& operator+=(const __coordinate_impl& __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator+=(static_cast<const __coordinate_leaf<N>&>(__t).get())...);
        return *this;
    }
    __coordinate_impl& operator-=(const __coordinate_impl& __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator-=(static_cast<const __coordinate_leaf<N>&>(__t).get())...);
        return *this;
    }
    __coordinate_impl& operator*=(const __coordinate_impl& __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator*=(static_cast<const __coordinate_leaf<N>&>(__t).get())...);
        return *this;
    }
    __coordinate_impl& operator/=(const __coordinate_impl& __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator/=(static_cast<const __coordinate_leaf<N>&>(__t).get())...);
        return *this;
    }
    __coordinate_impl& operator+=(const ptrdiff_t __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator+=(__t)...);
        return *this;
    }
    __coordinate_impl& operator-=(const ptrdiff_t __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator-=(__t)...);
        return *this;
    }
    __coordinate_impl& operator*=(const ptrdiff_t __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator*=(__t)...);
        return *this;
    }
    __coordinate_impl& operator/=(const ptrdiff_t __t) restrict(amp,cpu) {
        __std_swallow(__coordinate_leaf<N>::operator/=(__t)...);
        return *this;
    }

    const ptrdiff_t size() const noexcept {
        return coordinate_mul(static_cast<const __coordinate_leaf<N>&>(*this).get()...);
    }

    const bool contains(const __coordinate_impl& __r) const noexcept {
        auto check = [&] (const ptrdiff_t& id, const ptrdiff_t& ext) { return (id >= 0) && (id <= ext); };
        return coordinate_check(check(static_cast<const __coordinate_leaf<N>&>(__r).get(),
                                      static_cast<const __coordinate_leaf<N>&>(*this).get())...);
    }

    const bool all_pos() const noexcept {
        auto check = [&] (const ptrdiff_t& id) { return (id >= 0); };
        return coordinate_check(check(static_cast<const __coordinate_leaf<N>&>(*this).get())...);
    }
};
 
extern "C" __attribute__((const)) uint32_t amp_get_global_id(unsigned int n) restrict(amp);

template<size_t N> class offset;

template<size_t N, typename _Tp>
struct offset_helper
{
    static inline void set(_Tp& now) restrict(amp,cpu) {
        now[N - 1] = static_cast<size_t>(amp_get_global_id(_Tp::rank - N));
        offset_helper<N - 1, _Tp>::set(now);
    }
};
template<typename _Tp>
struct offset_helper<1, _Tp>
{
    static inline void set(_Tp& now) restrict(amp,cpu) {
        now[0] = static_cast<size_t>(amp_get_global_id(_Tp::rank - 1));
    }
};

template <size_t N>
class offset
{
public:
    static constexpr size_t rank = N;
    using reference           = ptrdiff_t&;
    using const_reference     = const ptrdiff_t&;
    using size_type           = size_t;
    using value_type          = ptrdiff_t;

    offset() /*noexcept*/ restrict(amp,cpu) : base_() {}

    template <size_t K = N, class = typename enable_if<K == 1>::type>
    offset(value_type v) /*noexcept*/ restrict(amp,cpu) : base_(v) {}

    offset(initializer_list<value_type> il) restrict(amp,cpu) : base_(il)
    { 
#if __KALMAR_ACCELERATOR__ != 1
        assert(il.size() == N);
#endif
    };

    reference       operator[](size_type n) restrict(amp,cpu) {
#if __KALMAR_ACCELERATOR__ != 1
        assert(n < N);
#endif
        return base_[n];
    }
    const_reference operator[](size_type n) const restrict(amp,cpu) {
#if __KALMAR_ACCELERATOR__ != 1
        assert(n < N);
#endif
        return base_[n];
    }

    offset& operator+=(const offset& rhs) restrict(amp,cpu) {
        base_ += rhs.base_;
        return *this;
    }
    offset& operator-=(const offset& rhs) restrict(amp,cpu) {
        base_ -= rhs.base_;
        return *this;
    }

    template <size_t K = N, class = typename enable_if<K == 1>::type>
    offset& operator++() restrict(amp,cpu) {
        base_ += 1;
        return *this;
    }
    template <size_t K = N, class = typename enable_if<K == 1>::type>
    offset  operator++(int) restrict(amp,cpu) {
        offset ret = *this;
        base_ += 1;
        return ret;
    };
    template <size_t K = N, class = typename enable_if<K == 1>::type>
    offset& operator--() restrict(amp,cpu) {
        base_ -= 1;
        return *this;
    }
    template <size_t K = N, class = typename enable_if<K == 1>::type>
    offset  operator--(int) restrict(amp,cpu) {
        offset ret = *this;
        base_ -= 1;
        return ret;
    }

    offset  operator+() const /*noexcept*/ restrict(amp,cpu)  { return *this; }
    offset  operator-() const restrict(amp,cpu) {
        offset __r;
        __r -= *this;
        return __r;
    }

    offset& operator*=(value_type v) restrict(amp,cpu) {
        base_ *= v;
        return *this;
    }
    offset& operator/=(value_type v) restrict(amp,cpu) {
        base_ /= v;
        return *this;
    }
private:
    //static_assert(N >= 1, "Rank should be greater than or equal to 1");
    typedef __coordinate_impl<typename __std_make_indices<N>::type> base;

    base base_;

    template <size_t K> friend class bounds;

    template <size_t K, typename Q> friend struct offset_helper;

public:
    __attribute__((annotate("__cxxamp_opencl_index")))
    void __cxxamp_opencl_index() restrict(amp,cpu)
#if __KALMAR_ACCELERATOR__ == 1
    {
      offset_helper<N, offset<N>>::set(*this);
    }
#elif __KALMAR_ACCELERATOR__ == 2 || __KALMAR_CPU__ == 2
    {
    }
#else
    ;
#endif

};

template <size_t idx, size_t N, template <size_t id> class coord>
struct __coord_compare
{
    static bool equal(const coord<N>& lhs, const coord<N>& rhs) noexcept {
        return lhs[idx] == rhs[idx] &&
            __coord_compare<idx + 1, N, coord>::equal(lhs, rhs);
    }
    static bool less(const coord<N>& lhs, const coord<N>& rhs) noexcept {
        return lhs[idx] < rhs[idx] ||
            __coord_compare<idx + 1, N, coord>::less(lhs, rhs);
    }
    static bool less_equal(const coord<N>& lhs, const coord<N>& rhs) noexcept {
        return lhs[idx] <= rhs[idx] &&
            __coord_compare<idx + 1, N, coord>::less_equal(lhs, rhs);
    }
};

template <size_t N, template <size_t id> class coord>
struct __coord_compare<N, N, coord>
{
    static const inline bool equal(const coord<N>& lhs, const coord<N>& rhs) noexcept { return true; }
    static const inline bool less(const coord<N>& lhs, const coord<N>& rhs) noexcept { return false; }
    static const inline bool less_equal(const coord<N>& lhs, const coord<N>& rhs) noexcept { return true; }
};

template <size_t N>
bool operator==(const offset<N>& lhs, const offset<N>& rhs) noexcept {
    return __coord_compare<0, N, offset>::equal(lhs, rhs);
}

template <size_t N>
bool operator!=(const offset<N>& lhs, const offset<N>& rhs) noexcept {
    return !(lhs == rhs);
}

template <size_t N>
static inline offset<N> operator+(const offset<N>& lhs, const offset<N>& rhs) {
    offset<N> __r = lhs;
    __r += rhs;
    return __r;
}

template <size_t N>
static inline offset<N> operator-(const offset<N>& lhs, const offset<N>& rhs) {
    offset<N> __r = lhs;
    __r -= rhs;
    return __r;
}

template <size_t N>
static inline offset<N> operator*(const offset<N>& lhs, ptrdiff_t v) {
    offset<N> __r = lhs;
    __r *= v;
    return __r;
}

template <size_t N>
static inline offset<N> operator*(ptrdiff_t v, const offset<N>& rhs) {
    offset<N> __r = rhs;
    __r *= v;
    return __r;
}

template <size_t N>
static inline offset<N> operator/(const offset<N>& lhs, ptrdiff_t v) {
    offset<N> __r = lhs;
    __r /= v;
    return __r;
}

template <size_t N>
class bounds_iterator : public std::iterator<std::random_access_iterator_tag,
                                             offset<N>,
                                             ptrdiff_t,
                                             offset<N>*,
                                             offset<N>& >
{
    template <size_t K> friend class bounds;
    ptrdiff_t stride;
    bounds<N> bnd_;  // exposition only
    explicit bounds_iterator(const bounds<N>& bnd_, ptrdiff_t stride_ = 0) restrict(amp,cpu)
        : bnd_(bnd_), stride(stride_) {}
public:
    using value_type        = offset<N>;
    using difference_type   = ptrdiff_t;
    using reference         = const offset<N>;

    bool operator==(const bounds_iterator& rhs) const { return stride == rhs.stride; }
    bool operator!=(const bounds_iterator& rhs) const { return !(*this == rhs); }
    bool operator<(const bounds_iterator& rhs) const { return stride < rhs.stride; }
    bool operator<=(const bounds_iterator& rhs) const { return stride <= rhs.stride; }
    bool operator>(const bounds_iterator& rhs) const { return !(*this <= rhs); }
    bool operator>=(const bounds_iterator& rhs) const { return !(*this < rhs); }

    bounds_iterator& operator++() {
        ++stride;
        return *this;
    }
    bounds_iterator  operator++(int) {
        bounds_iterator ret(*this);
        ++*this;
        return ret;
    }
    bounds_iterator& operator--() {
        --stride;
        return *this;
    }
    bounds_iterator  operator--(int) {
        bounds_iterator ret(*this);
        --*this;
        return ret;
    }

    bounds_iterator& operator+=(difference_type n) {
        stride += n;
        return *this;
    }
    bounds_iterator operator+(difference_type n) const {
        bounds_iterator ret(*this);
        ret += n;
        return ret;
    }
    bounds_iterator& operator-=(difference_type n) {
        stride -= n;
        return *this;
    }
    bounds_iterator  operator-(difference_type n) const {
        bounds_iterator ret(*this);
        ret -= n;
        return ret;
    }

    difference_type  operator-(const bounds_iterator& rhs) const {
        bounds_iterator ret(*this);
        ret -= rhs.stride;
        return ret.stride;
    }

    reference operator*() const {
        offset<N> idx;
        ptrdiff_t str = stride;
        for (int i = N - 1; i >= 0; --i) {
            idx[i] = str % bnd_[i];
            str -= idx[i];
            str /= bnd_[i];
        }
        return idx;
    }
    reference operator[](difference_type n) const {
        bounds_iterator iter(bnd_, stride + n);
        return *iter;
    }
};

template <size_t Rank>
bool operator==(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return __coord_compare<0, Rank, offset>::equal(*lhs, *rhs);
}

template <size_t Rank>
bool operator!=(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return !(rhs == lhs);
}

template <size_t Rank>
bool operator<(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return __coord_compare<0, Rank, offset>::less(*lhs, *rhs);
}

template <size_t Rank>
bool operator<=(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return __coord_compare<0, Rank, offset>::less_equal(*lhs, *rhs);
}

template <size_t Rank>
bool operator>(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return !(lhs <= rhs);
}

template <size_t Rank>
bool operator>=(const bounds_iterator<Rank>& lhs, const bounds_iterator<Rank>& rhs) {
    return !(lhs < rhs);
}

template <size_t N>
class bounds {
    static_assert(N >= 1, "Rank should be greater than or equal to 1");
    typedef __coordinate_impl<typename __std_make_indices<N>::type> base;


    template <size_t Rank>
        friend bounds<Rank> operator+(const offset<Rank>& lhs, const bounds<Rank>& rhs);

    base base_;
    void check() const {
#if __KALMAR_ACCELERATOR__ != 1
        assert(this->size() >= 0 && this->size() <= numeric_limits<ptrdiff_t>::max());
        assert(base_.all_pos());
#endif
    }
public:
    static constexpr size_t rank = N;
    using reference           = ptrdiff_t&;
    using const_reference     = const ptrdiff_t&;
    using iterator            = bounds_iterator<N>;
    using const_iterator      = bounds_iterator<N>;
    using size_type           = size_t;
    using value_type          = ptrdiff_t;

    bounds() restrict(amp,cpu) : base_() {}

    template <size_t K = N, class = typename enable_if<K == 1>::type>
    bounds(value_type v) restrict(amp,cpu) : base_(v) {
#if __KALMAR_ACCELERATOR__ != 1
        assert(v >= 0 && v <= numeric_limits<ptrdiff_t>::max());
#endif
    }

    bounds(initializer_list<value_type> il) restrict(amp,cpu) : base_(il) {
#if __KALMAR_ACCELERATOR__ != 1
        assert(il.size() == N);
#endif
        check();
    }

    size_type size() const noexcept { return base_.size(); }
    bool      contains(const offset<N>& idx) const noexcept { return base_.contains(idx.base_); }

    const_iterator begin() const noexcept { return bounds_iterator<N>(*this); }
    const_iterator end() const noexcept { return bounds_iterator<N>(*this, size()); }

    reference       operator[](size_type n) { return base_[n]; }
    const_reference operator[](size_type n) const { return base_[n]; };

    bounds  operator+(const offset<N>& rhs) const {
        bounds __r(*this);
        __r.base_ += rhs.base_;
        __r.check();
        return __r;
    }
    bounds  operator-(const offset<N>& rhs) const {
        bounds __r(*this);
        __r.base_ -= rhs.base_;
        __r.check();
        return __r;
    }
    bounds& operator+=(const offset<N>& rhs) {
        base_ += rhs.base_;
        this->check();
        return *this;
    }
    bounds& operator-=(const offset<N>& rhs) {
        base_ -= rhs.base_;
        this->check();
        return *this;
    }

    bounds  operator*(value_type v) const {
        bounds __r(*this);
        __r *= v;
        __r.check();
        return __r;
    }
    bounds  operator/(value_type v) const {
        bounds __r(*this);
        __r /= v;
        __r.check();
        return __r;
    }
    bounds& operator*=(value_type v) {
        base_ *= v;
        this->check();
        return *this;
    }
    bounds& operator/=(value_type v) {
        base_ /= v;
        this->check();
        return *this;
    }
};

template <size_t Rank>
bool operator==(const bounds<Rank>& lhs, const bounds<Rank>& rhs) noexcept {
    return __coord_compare<0, Rank, bounds>::equal(lhs, rhs);
}

template <size_t Rank>
bool operator!=(const bounds<Rank>& lhs, const bounds<Rank>& rhs) noexcept {
    return !(lhs == rhs);
}

template <size_t Rank>
bounds<Rank> operator+(const bounds<Rank>& lhs, const offset<Rank>& rhs) {
    bounds<Rank> ret(lhs);
    ret += rhs;
    return ret;
}

template <size_t Rank>
bounds<Rank> operator+(const offset<Rank>& lhs, const bounds<Rank>& rhs) {
    bounds<Rank> ret(rhs);
    ret += lhs;
    return ret;
}

template <size_t Rank>
bounds<Rank> operator-(const bounds<Rank>& lhs, const offset<Rank>& rhs) {
    bounds<Rank> ret(lhs);
    ret -= rhs;
    return ret;
}

template <size_t Rank>
bounds<Rank> operator*(const bounds<Rank>& lhs, ptrdiff_t v) {
    bounds<Rank> ret(lhs);
    ret *= v;
    return ret;
}

template <size_t Rank>
bounds<Rank> operator*(ptrdiff_t v, const bounds<Rank>& rhs) {
    bounds<Rank> ret(rhs);
    ret *= v;
    return ret;
}

}  // std

#endif  // _LIBCPP_COORDINATE
