How to specify degenerate dimension of boost multi_array at runtime?

前端 未结 2 2281
独厮守ぢ
独厮守ぢ 2021-02-20 14:06

I have a 3D multi_array and I would like to make 2D slices using dimensions specified at runtime. I know the index of degenerate dimension and the index of a slice that I want t

相关标签:
2条回答
  • 2021-02-20 14:23

    What you're trying to do is move a variable from run time to compile time. This can only be done with a chain of if else statements or a switch statement.

    A simplified example

    // print a compile time int
    template< int I >
    void printer( void )
    {
       std::cout << I << '\n';
    }
    
    // print a run time int
    void printer( int i )
    {
       // translate a runtime int to a compile time int
       switch( i )
       {
          case 1: printer<1>(); break;
          case 2: printer<2>(); break;
          case 3: printer<3>(); break;
          case 4: printer<4>(); break;
          default: throw std::logic_error( "not implemented" );
       }
    }
    
    // compile time ints
    enum{ enum_i = 2 };
    const int const_i = 3;
    constexpr i constexper_i( void ) { return 4; }
    
    // run time ints
    extern int func_i( void ); // { return 5; }
    extern int global_i; // = 6
    
    int main()
    {
       int local_i = 7;
       const int local_const_i = 8;
    
       printer<enum_i>();
       printer<const_i>();
       printer<constexpr_i()>();
       //printer<func_i()>();
       //printer<global_i>();
       //printer<local_i>();
       printer<local_const_i>();
    
       printer( enum_i );
       printer( const_i );
       printer( constexpr_i() );
       printer( func_i()      ); // throws an exception
       printer( global_i      ); // throws an exception
       printer( local_i       ); // throws an exception
       printer( local_const_i ); // throws an exception
    }
    
    0 讨论(0)
  • 2021-02-20 14:33

    Please, try this. Сode has one disadvantage - it refers to ranges_ array variable declared at boost::detail:: multi_array namespace.

    #include <boost/multi_array.hpp>                                                                                                                              
    
    typedef boost::multi_array<double, 3> array_type;                                                                                                             
    typedef boost::multi_array_types::index_gen::gen_type<2,3>::type index_gen_type;                                                                                   
    typedef boost::multi_array_types::index_range range;                                                                                                          
    
    index_gen_type                                                                                                                                                     
    func(int degenerate_dimension, int slice_index)                                                                                                               
    {                                                                                                                                                             
        index_gen_type slicer;                                                                                                                                         
        int i;                                                                                                                                                    
        for(int i = 0; i < 3; ++i) {                                                                                                                              
            if (degenerate_dimension == i)                                                                                                                        
                slicer.ranges_[i] = range(slice_index);                                                                                                           
            else                                                                                                                                                  
                slicer.ranges_[i] = range();                                                                                                                      
        }                                                                                                                                                         
        return slicer;                                                                                                                                            
    }                                                                                                                                                             
    
    int main(int argc, char **argv)                                                                                                                               
    {                                                                                                                                                             
        array_type myarray(boost::extents[3][3][3]);                                                                                                              
        array_type::array_view<2>::type myview = myarray[ func(2, 1) ];                                                                                           
        return 0;                                                                                                                                                 
    }
    
    0 讨论(0)
提交回复
热议问题