Multiple dispatch solution with full maintainability

杀马特。学长 韩版系。学妹 提交于 2019-11-28 11:51:26

What I have done for multiple dispatch (turn out my comment into answer):

// Generic IVisitor
// Do: using MyIVisitor = IVisitorTs<Child1, Child2, ...>
template <typename ... Ts> class IVisitorTs;

template <typename T, typename ... Ts>
class IVisitorTs<T, Ts...> : public IVisitorTs<Ts...>
{
public:
    using tuple_type = std::tuple<T, Ts...>;
    using IVisitorTs<Ts...>::visit;

    virtual ~IVisitorTs() = default;
    virtual void visit(const T& t) = 0;
};

template <typename T> class IVisitorTs<T>
{
public:
    using tuple_type = std::tuple<T>;

    virtual ~IVisitorTs() = default;
    virtual void visit(const T& t) = 0;
};

namespace detail {

// retrieve the index of T in Ts...
template <typename T, typename ... Ts> struct get_index;

template <typename T, typename ... Ts>
struct get_index<T, T, Ts...> : std::integral_constant<std::size_t, 0> {};

template <typename T, typename Tail,  typename ... Ts>
struct get_index<T, Tail, Ts...> :
        std::integral_constant < std::size_t, 1 + get_index<T, Ts...>::value > {};

// retrieve the index of T in Tuple<Ts...>
template <typename T, typename Tuple> struct get_index_in_tuple;

template <typename T, template <typename...> class C, typename ... Ts>
struct get_index_in_tuple<T, C<Ts...>> : get_index<T, Ts...> {};

// get element of a multiarray
template <std::size_t I>
struct multi_array_getter
{
    template <typename T, std::size_t N>
    static constexpr auto get(const T& a, const std::array<std::size_t, N>& index)
    -> decltype(multi_array_getter<I - 1>::get(a[index[N - I]], index))
    {
        return multi_array_getter<I - 1>::get(a[index[N - I]], index);
    }
};

template <>
struct multi_array_getter<0>
{
    template <typename T, std::size_t N>
    static constexpr auto get(const T& a, const std::array<std::size_t, N>& index)
    -> decltype(a)
    {
        return a;
    }
};

// Provide an implementation of visitor
// by forwarding to C implementation (which may be non virtual)
template <typename IVisitor, typename C, typename...Ts> struct IVisitorImpl;

template <typename IVisitor, typename C, typename T, typename...Ts>
struct IVisitorImpl<IVisitor, C, T, Ts...> : IVisitorImpl<IVisitor, C, Ts...>
{
    virtual void visit(const T& t) override { C::visit(t); }
};

template <typename IVisitor, typename C, typename T>
struct IVisitorImpl<IVisitor, C, T> : IVisitor, C
{
    virtual void visit(const T& t) override { C::visit(t); }
};

// helper to expand child type to IVisitorImpl
template <typename IVisitor, typename C>
struct IVisitorImplType;

template <typename ... Ts, typename C>
struct IVisitorImplType<IVisitorTs<Ts...>, C>
{
    using type = IVisitorImpl<IVisitorTs<Ts...>, C, Ts...>;
};

// Create an multi array of pointer of function
// (with all combinaisons of overload).
template <typename Ret, typename F, typename Arg>
class GetAllOverload
{
private:
    template <typename...Ts>
    struct Functor
    {
        // function which will be in array.
        static Ret call(F&f, const Arg& arg)
        {
            return call_helper(f, arg, make_index_sequence<sizeof...(Ts)>());
        }
    private:
        // The final dispatched function
        template <std::size_t ... Is>
        static Ret call_helper(F&f, const Arg& arg, index_sequence<Is...>)
        {
            using RetTuple = std::tuple<Ts&...>;
            // static cast is suffisant if arg is the abstract type
            // when given arg is concrete type, reinterpret_cast is required.
            // TODO: build a smaller table with only possible value to avoid that
            return f(reinterpret_cast<typename std::tuple_element<Is, RetTuple>::type>(std::get<Is>(arg))...);
        }
    };

    // helper class to create the multi array of function pointer
    template <std::size_t N, typename Tuple, typename...Ts>
    struct Builder;

