Eigen: Modifyable Custom Expression

北慕城南 提交于 2019-12-12 04:54:21

问题


I'm trying to implement a modifyable custom expression using Eigen, similar to this question. Basically, what I want is something similar to the indexing example in the tutorial, but with the possibility to assign new values to the selected coefficients.

As suggested in the accepted answer in the question mentioned above, I have looked into the Transpose implementation and tried many things, yet without success. Basically, my attempts are failing with errors like 'Eigen::internal::evaluator<SrcXprType>::evaluator(const Eigen::internal::evaluator<SrcXprType> &)': cannot convert argument 1 from 'const Eigen::Indexing<Derived>' to 'Eigen::Indexing<Derived> &'. Probably, the problem lies in my evaluator struct which seems to be read-only.

namespace Eigen {
namespace internal {
    template<typename ArgType>
    struct evaluator<Indexing<ArgType> >
        : evaluator_base<Indexing<ArgType> >
    {
        typedef Indexing<ArgType> XprType;
        typedef typename nested_eval<ArgType, XprType::ColsAtCompileTime>::type ArgTypeNested;
        typedef typename remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
        typedef typename XprType::CoeffReturnType CoeffReturnType;
        typedef typename traits<ArgType>::Scalar Scalar;
        enum {
            CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
            Flags = Eigen::ColMajor
        };

        evaluator(XprType& xpr)
            : m_argImpl(xpr.m_arg), m_rows(xpr.rows())
        { }
        const Scalar& coeffRef(Index row, Index col) const
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        Scalar& coeffRef(Index row, Index col)
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        evaluator<ArgTypeNestedCleaned> m_argImpl;
        const Index m_rows;
    };
}
}

Also, I've changed all occurences of typedef typename Eigen::internal::ref_selector<ArgType>::type to ...::non_const_type, but this had no effect.

Due to the complexity of the Eigen library, I cant figure out how to puzzle the expression and the evaluator together correctly. I don't understand, why my evaluator is read-only or how to get a write-enabled evaluator. It would be great if someone could provide a minimal example for a modifyable custom expression.


回答1:


With help of ggael's hint I've been able to sucessfully add my own modifyable expression. I've basically adapted the IndexedView of the Eigen development branch.

As the originally requested funcionality is covered by the IndexedView, I've written a modifyable circular shift function as simple example of a modifyable custom expression. Most of the code is directly taken from the IndexedView, so credits go to the authors of that.

// circ_shift.h
#pragma once
#include <Eigen/Core>

namespace helper
{
        namespace detail
    {
        template <typename T>
        constexpr std::true_type is_matrix(Eigen::MatrixBase<T>);
        std::false_type constexpr is_matrix(...);

        template <typename T>
        constexpr std::true_type is_array(Eigen::ArrayBase<T>);
        std::false_type constexpr is_array(...);
    }


