//===----------------------------------------------------------------------===//
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCPP_ARRAY_VIEW
#define _LIBCPP_ARRAY_VIEW

#include <coordinate>
#include <type_traits>

namespace std
{

#define VIEW_ACCESS(data, idx, stride, rank) \
    { \
        ptrdiff_t offset = 0; \
        for (int i = 0; i < rank; ++i) \
            offset += stride[i] * idx[i]; \
        return data[offset]; \
    }


template <size_t N>
static inline offset<N> get_stride(const bounds<N>& bnd)
{
    offset<N> id;
    id[N - 1] = 1;
    for (size_t i = N - 1; i > 0; --i)
        id[i - 1] = id[i] * bnd[i];
    return id;
}


template <class T, size_t Rank> class strided_array_view;

template <typename T>
struct __has_data
{
private:
    struct two {char __lx; char __lxx;};
    template <typename C> static char test(decltype(std::declval<C>().data()));
    template <typename C> static two test(...);
public:
    static const bool value = sizeof(test<T>(0)) == 1;
};

template <typename T>
struct __has_size
{
private:
    struct two {char __lx; char __lxx;};
    template <typename C> static char test(decltype(&C::size));
    template <typename C> static two test(...);
public:
    static const bool value = sizeof(test<T>(0)) == 1;
};

template <typename T>
struct __is_viewable
{
    using _T = typename std::remove_reference<T>::type;
    static const bool value = __has_size<_T>::value && __has_data<_T>::value;
};

template <class T, size_t Rank = 1>
class array_view {
public:
    static constexpr size_t rank = Rank;
    using offset_type         = offset<Rank>;
    using bounds_type         = bounds<Rank>;
    using size_type           = size_t;
    using value_type          = T;
    using pointer             = T*;
    using reference           = T&;

    array_view() noexcept : data_(nullptr), bnd_(), stride_() {}

    template <class Viewable, size_t N = Rank,
              typename = typename enable_if<
                  (N == 1) &&
                  __is_viewable<Viewable>::value
                  >::type
              >
        array_view(Viewable&& vw) : data_(vw.data()), bnd_(vw.size()), stride_() {
            static_assert(is_convertible<decltype(vw.size()), ptrdiff_t>::value, "illegal Viewable");
            static_assert(is_convertible<decltype(vw.data()), pointer>::value, "illegal Viewable");
            static_assert(is_same<typename remove_cv<typename remove_pointer<decltype(vw.data())>::type>::type,
                                  typename remove_cv<T>::type>::value, "illegal Viewable");
        }

    template <class U, size_t AnyN, size_t N = Rank,
             typename = typename enable_if<
                 N == 1 &&
                 is_convertible<typename add_pointer<U>::type, pointer>::value &&
                 is_same<typename remove_cv<T>::type, typename remove_cv<value_type>::type>::value
                 >::type
             >
        array_view(const array_view<U, AnyN>& rhs) noexcept
        : data_(rhs.data()), bnd_(rhs.size()), stride_(1) {}

    template <size_t Extent,
              size_t N = Rank,
              typename = typename enable_if<N == 1>::type
             >
        array_view(value_type (&arr)[Extent]) noexcept
        : data_(arr), bnd_(Extent), stride_(get_stride(bnd_)) {}

    template <class U,
              typename = typename enable_if<
                                            is_convertible<typename add_pointer<U>::type, pointer>::value &&
                                            is_same<typename remove_cv<U>::type, typename remove_cv<value_type>::type>::value
                                           >::type
             >
        array_view(const array_view<U, Rank>& rhs) noexcept
        : data_(rhs.data()), bnd_(rhs.bounds()), stride_(rhs.stride()) {}

    template <class Viewable>
        array_view(Viewable&& vw, bounds_type bounds)
        : data_(vw.data()), bnd_(bounds), stride_(get_stride(bounds)) { 
#ifndef __KALMAR_ACCELERATOR__
            assert(bnd_.size() <= vw.size());
#endif
            static_assert(is_convertible<decltype(vw.size()), ptrdiff_t>::value, "illegal Viewable");
            static_assert(is_convertible<decltype(vw.data()), pointer>::value, "illegal Viewable");
            static_assert(is_same<typename remove_cv<typename remove_pointer<decltype(vw.data())>::type>::type,
                          typename remove_cv<T>::type>::value, "illegal Viewable");
        }

    array_view(pointer ptr, bounds_type bounds)
    : data_(ptr), bnd_(bounds), stride_(get_stride(bounds)) {}

    bounds_type bounds() const noexcept { return bnd_; }
    size_type   size() const noexcept { return bnd_.size(); }
    offset_type  stride() const noexcept { return stride_; }
    pointer     data() const noexcept { return data_; }

    reference operator[](const offset_type& idx) const {
#ifndef __KALMAR_ACCELERATOR__
        assert(bnd_.contains(idx));
#endif
        VIEW_ACCESS(data_, idx, stride_, Rank);
    }