    template <typename...Ts, typename...Ts2>
    struct Builder<1, std::tuple<Ts...>, Ts2...>
    {
        using RetType = std::array<Ret (*)(F&, const Arg&), sizeof...(Ts)>;

        static constexpr RetType build()
        {
            return RetType{ &Functor<Ts2..., Ts>::call... };
        }
    };

    template <std::size_t N, typename ...Ts, typename...Ts2>
    struct Builder<N, std::tuple<Ts...>, Ts2...>
    {
        template <typename T>
        using RecType = Builder<N - 1, std::tuple<Ts...>, Ts2..., T>;
        using T0 = typename std::tuple_element<0, std::tuple<Ts...>>::type;
        using RetType = std::array<decltype(RecType<T0>::build()), sizeof...(Ts)>;

        static constexpr RetType build() {
            return RetType{ RecType<Ts>::build()... };
        }
    };

public:
    template <std::size_t N, typename VisitorTuple>
    static constexpr auto get()
    -> decltype(Builder<N, VisitorTuple>::build())
    {
        return Builder<N, VisitorTuple>::build();
    }
};

template <typename Ret, typename IVisitor, typename F, std::size_t N>
class dispatcher
{
private:
    std::array<std::size_t, N> index;

    struct visitorCallImpl
    {
        template <typename T>
        void visit(const T&) const
        {
            *index = get_index_in_tuple<T, IVisitor>::value;
        }

        void setIndexPtr(std::size_t& index) { this->index = &index; }
    private:
        std::size_t* index = nullptr;
    };

    template <std::size_t I, typename Tuple>
    void set_index(const Tuple&t)
    {
        using VisitorType = typename IVisitorImplType<IVisitor, visitorCallImpl>::type;
        VisitorType visitor;
        visitor.setIndexPtr(index[I]);

        std::get<I>(t).accept(visitor);
    }
public:
    template <typename Tuple, std::size_t ... Is>
    Ret operator () (F&& f, const Tuple&t, index_sequence<Is...>)
    {
        const int dummy[] = {(set_index<Is>(t), 0)...};
        static_cast<void>(dummy); // silent the warning unused varaible
        constexpr auto a = GetAllOverload<Ret, F&&, Tuple>::
            template get<sizeof...(Is), typename IVisitor::tuple_type>();
        auto func = multi_array_getter<N>::get(a, index);
        return (*func)(f, t);
    }
};

} // namespace detail

template <typename Ret, typename Visitor, typename F, typename ... Ts>
Ret dispatch(F&& f, Ts&...args)
{
    constexpr std::size_t size = sizeof...(Ts);
    detail::dispatcher<Ret, Visitor, F&&, size> d;
    return d(std::forward<F>(f), std::tie(args...), make_index_sequence<size>());
}

Example usage

struct A;
struct B;
struct C;
struct D;

using IAVisitor = IVisitorTs<A, B, C, D>;

struct A {
    virtual ~A() = default;
    virtual void accept(IAVisitor& v) const { v.visit(*this); }
};
struct B : A {
    virtual void accept(IAVisitor& v) const override { v.visit(*this); }
};

struct C : A {
    virtual void accept(IAVisitor& v) const override { v.visit(*this); }
};
struct D : A {
    virtual void accept(IAVisitor& v) const override { v.visit(*this); }
};

