Класс с типами, зависящими от вариационных шаблонов

Недавно я просмотрел видео, которое вдохновило меня написать мою собственную систему нейронных сетей, и я хотел, чтобы количество узлов в сети было настраиваемым.

Сначала я достиг этого во время выполнения, анализируя массив чисел узлов, но мне было интересно, могу ли я сделать это во время компиляции. Вот пример того, что я надеялся выполнить.

template<int FirstNodes, int SecondNodes, int... OtherNodes>
class Net
{
    tuple<Eigen::Matrix<float, FirstNodes, SecondNodes>, ...> m_weights;
    // More matricies with the values from the OtherNodes
};

В качестве более подробного примера Net<784, 16, 16, 10> n; n.m_weight должен иметь тип

tuple<Eigen::Matrix<float, 784, 16>,
    Eigen::Matrix<float, 16, 16>,
    Eigen::Matrix<float, 16, 10>>

Из того, что я знаю о С++ и constexpr, это должно быть возможно.

Я должен добавить, что я смог сделать

template<int FirstNodes, int SecondNodes, int... OtherNodes>
class Net
{
public:
    Net()
    {
        auto nodes = {FirstNodes, SecondNodes, OtherNodes...};

        auto i = nodes.begin();
        do 
        {
            // Eigen::Matrix<float, Dynamic, Dynamic>
            Eigen::MatrixXf m(*(i++), *i);
        } while (i+1 != nodes.end());
    }
};

Но потом я снова использую динамические матрицы, и это не то, на что я надеялся.

Приветствуются любые советы или рабочие примеры.

Ответ 1

Вы хотите, чтобы какое-то преобразование типа, которое дало список целых чисел N, возвращает a tuple из N - 1 матриц. Здесь решение С++ 17:

template <int A, int B, int... Is>
auto make_matrix_tuple()
{   
    if constexpr(sizeof...(Is) == 0)
    {
        return std::tuple<Eigen::Matrix<float, A, B>>{};
    }
    else
    {
        return std::tuple_cat(make_matrix_tuple<A, B>(), 
                            make_matrix_tuple<B, Is...>());
    }
}

живой пример в wandbox


В С++ 11 вы можете реализовать это преобразование типа рекурсивно:

template <int... Is>
struct matrix_tuple_helper;

template <int A, int B, int... Rest>
struct matrix_tuple_helper<A, B, Rest...>
{
    using curr_matrix = Eigen::Matrix<float, A, B>;
    using type = 
        decltype(
            std::tuple_cat(
                std::tuple<curr_matrix>{},
                typename matrix_tuple_helper<B, Rest...>::type{}
            )
        );
};

template <int A, int B>
struct matrix_tuple_helper<A, B>
{
    using curr_matrix = Eigen::Matrix<float, A, B>;
    using type = std::tuple<curr_matrix>;
};

template <int... Is>
using matrix_tuple = typename matrix_tuple_helper<Is...>::type;

С++ 14:

struct matrix_tuple_maker
{
    template <int A, int B, int C, int... Is>
    static auto get()
    {
        return std::tuple_cat(get<A, B>(), get<B, C, Is...>());
    }

    template <int A, int B>
    static auto get()
    {
        return std::tuple<Eigen::Matrix<float, A, B>>{};
    }
};

static_assert(std::is_same_v<
    decltype(matrix_tuple_maker::get<784, 16, 16, 10>()),
    std::tuple<Eigen::Matrix<float, 784, 16>,
               Eigen::Matrix<float, 16, 16>,
               Eigen::Matrix<float, 16, 10>>
    >);

Ответ 2

Мне кажется, что вам нужен два списка целых чисел, не входящих в фазу 1.

Если вы определяете тривиальный контейнер целых чисел (в С++ 14 вы можете использовать std::integer_sequence)

template <int...>
struct iList
 { };

вы можете определить базовый класс следующим образом (извините: используется foo вместо Eigen::Matrix)

template <typename, typename, typename = std::tuple<>>
struct NetBase;

// avoid the first couple
template <int ... Is, int J0, int ... Js>
struct NetBase<iList<0, Is...>, iList<J0, Js...>, std::tuple<>>
   : NetBase<iList<Is...>, iList<Js...>, std::tuple<>>
 { };

// intermediate case
template <int I0, int ... Is, int J0, int ... Js, typename ... Ts>
struct NetBase<iList<I0, Is...>, iList<J0, Js...>, std::tuple<Ts...>>
   : NetBase<iList<Is...>, iList<Js...>,
             std::tuple<Ts..., foo<float, I0, J0>>>
 { };

// avoid the last couple and terminate
template <int I0, typename ... Ts>
struct NetBase<iList<I0>, iList<0>, std::tuple<Ts...>>
 { using type = std::tuple<Ts...>; };

и Net просто становятся (наблюдают за фазой пару целых списков)

template <int F, int S, int... Os>
struct Net : NetBase<iList<0, F, S, Os...>, iList<F, S, Os..., 0>>
 { };

Ниже приведен полный пример компиляции

#include <tuple>

template <int...>
struct iList
 { };

template <typename, int, int>
struct foo
 { };

template <typename, typename, typename = std::tuple<>>
struct NetBase;

// avoid the first couple
template <int ... Is, int J0, int ... Js>
struct NetBase<iList<0, Is...>, iList<J0, Js...>, std::tuple<>>
   : NetBase<iList<Is...>, iList<Js...>, std::tuple<>>
 { };

// intermediate case
template <int I0, int ... Is, int J0, int ... Js, typename ... Ts>
struct NetBase<iList<I0, Is...>, iList<J0, Js...>, std::tuple<Ts...>>
   : NetBase<iList<Is...>, iList<Js...>,
             std::tuple<Ts..., foo<float, I0, J0>>>
 { };

// avoid the last couple and terminate
template <int I0, typename ... Ts>
struct NetBase<iList<I0>, iList<0>, std::tuple<Ts...>>
 { using type = std::tuple<Ts...>; };

template <int F, int S, int... Os>
struct Net : NetBase<iList<0, F, S, Os...>, iList<F, S, Os..., 0>>
 { };

int main()
 {
   static_assert(std::is_same<
      typename Net<784, 16, 16, 10>::type, 
      std::tuple<foo<float, 784, 16>, foo<float, 16, 16>,
                 foo<float, 16, 10>>>{}, "!");
 }

Ответ 3

Вот еще одно решение С++ 14. Я считаю, что это стоит публикации, потому что это нерекурсивный и читаемый.

#include <tuple>
#include <utility>

template<class, int, int> struct Matrix {};

template<int... matsizes, std::size_t... matinds>
constexpr auto make_net(
  std::integer_sequence<int, matsizes...>,
  std::index_sequence<matinds...>
) {
  constexpr int sizes[] = {matsizes...};
  return std::tuple< Matrix<float, sizes[matinds], sizes[1+matinds]>... >{};
}

template<int... matsizes>
constexpr auto make_net(
  std::integer_sequence<int, matsizes...> sizes
) {
  static_assert(sizes.size() >= 2, "");
  constexpr auto number_of_mats = sizes.size() - 1;
  return make_net(sizes, std::make_index_sequence<number_of_mats>{});
}

int main () {
  auto net = make_net(std::integer_sequence<int, 784, 16, 16, 10>{});
  using Net = decltype(net);

  static_assert(
    std::is_same<
      std::tuple<
        Matrix<float, 784, 16>,
        Matrix<float, 16, 16>,
        Matrix<float, 16, 10>
      >,
      Net
    >{}, ""
  );

  return 0;
}