    // [arrayview.subview], array_view slicing and sectioning
    template<size_t N = Rank, typename = typename enable_if<(N > 1)>::type>
    array_view<T, Rank - 1>
        operator[](ptrdiff_t slice) const {
#ifndef __KALMAR_ACCELERATOR__
            assert(slice < bnd_[0]);
#endif
            std::bounds<Rank - 1> bnd;
            for (auto i = 1; i < Rank; ++i)
                bnd[i - 1] = bnd_[i];
            return array_view<T, Rank - 1>(data_ + stride_[0] * slice, bnd);
        }
    strided_array_view<T, Rank>
        section(const offset_type& origin, const bounds_type& section_bnd) const {
            auto range = bnd_ - origin;
#ifndef __KALMAR_ACCELERATOR__
            for (auto i = 0; i < Rank; ++i)
                assert(range[i] >= section_bnd[i]);
#endif
            ptrdiff_t offset = 0;
            for (auto i = 0; i < Rank; ++i)
                offset += origin[i] * stride_[i];
            return strided_array_view<T, Rank>(data_ + offset, section_bnd, stride_);
        }
    strided_array_view<T, Rank>
        section(const offset_type& origin) const { return section(origin, bnd_ - origin); }

private:
    static_assert(Rank >= 1, "Rank should be greater than or equal to 1");
    pointer data_;
    bounds_type bnd_;
    offset_type stride_;
};

template <class T, size_t Rank = 1>
class strided_array_view {
public:
    // constants and types
    static constexpr size_t rank = Rank;
    using offset_type          = offset<Rank>;
    using bounds_type         = bounds<Rank>;
    using size_type           = size_t;
    using value_type          = T;
    using pointer             = T*;
    using reference           = T&;

    strided_array_view() noexcept : data_(nullptr), bnd_(), stride_() {}

    template <class U,
              typename = typename enable_if<
                                            is_convertible<typename add_pointer<U>::type, pointer>::value &&
                                            is_same<
                                                    typename remove_cv<U>::type,
                                                    typename remove_cv<value_type>::type
                                                   >::value
                                           >::type
             >
        strided_array_view(const array_view<U, Rank>& rhs) noexcept
        : data_(rhs.data()), bnd_(rhs.bounds()), stride_(rhs.stride()) {}

    template <class U,
              typename = typename enable_if<
                                            is_convertible<typename add_pointer<U>::type, pointer>::value &&
                                            is_same<
                                                    typename remove_cv<U>::type,
                                                    typename remove_cv<value_type>::type
                                                   >::value
                                           >::type
             >
        strided_array_view(const strided_array_view<U, Rank>& rhs) noexcept
        : data_(rhs.data_), bnd_(rhs.bnd_), stride_(rhs.stride_) {}

    strided_array_view(pointer ptr, bounds_type bounds, offset_type stride)
        : data_(ptr), bnd_(bounds), stride_(stride) {}

    bounds_type bounds() const noexcept { return bnd_; }
    size_type   size() const noexcept { return bnd_.size(); }
    offset_type  stride() const noexcept { return stride_; }

    reference operator[](const offset_type& idx) const {
#ifndef __KALMAR_ACCELERATOR__
        assert(bnd_.contains(idx));
#endif
        VIEW_ACCESS(data_, idx, stride(), Rank);
    }

    template<size_t N = Rank, typename = typename enable_if<(N > 1)>::type>
    strided_array_view<T, Rank - 1>
        operator[](ptrdiff_t slice) const {
#ifndef __KALMAR_ACCELERATOR__
            assert(slice < bnd_[0]);
#endif
            std::bounds<Rank - 1> bnd;
            for (auto i = 1; i < Rank; ++i)
                bnd[i - 1] = bnd_[i];
            std::offset<Rank - 1> stride;
            for (auto i = 1; i < Rank; ++i)
                stride[i - 1] = stride_[i];
            return strided_array_view<T, Rank - 1>(data_ + stride_[0] * slice, bnd, stride);
        }
    strided_array_view<T, Rank>
        section(const offset_type& origin, const bounds_type& section_bnd) const {
            auto range = bnd_ - origin;
#ifndef __KALMAR_ACCELERATOR__
            for (auto i = 0; i < Rank; ++i)
                assert(range[i] >= section_bnd[i]);
#endif
            ptrdiff_t offset = 0;
            for (auto i = 0; i < Rank; ++i)
                offset += origin[i] * stride_[i];
            return strided_array_view<T, Rank>(data_ + offset, section_bnd, stride_);
        }
    strided_array_view<T, Rank>
        section(const offset_type& origin) const { return section(origin, bnd_ - origin); }

private:
    static_assert(Rank >= 1, "Rank should be greater than or equal to 1");
    template <typename T_, size_t Rank_> friend class strided_array_view;
    pointer data_;  // exposition only
    bounds_type bnd_;
    offset_type stride_;
};

}  // std

#endif  // _LIBCPP_ARRAY_VIEW