class Object {
    public:
        virtual double foo (A*, A*) { std::cout << "Object::foo A,A\n";  return 3.14; }
        virtual double foo (B*, B*) { std::cout << "Object::foo B,B\n";  return 3.14; }
        virtual double foo (B*, C*) { std::cout << "Object::foo B,C\n";  return 3.14; }
        virtual double foo (C*, B*) { std::cout << "Object::foo C,B\n";  return 3.14; }
        virtual double foo (C*, C*) { std::cout << "Object::foo C,C\n";  return 3.14; }
        virtual char foo (A*, A*, A*) const { std::cout << "Object::foo A,A,A\n";  return '&'; }
        virtual char foo (C*, B*, D*) const { std::cout << "Object::foo C,B,D\n";  return '!'; }  // Overload of foo with three arguments.
        virtual void bar (A*, A*, A*) const { std::cout << "Object::bar A,A,A\n"; }
        virtual void bar (B*, B*, B*) const { std::cout << "Object::bar B,B,B\n"; }
        virtual void bar (B*, C*, B*) const { std::cout << "Object::bar B,C,B\n"; }
        virtual void bar (B*, C*, C*) const { std::cout << "Object::bar B,C,C\n"; }
        virtual void bar (B*, C*, D*) const { std::cout << "Object::bar B,C,D\n"; }
        virtual void bar (C*, B*, D*) const { std::cout << "Object::bar C,B,D\n"; }
        virtual void bar (C*, C*, C*) const { std::cout << "Object::bar C,C,C\n"; }
        virtual void bar (D*, B*, C*) const { std::cout << "Object::bar D,B,C\n"; }
        double fooMultipleDispatch (A*, A*);
        char fooMultipleDispatch (A*, A*, A*);
};

class FooDispatcher
{
public:
    explicit FooDispatcher(Object& object) : object(object) {}

    template <typename T1, typename T2>
    double operator() (T1& a1, T2& a2) const
    {
        return object.foo(&a1, &a2);
    }

    template <typename T1, typename T2, typename T3>
    char operator() (T1& a1, T2& a2, T3& a3) const
    {
        return object.foo(&a1, &a2, &a3);
    }
private:
    Object& object;
};

double Object::fooMultipleDispatch (A* a1, A* a2)
{
    return dispatch<double, IAVisitor>(FooDispatcher(*this), *a1, *a2);
}
char Object::fooMultipleDispatch (A* a1, A* a2, A* a3)
{
    return dispatch<char, IAVisitor>(FooDispatcher(*this), *a1, *a2, *a3);
}


int main() {
    A a_a;
    B a_b;
    C a_c;
    D a_d;
    A* a[] = {&a_b, &a_c, &a_d, &a_a};
    Object object;

    double d = object.foo (a[0], a[1]);  // Object::foo A,A  (no multiple dispatch)
    d = object.fooMultipleDispatch (a[0], a[1]);  // Object::foo B,C
    std::cout << "d = " << d << std::endl;  // 3.14

    object.fooMultipleDispatch (a[0], a[3]);  // B,A -> so best match is Object::foo A,A

    const char k = object.fooMultipleDispatch (a[1], a[0], a[2]);  // Object::foo C,B,D
    std::cout << "k = " << k << std::endl;  // !
}

Live example

I would use something like this, after some brushing up:

requisite headers

#include <iostream>
#include <typeinfo>
#include <map>
#include <array>
#include <functional>
#include <stdexcept>
#include <algorithm>

dynamic_call: call any function with downcasted arguments, used by the dispatcher

// base case: no arguments    

template<typename Result>
Result dynamic_call (std::function<Result()> fun)
{
    return fun();
}    
template<typename Result>
Result dynamic_call (Result(*fun)())
{
    return fun();
}

// one or more argument: dynamic_cast the first argument,
// recursively pass down the rest of them

template<typename Result, typename Arg0, typename FunArg0, typename ... Args, typename ... FunArgs>
Result dynamic_call (std::function<Result(FunArg0*, FunArgs*...)> fun, Arg0* arg0, Args*... args)
{
    FunArg0* converted_arg0 = dynamic_cast<FunArg0*>(arg0);
    if (converted_arg0 == nullptr)
        throw std::runtime_error("Argument type error!");
    std::function<Result(FunArgs*...)> helper = [converted_arg0, fun](FunArgs*... fun_args) -> Result
    {
        return fun(converted_arg0, fun_args...);
    };
    return dynamic_call(helper, args...);
}

template<typename Result, typename Arg0, typename FunArg0, typename ... Args, typename ... FunArgs>
Result dynamic_call (Result (*fun)(FunArg0*, FunArgs*...), Arg0* arg0, Args*... args)
{
    std::function<Result(FunArg0*, FunArgs*...)> sfn(fun);
    return dynamic_call(sfn, arg0, args...);
}

dispatcher: store a bunch of functions in a map, find them by actual dynamic types of passed arguments

