58 const std::vector<int64_t> &layers,
59 const std::vector<std::vector<std::any>> &activations,
61 assert(layers.size() == activations.size() + 1);
64 for (
auto i = 0; i < layers.size() - 1; ++i) {
66 register_module(
"layer[" + std::to_string(i) +
"]",
67 torch::nn::Linear(layers[i], layers[i + 1])));
68 layers_.back()->to(options.device(), options.dtype(),
true);
70 torch::nn::init::xavier_uniform_(
layers_.back()->weight);
71 torch::nn::init::constant_(
layers_.back()->bias, 0.0);
75 for (
const auto &a : activations)
84 throw std::runtime_error(
"Invalid number of parameters");
93 std::any_cast<torch::Tensor>(a[1]),
94 std::any_cast<torch::Tensor>(a[2]),
95 std::any_cast<torch::Tensor>(a[3]),
96 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
97 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
101 std::any_cast<torch::Tensor>(a[1]),
102 std::any_cast<torch::Tensor>(a[2]),
103 std::any_cast<torch::Tensor>(a[3]),
104 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
105 std::any_cast<double>(a[6])});
109 std::any_cast<torch::Tensor>(a[1]),
110 std::any_cast<torch::Tensor>(a[2]),
111 std::any_cast<torch::nn::functional::BatchNormFuncOptions>(
116 new BatchNorm{std::any_cast<torch::Tensor>(a[1]),
117 std::any_cast<torch::Tensor>(a[2])});
120 throw std::runtime_error(
"Invalid number of parameters");
129 new CELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
134 std::any_cast<torch::nn::functional::CELUFuncOptions>(a[1])});
136 activations_.emplace_back(
new CELU{std::any_cast<double>(a[1])});
143 throw std::runtime_error(
"Invalid number of parameters");
152 new ELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
157 std::any_cast<torch::nn::functional::ELUFuncOptions>(a[1])});
159 activations_.emplace_back(
new ELU{std::any_cast<double>(a[1])});
166 throw std::runtime_error(
"Invalid number of parameters");
177 throw std::runtime_error(
"Invalid number of parameters");
187 std::any_cast<torch::nn::functional::GLUFuncOptions>(a[1])});
189 activations_.emplace_back(
new GLU{std::any_cast<int64_t>(a[1])});
196 throw std::runtime_error(
"Invalid number of parameters");
205 std::any_cast<int64_t>(a[1]), std::any_cast<torch::Tensor>(a[2]),
206 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
211 std::any_cast<torch::nn::functional::GroupNormFuncOptions>(
215 new GroupNorm{std::any_cast<int64_t>(a[1])});
219 throw std::runtime_error(
"Invalid number of parameters");
228 std::any_cast<double>(a[1]), std::any_cast<int>(a[2]),
229 std::any_cast<bool>(a[3])});
233 std::any_cast<torch::nn::functional::GumbelSoftmaxFuncOptions>(
240 throw std::runtime_error(
"Invalid number of parameters");
250 std::any_cast<torch::nn::functional::HardshrinkFuncOptions>(
254 new Hardshrink{std::any_cast<double>(a[1])});
261 throw std::runtime_error(
"Invalid number of parameters");
272 throw std::runtime_error(
"Invalid number of parameters");
283 throw std::runtime_error(
"Invalid number of parameters");
291 activations_.emplace_back(
new Hardtanh{std::any_cast<double>(a[1]),
292 std::any_cast<double>(a[2]),
293 std::any_cast<bool>(a[3])});
296 activations_.emplace_back(
new Hardtanh{std::any_cast<double>(a[1]),
297 std::any_cast<double>(a[2])});
301 std::any_cast<torch::nn::functional::HardtanhFuncOptions>(a[1])});
307 throw std::runtime_error(
"Invalid number of parameters");
316 std::any_cast<torch::Tensor>(a[1]),
317 std::any_cast<torch::Tensor>(a[2]),
318 std::any_cast<torch::Tensor>(a[3]),
319 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
320 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
324 std::any_cast<torch::Tensor>(a[1]),
325 std::any_cast<torch::Tensor>(a[2]),
326 std::any_cast<torch::Tensor>(a[3]),
327 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
328 std::any_cast<double>(a[6])});
332 std::any_cast<torch::nn::functional::InstanceNormFuncOptions>(
339 throw std::runtime_error(
"Invalid number of parameters");
348 std::any_cast<std::vector<int64_t>>(a[1]),
349 std::any_cast<torch::Tensor>(a[2]),
350 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
355 std::any_cast<torch::nn::functional::LayerNormFuncOptions>(
359 new LayerNorm{std::any_cast<std::vector<int64_t>>(a[1])});
363 throw std::runtime_error(
"Invalid number of parameters");
371 activations_.emplace_back(
new LeakyReLU{std::any_cast<double>(a[1]),
372 std::any_cast<bool>(a[2])});
377 std::any_cast<torch::nn::functional::LeakyReLUFuncOptions>(
381 new LeakyReLU{std::any_cast<double>(a[1])});
388 throw std::runtime_error(
"Invalid number of parameters");
397 std::any_cast<int64_t>(a[1]), std::any_cast<double>(a[2]),
398 std::any_cast<double>(a[3]), std::any_cast<double>(a[4])});
402 activations_.emplace_back(
new LocalResponseNorm{std::any_cast<
403 torch::nn::functional::LocalResponseNormFuncOptions>(a[1])});
406 new LocalResponseNorm{std::any_cast<int64_t>(a[1])});
410 throw std::runtime_error(
"Invalid number of parameters");
421 throw std::runtime_error(
"Invalid number of parameters");
431 std::any_cast<torch::nn::functional::LogSoftmaxFuncOptions>(
435 new LogSoftmax{std::any_cast<int64_t>(a[1])});
439 throw std::runtime_error(
"Invalid number of parameters");
450 throw std::runtime_error(
"Invalid number of parameters");
459 std::any_cast<double>(a[1]), std::any_cast<double>(a[2]),
460 std::any_cast<int64_t>(a[3])});
464 std::any_cast<torch::nn::functional::NormalizeFuncOptions>(
471 throw std::runtime_error(
"Invalid number of parameters");
480 new PReLU{std::any_cast<torch::Tensor>(a[1])});
483 throw std::runtime_error(
"Invalid number of parameters");
493 std::any_cast<torch::nn::functional::ReLUFuncOptions>(a[1])});
495 activations_.emplace_back(
new ReLU{std::any_cast<bool>(a[1])});
502 throw std::runtime_error(
"Invalid number of parameters");
512 std::any_cast<torch::nn::functional::ReLU6FuncOptions>(a[1])});
514 activations_.emplace_back(
new ReLU6{std::any_cast<bool>(a[1])});
521 throw std::runtime_error(
"Invalid number of parameters");
529 activations_.emplace_back(
new RReLU{std::any_cast<double>(a[1]),
530 std::any_cast<double>(a[2]),
531 std::any_cast<bool>(a[3])});
534 activations_.emplace_back(
new RReLU{std::any_cast<double>(a[1]),
535 std::any_cast<double>(a[2])});
539 std::any_cast<torch::nn::functional::RReLUFuncOptions>(a[1])});
545 throw std::runtime_error(
"Invalid number of parameters");
555 std::any_cast<torch::nn::functional::SELUFuncOptions>(a[1])});
557 activations_.emplace_back(
new SELU{std::any_cast<bool>(a[1])});
564 throw std::runtime_error(
"Invalid number of parameters");
575 throw std::runtime_error(
"Invalid number of parameters");
586 throw std::runtime_error(
"Invalid number of parameters");
596 std::any_cast<torch::nn::functional::SoftmaxFuncOptions>(
600 new Softmax{std::any_cast<int64_t>(a[1])});
604 throw std::runtime_error(
"Invalid number of parameters");
614 std::any_cast<torch::nn::functional::SoftminFuncOptions>(
618 new Softmin{std::any_cast<int64_t>(a[1])});
622 throw std::runtime_error(
"Invalid number of parameters");
630 activations_.emplace_back(
new Softplus{std::any_cast<double>(a[1]),
631 std::any_cast<double>(a[2])});
635 std::any_cast<torch::nn::functional::SoftplusFuncOptions>(a[1])});
641 throw std::runtime_error(
"Invalid number of parameters");
651 std::any_cast<torch::nn::functional::SoftshrinkFuncOptions>(
655 new Softshrink{std::any_cast<double>(a[1])});
662 throw std::runtime_error(
"Invalid number of parameters");
673 throw std::runtime_error(
"Invalid number of parameters");
684 throw std::runtime_error(
"Invalid number of parameters");
695 throw std::runtime_error(
"Invalid number of parameters");
703 activations_.emplace_back(
new Threshold{std::any_cast<double>(a[1]),
704 std::any_cast<double>(a[2]),
705 std::any_cast<bool>(a[3])});
708 activations_.emplace_back(
new Threshold{std::any_cast<double>(a[1]),
709 std::any_cast<double>(a[2])});
713 std::any_cast<torch::nn::functional::ThresholdFuncOptions>(
717 throw std::runtime_error(
"Invalid number of parameters");
722 throw std::runtime_error(
"Invalid activation function");