Symbolic differentiation using expression templates in C++

前端 未结 1 1839
面向向阳花
面向向阳花 2021-02-04 12:54

How to implement symbolic differentiation using expression templates in C++

相关标签:
1条回答
  • 2021-02-04 13:31

    In general you'd want a way to represent your symbols (i.e. the expressions templates that encode e.g. 3 * x * x + 42), and a meta-function that can compute a derivative. Hopefully you're familiar enough with metaprogramming in C++ to know what that means and entails but to give you an idea:

    // This should come from the expression templates
    template<typename Lhs, typename Rhs>
    struct plus_node;
    
    // Metafunction that computes a derivative
    template<typename T>
    struct derivative;
    
    // derivative<foo>::type is the result of computing the derivative of foo
    
    // Derivative of lhs + rhs
    template<typename Lhs, typename Rhs>
    struct derivative<plus_node<Lhs, Rhs> > {
        typedef plus_node<
            typename derivative<Lhs>::type
            , typename derivative<Rhs>::type
        > type;
    };
    
    // and so on
    

    You'd then tie up the two parts (representation and computation) such that it would be convenient to use. E.g. derivative(3 * x * x + 42)(6) could mean 'compute the derivative of 3 * x * x + 42 in x at 6'.

    However even if you do know what it takes to write expression templates and what it takes to write a metaprogram in C++ I wouldn't recommend going about it this way. Template metaprogramming requires a lot of boilerplate and can be tedious. Instead, I direct you to the genius Boost.Proto library, which is precisely designed to help write EDSLs (using expression templates) and operate on those expression templates. It it not necessarily easy to learn to use but I've found that learning how to achieve the same thing without using it is harder. Here's a sample program that can in fact understand and compute derivative(3 * x * x + 42)(6):

    #include <iostream>
    
    #include <boost/proto/proto.hpp>
    
    using namespace boost::proto;
    
    // Assuming derivative of one variable, the 'unknown'
    struct unknown {};
    
    // Boost.Proto calls this the expression wrapper
    // elements of the EDSL will have this type
    template<typename Expr>
    struct expression;
    
    // Boost.Proto calls this the domain
    struct derived_domain
    : domain<generator<expression>> {};
    
    // We will use a context to evaluate expression templates
    struct evaluation_context: callable_context<evaluation_context const> {
        double value;
    
        explicit evaluation_context(double value)
            : value(value)
        {}
    
        typedef double result_type;
    
        double operator()(tag::terminal, unknown) const
        { return value; }
    };
    // And now we can do:
    // evalutation_context context(42);
    // eval(expr, context);
    // to evaluate an expression as though the unknown had value 42
    
    template<typename Expr>
    struct expression: extends<Expr, expression<Expr>, derived_domain> {
        typedef extends<Expr, expression<Expr>, derived_domain> base_type;
    
        expression(Expr const& expr = Expr())
            : base_type(expr)
        {}
    
        typedef double result_type;
    
        // We spare ourselves the need to write eval(expr, context)
        // Instead, expr(42) is available
        double operator()(double d) const
        {
            evaluation_context context(d);
            return eval(*this, context);
        }
    };
    
    // Boost.Proto calls this a transform -- we use this to operate
    // on the expression templates
    struct Derivative
    : or_<
        when<
            terminal<unknown>
            , boost::mpl::int_<1>()
        >
        , when<
            terminal<_>
            , boost::mpl::int_<0>()
        >
        , when<
            plus<Derivative, Derivative>
            , _make_plus(Derivative(_left), Derivative(_right))
        >
        , when<
            multiplies<Derivative, Derivative>
            , _make_plus(
                _make_multiplies(Derivative(_left), _right)
                , _make_multiplies(_left, Derivative(_right))
            )
        >
        , otherwise<_>
    > {};
    
    // x is the unknown
    expression<terminal<unknown>::type> const x;
    
    // A transform works as a functor
    Derivative const derivative;
    
    int
    main()
    {
        double d = derivative(3 * x * x + 3)(6);
        std::cout << d << '\n';
    }
    
    0 讨论(0)
提交回复
热议问题