template <typename Result, typename ... Args>
class Dispatcher
{
  public:

    Result operator() (Args*... args)
    {
        key k{tiholder(typeid(*args))...};
        typename map::iterator it = functions.find(k);
        if (it == functions.end())
            throw std::runtime_error("Function not found!");
        return it->second(args...);
    }

    template <typename ... FunArgs>
    void register_fn(std::function<Result(FunArgs*...)> fun)
    {
        auto lam = [fun](Args*... args) -> Result
        {
            return dynamic_call(fun, args...);
        };
        key k{tiholder(typeid(FunArgs))...};
        functions[k] = lam;
    }

    template <typename ... FunArgs>
    void register_fn(Result(*fun)(FunArgs*...))
    {
        return register_fn(std::function<Result(FunArgs*...)>(fun));
    }

  private:

    struct tiholder
    {
        const std::type_info* ti;
        tiholder(const std::type_info& ti) : ti(&ti) {}
        bool operator< (const tiholder& other) const { return ti->before(*other.ti); }
    };

    static constexpr int PackSize = sizeof ... (Args);
    using key = std::array<tiholder, PackSize>;
    using value = std::function<Result(Args*...)>;
    using map = std::map<key, value>;
    map functions;
};

test case

struct Base { virtual ~Base() {} } ;

struct A : Base {};
struct B : Base {};

void foo1(A*,A*) { std::cout << "foo(A*,A*)\n"; }
void foo2(A*,B*) { std::cout << "foo(A*,B*)\n"; }
void foo3(B*,A*) { std::cout << "foo(B*,A*)\n"; }
void foo4(B*,B*) { std::cout << "foo(B*,B*)\n"; }

test driver and user guide

int main ()
{
    Base* x = new A;
    Base* y = new B;

    Dispatcher<void,Base,Base> foo;
    foo.register_fn(foo1);
    foo.register_fn(foo2);
    foo.register_fn(foo3);
    foo.register_fn(foo4);

    foo(x,x);
    foo(x,y);
    foo(y,x);
    foo(y,y);

}

Here is a shorter implementation of multiple dispatch, based originally on Jarod42's ideas.

#include <iostream>
#include <tuple>
#include <functional>
#include <utility>
#include <array>

class A {
public:
    virtual ~A() = default;
    virtual std::size_t getIndex() const = 0;
};

template <typename Derived>
class CRTP : public A {
    virtual std::size_t getIndex() const override;
};

class B : public CRTP<B> {};

class C : public CRTP<C> {};

class D : public CRTP<D> {};

class Object {
    public:
        double foo (A*, A*) { std::cout << "Object::foo A,A\n";  return 3.14; }
        double foo (B*, B*) { std::cout << "Object::foo B,B\n";  return 3.14; }
        double foo (B*, C*) { std::cout << "Object::foo B,C\n";  return 3.14; }
        double foo (C*, B*) { std::cout << "Object::foo C,B\n";  return 3.14; }
        double foo (C*, C*) { std::cout << "Object::foo C,C\n";  return 3.14; }
        double foo (B*, D*) { std::cout << "Object::foo B,D\n";  return 3.14; }
        char foo (A*, A*, A*) const { std::cout << "Object::foo A,A,A\n";  return '&'; }
        char foo (C*, B*, D*) const { std::cout << "Object::foo C,B,D\n";  return '!'; }
        template <typename R, typename... Args> R fooMultipleDispatch (Args*...);
};

using ArgumentTuple = std::tuple<A,B,C,D>;  // This must be updated whenever the Object::foo overloads are updated with new argument types.

template <std::size_t N>
using ArgumentType = std::tuple_element_t<N, ArgumentTuple>;

template <std::size_t N, typename Arg>
using Type = Arg;

template <typename T, std::size_t N, T v>
constexpr T Value = v;

template <typename T, typename... Ts> struct GetIndex;

template <typename T, typename... Ts>
struct GetIndex<T, T, Ts...> : std::integral_constant<std::size_t, 0> {};

template <typename T, typename Tail,  typename... Ts>
struct GetIndex<T, Tail, Ts...> : std::integral_constant <std::size_t, GetIndex<T, Ts...>::value + 1> {};