    template <typename T>
    struct is_matrix : decltype(detail::is_matrix(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    struct is_array : decltype(detail::is_array(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    using is_matrix_or_array = std::bool_constant<is_array<T>::value || is_matrix<T>::value>;



    /*
     * Index something if it's not an scalar
     */
    template <typename T, typename std::enable_if<is_matrix_or_array<T>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index idx)
    {
        return thing(idx);
    }

    /*
    * Overload for scalar.
    */
    template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index)
    {
        return thing;
    }
}

namespace Eigen
{
    template <typename XprType, typename RowIndices, typename ColIndices>
    class CircShiftedView;

    namespace internal
    {
        template <typename XprType, typename RowIndices, typename ColIndices>
        struct traits<CircShiftedView<XprType, RowIndices, ColIndices>>
            : traits<XprType>
        {
            enum
            {
                RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
                ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
                MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : int(traits<XprType>::MaxRowsAtCompileTime),
                MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : int(traits<XprType>::MaxColsAtCompileTime),

                XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
                IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1
                                 : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0
                                 : XprTypeIsRowMajor,


                FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
                FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
                Flags = (traits<XprType>::Flags & HereditaryBits) | FlagsLvalueBit | FlagsRowMajorBit
            };
        };
    }

    template <typename XprType, typename RowShift, typename ColShift, typename StorageKind>
    class CircShiftedViewImpl;


    template <typename XprType, typename RowShift, typename ColShift>
    class CircShiftedView : public CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>
    {
    public:
        typedef typename CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>::Base Base;
        EIGEN_GENERIC_PUBLIC_INTERFACE(CircShiftedView)
        EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CircShiftedView)

        typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
        typedef typename internal::remove_all<XprType>::type NestedExpression;

        template <typename T0, typename T1>
        CircShiftedView(XprType& xpr, const T0& rowShift, const T1& colShift)
            : m_xpr(xpr), m_rowShift(rowShift), m_colShift(colShift)
        {
            for (auto c = 0; c < xpr.cols(); ++c)
            assert(std::abs(helper::index_if_necessary(m_rowShift, c)) < m_xpr.rows()); // row shift must be within +- rows()-1
            for (auto r = 0; r < xpr.rows(); ++r)
            assert(std::abs(helper::index_if_necessary(m_colShift, r)) < m_xpr.cols()); // col shift must be within +- cols()-1
        }

        /** \returns number of rows */
        Index rows() const { return m_xpr.rows(); }

        /** \returns number of columns */
        Index cols() const { return m_xpr.cols(); }

        /** \returns the nested expression */
        const typename internal::remove_all<XprType>::type&
        nestedExpression() const { return m_xpr; }

        /** \returns the nested expression */
        typename internal::remove_reference<XprType>::type&
        nestedExpression() { return m_xpr.const_cast_derived(); }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getRowIdx(Index row, Index col) const
        {
            Index R = m_xpr.rows();
            assert(row >= 0 && row < R && col >= 0 && col < m_xpr.cols());
            Index r = row - helper::index_if_necessary(m_rowShift, col);
            if (r >= R)
                return r - R;
            if (r < 0)
                return r + R;
            return r;
        }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getColIdx(Index row, Index col) const
        {
            Index C = m_xpr.cols();
            assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < C);
            Index c = col - helper::index_if_necessary(m_colShift, row);
            if (c >= C)
                return c - C;
            if (c < 0)
                return c + C;
            return c;
        }

    protected:
        MatrixTypeNested m_xpr;
        RowShift m_rowShift;
        ColShift m_colShift;
    };


    // Generic API dispatcher
    template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
    class CircShiftedViewImpl
        : public internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type
    {
    public:
        typedef typename internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type Base;
    };

    namespace internal
    {
        template <typename ArgType, typename RowIndices, typename ColIndices>
        struct unary_evaluator<CircShiftedView<ArgType, RowIndices, ColIndices>, IndexBased>
            : evaluator_base<CircShiftedView<ArgType, RowIndices, ColIndices>>
        {
            typedef CircShiftedView<ArgType, RowIndices, ColIndices> XprType;

            enum
            {
                CoeffReadCost = evaluator<ArgType>::CoeffReadCost + NumTraits<Index>::AddCost /* for comparison */ + NumTraits<Index>::AddCost /*for addition*/,

                Flags = (evaluator<ArgType>::Flags & HereditaryBits),

                Alignment = 0
            };

            EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
            {
                EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
            }

            typedef typename XprType::Scalar Scalar;
            typedef typename XprType::CoeffReturnType CoeffReturnType;


            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            CoeffReturnType coeff(Index row, Index col) const
            {
                return m_argImpl.coeff(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            Scalar& coeffRef(Index row, Index col)
            {
                assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < m_xpr.cols());

                return m_argImpl.coeffRef(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

        protected:

            evaluator<ArgType> m_argImpl;
            const XprType& m_xpr;
        };
    } // end namespace internal
} // end namespace Eigen


template <typename XprType, typename RowShift, typename ColShift>
auto circShift(Eigen::DenseBase<XprType>& x, RowShift r, ColShift c)
{
    return Eigen::CircShiftedView<XprType, RowShift, ColShift>(x.derived(), r, c);
}

And:

// main.cpp
#include "stdafx.h"
#include "Eigen/Core"
#include <iostream>
#include "circ_shift.h"

using namespace Eigen;


int main()
{

    ArrayXXf x(4, 2);
    x.transpose() << 1, 2, 3, 4, 10, 20, 30, 40;


    Vector2i rowShift;
    rowShift << 3, -3; // rotate col 1 by 3 and col 2 by -3

    Index colShift = 1; // flip columns

    auto shifted = circShift(x, rowShift, colShift);

    std::cout << "shifted: " << std::endl << shifted << std::endl;

    shifted.block(2,0,2,1) << -1, -2; // will appear in row 3 and 0.
    shifted.col(1) << 2,4,6,8;  // shifted col 1 is col 0 of the original

    std::cout << "modified original:" << std::endl << x << std::endl;

    return 0;
}


来源:https://stackoverflow.com/questions/46077242/eigen-modifyable-custom-expression

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!