Как определить вырожденную размерность boost_ multi_array во время выполнения?

У меня есть 3D multi_array, и я хотел бы сделать 2D-фрагменты, используя измерения, указанные во время выполнения. Я знаю индекс вырожденной размерности и индекс среза, который я хочу извлечь в этой вырожденной размерности. В настоящее время уродливое обходное решение выглядит так:

if (0 == degenerate_dimension)
{
    Slice slice = input_array[boost::indices[slice_index][range()][range()]];
}
else if (1 == degenerate_dimension)
{
    Slice slice = input_array[boost::indices[range()][slice_index][range()]];
}
else if (2 == degenerate_dimension)
{
    Slice slice = input_array[boost::indices[range()][range()][slice_index]];
}

Есть ли более красивый способ создания объекта index_gen? Что-то вроде этого:

var slicer;
for(int i = 0; i < 3; ++i) {
    if (degenerate_dimension == i)
        slicer = boost::indices[slice_index];
    else
        slicer = boost::indices[range()];
}
Slice slice = input_array[slicer];

Кажется, что каждый последующий вызов boost:: indices:: operator [] возвращает другой тип в зависимости от размерности (т.е. количества предыдущих вызовов), поэтому нет возможности использовать одну переменную, которая может содержать временный индекс_индекс объект.

Ответ 1

Пожалуйста, попробуйте это. У кода есть один недостаток - он ссылается на переменную массива range_, объявленную в пространстве имен boost:: detail:: multi_array.

#include <boost/multi_array.hpp>                                                                                                                              

typedef boost::multi_array<double, 3> array_type;                                                                                                             
typedef boost::multi_array_types::index_gen::gen_type<2,3>::type index_gen_type;                                                                                   
typedef boost::multi_array_types::index_range range;                                                                                                          

index_gen_type                                                                                                                                                     
func(int degenerate_dimension, int slice_index)                                                                                                               
{                                                                                                                                                             
    index_gen_type slicer;                                                                                                                                         
    int i;                                                                                                                                                    
    for(int i = 0; i < 3; ++i) {                                                                                                                              
        if (degenerate_dimension == i)                                                                                                                        
            slicer.ranges_[i] = range(slice_index);                                                                                                           
        else                                                                                                                                                  
            slicer.ranges_[i] = range();                                                                                                                      
    }                                                                                                                                                         
    return slicer;                                                                                                                                            
}                                                                                                                                                             

int main(int argc, char **argv)                                                                                                                               
{                                                                                                                                                             
    array_type myarray(boost::extents[3][3][3]);                                                                                                              
    array_type::array_view<2>::type myview = myarray[ func(2, 1) ];                                                                                           
    return 0;                                                                                                                                                 
}

Ответ 2

То, что вы пытаетесь сделать, - это переместить переменную из времени выполнения для компиляции времени. Это можно сделать только с помощью цепочки операторов if else или оператора switch.

Упрощенный пример

// print a compile time int
template< int I >
void printer( void )
{
   std::cout << I << '\n';
}

// print a run time int
void printer( int i )
{
   // translate a runtime int to a compile time int
   switch( i )
   {
      case 1: printer<1>(); break;
      case 2: printer<2>(); break;
      case 3: printer<3>(); break;
      case 4: printer<4>(); break;
      default: throw std::logic_error( "not implemented" );
   }
}

// compile time ints
enum{ enum_i = 2 };
const int const_i = 3;
constexpr i constexper_i( void ) { return 4; }

// run time ints
extern int func_i( void ); // { return 5; }
extern int global_i; // = 6

int main()
{
   int local_i = 7;
   const int local_const_i = 8;

   printer<enum_i>();
   printer<const_i>();
   printer<constexpr_i()>();
   //printer<func_i()>();
   //printer<global_i>();
   //printer<local_i>();
   printer<local_const_i>();

   printer( enum_i );
   printer( const_i );
   printer( constexpr_i() );
   printer( func_i()      ); // throws an exception
   printer( global_i      ); // throws an exception
   printer( local_i       ); // throws an exception
   printer( local_const_i ); // throws an exception
}