template <typename T, typename Tuple> struct GetIndexInTuple;

template <typename T, template <typename...> class Tuple, typename... Ts>
struct GetIndexInTuple<T, Tuple<Ts...>> : GetIndex<T, Ts...> {};

template <typename Derived>
std::size_t CRTP<Derived>::getIndex() const {
    return GetIndexInTuple<Derived, ArgumentTuple>::value;
}

template <typename Element, typename Array>
auto getArrayElement (const Element& element, const Array&, std::size_t) {
    return element;
}

template <typename Subarray, std::size_t N, typename Array>
auto getArrayElement (const std::array<Subarray, N>& array, const Array& indices, std::size_t n = 0) {
    return getArrayElement(array[indices[n]], indices, n+1);
}

template <typename... Ts>
constexpr std::array<std::common_type_t<Ts...>, sizeof...(Ts)> makeArray (Ts&&... ts) {
    return { {std::forward<Ts>(ts)...} };
}

template <typename R, typename Class, typename Arg, std::size_t... Is>  // This terminating createMultiArrayHelper overload just creates one of the foo overloads.
constexpr auto createMultiArrayHelper (std::index_sequence<>, std::index_sequence<Is...>) {
    return [](Class* object, Type<Is, Arg*>... args)->R {return object->foo(static_cast<ArgumentType<Is>*>(args)...);};
}

template <typename R, typename Class, typename Arg, std::size_t First, std::size_t... Rest, std::size_t... Is, std::size_t... Js>
constexpr auto createMultiArrayHelper (std::index_sequence<Is...>, std::index_sequence<Js...>) {  // sizeof...(Is) is the dimension size of the current subarray.
    return makeArray (createMultiArrayHelper<R, Class, Arg, Rest...>(std::make_index_sequence<First>{}, std::index_sequence<Js..., Is>{})...);
}

template <typename R, typename Class, typename Arg, std::size_t First, std::size_t... Rest>
constexpr auto createMultiArray() {
    return createMultiArrayHelper<R, Class, Arg, Rest..., 0>(std::make_index_sequence<First>{}, std::index_sequence<>{});  // 0 for the terminating case, since std::make_index_sequence<0> is std::index_sequence<> (in the terminating createMultiArrayHelper overload).
}

template <typename R, typename Class, typename Arg, std::size_t Size, std::size_t... Is>
constexpr auto createFunctionMultiArrayHelper (std::index_sequence<Is...>) {
    return createMultiArray<R, Class, Arg, Value<std::size_t, Is, Size>...>();  // i.e. createMultiArray<R, Class, Arg, Size, Size, Size, ..., Size>(), where Size is repeated sizeof...(Is) times.
}

template <typename R, typename Class, typename Arg, std::size_t Size, std::size_t NumDimensions>
constexpr auto createFunctionMultiArray() {
    return createFunctionMultiArrayHelper<R, Class, Arg, Size>(std::make_index_sequence<NumDimensions>{});
}

template <typename R, typename... Args>
R Object::fooMultipleDispatch (Args*... args) {  // Jarod42 says that R must be known during compile-time (as his solution requires as well), though covariant return types can be possible if Args*... is B*,B*,...
    constexpr std::size_t N = sizeof...(Args);
    static const auto fooOverloads = createFunctionMultiArray<R, Object, A, std::tuple_size<ArgumentTuple>::value, N>();  // We use the keyword static because we only want to create this (const) array once for the entire program run.
    const std::array<std::size_t, N> indexArray = {args->getIndex()...};
    const auto f = getArrayElement(fooOverloads, indexArray);
    return f(this, args...);
}

int main() {
    B* b = new B;
    C* c = new C;
    D* d = new D;
    A* a[] = {b, c, d};
    Object object;
    object.foo (a[0], a[2]);  // Object::foo A,A  (no multiple dispatch)
    object.fooMultipleDispatch<double>(a[0], a[2]);  // Object::foo B,D  (multiple dispatch!)
    object.fooMultipleDispatch<char>(a[1], a[0], a[2]);  // Object::foo C,B,D  (multiple dispatch!)
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!