IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
generator.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <any>
18#include <iostream>
19#include <vector>
20
21#include <core/core.hpp>
22#include <core/options.hpp>
23#include <net/activation.hpp>
24#include <utils/zip.hpp>
25
26namespace iganet {
27
28// clang-format off
30enum class nn_init : short_t {
31 constant = 0,
32 normal = 1,
33 uniform = 2,
34 kaiming_normal = 3,
35 kaiming_uniform = 4,
36 xavier_normal = 5,
37 xavier_uniform = 6,
38};
39// clang-format on
40
50template <typename real_t>
51class IgANetGeneratorImpl : public torch::nn::Module {
52public:
55
58 const std::vector<int64_t> &layers,
59 const std::vector<std::vector<std::any>> &activations,
60 Options<real_t> options = Options<real_t>{}) {
61 assert(layers.size() == activations.size() + 1);
62
63 // Generate vector of linear layers and register them as layer[i]
64 for (auto i = 0; i < layers.size() - 1; ++i) {
65 layers_.emplace_back(
66 register_module("layer[" + std::to_string(i) + "]",
67 torch::nn::Linear(layers[i], layers[i + 1])));
68 layers_.back()->to(options.device(), options.dtype(), true);
69
70 torch::nn::init::xavier_uniform_(layers_.back()->weight);
71 torch::nn::init::constant_(layers_.back()->bias, 0.0);
72 }
73
74 // Generate vector of activation functions
75 for (const auto &a : activations)
76 switch (std::any_cast<activation>(a[0])) {
77 // No activation function
79 switch (a.size()) {
80 case 1:
81 activations_.emplace_back(new None{});
82 break;
83 default:
84 throw std::runtime_error("Invalid number of parameters");
85 }
86 break;
87
88 // Batch Normalization
90 switch (a.size()) {
91 case 8:
92 activations_.emplace_back(new BatchNorm{
93 std::any_cast<torch::Tensor>(a[1]),
94 std::any_cast<torch::Tensor>(a[2]),
95 std::any_cast<torch::Tensor>(a[3]),
96 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
97 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
98 break;
99 case 7:
100 activations_.emplace_back(new BatchNorm{
101 std::any_cast<torch::Tensor>(a[1]),
102 std::any_cast<torch::Tensor>(a[2]),
103 std::any_cast<torch::Tensor>(a[3]),
104 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
105 std::any_cast<double>(a[6])});
106 break;
107 case 4:
108 activations_.emplace_back(new BatchNorm{
109 std::any_cast<torch::Tensor>(a[1]),
110 std::any_cast<torch::Tensor>(a[2]),
111 std::any_cast<torch::nn::functional::BatchNormFuncOptions>(
112 a[3])});
113 break;
114 case 3:
115 activations_.emplace_back(
116 new BatchNorm{std::any_cast<torch::Tensor>(a[1]),
117 std::any_cast<torch::Tensor>(a[2])});
118 break;
119 default:
120 throw std::runtime_error("Invalid number of parameters");
121 }
122 break;
123
124 // CELU
125 case activation::celu:
126 switch (a.size()) {
127 case 3:
128 activations_.emplace_back(
129 new CELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
130 break;
131 case 2:
132 try {
133 activations_.emplace_back(new CELU{
134 std::any_cast<torch::nn::functional::CELUFuncOptions>(a[1])});
135 } catch (...) {
136 activations_.emplace_back(new CELU{std::any_cast<double>(a[1])});
137 }
138 break;
139 case 1:
140 activations_.emplace_back(new CELU{});
141 break;
142 default:
143 throw std::runtime_error("Invalid number of parameters");
144 }
145 break;
146
147 // ELU
148 case activation::elu:
149 switch (a.size()) {
150 case 3:
151 activations_.emplace_back(
152 new ELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
153 break;
154 case 2:
155 try {
156 activations_.emplace_back(new ELU{
157 std::any_cast<torch::nn::functional::ELUFuncOptions>(a[1])});
158 } catch (...) {
159 activations_.emplace_back(new ELU{std::any_cast<double>(a[1])});
160 }
161 break;
162 case 1:
163 activations_.emplace_back(new ELU{});
164 break;
165 default:
166 throw std::runtime_error("Invalid number of parameters");
167 }
168 break;
169
170 // GELU
171 case activation::gelu:
172 switch (a.size()) {
173 case 1:
174 activations_.emplace_back(new GELU{});
175 break;
176 default:
177 throw std::runtime_error("Invalid number of parameters");
178 }
179 break;
180
181 // GLU
182 case activation::glu:
183 switch (a.size()) {
184 case 2:
185 try {
186 activations_.emplace_back(new GLU{
187 std::any_cast<torch::nn::functional::GLUFuncOptions>(a[1])});
188 } catch (...) {
189 activations_.emplace_back(new GLU{std::any_cast<int64_t>(a[1])});
190 }
191 break;
192 case 1:
193 activations_.emplace_back(new GLU{});
194 break;
195 default:
196 throw std::runtime_error("Invalid number of parameters");
197 }
198 break;
199
200 // Group Normalization
202 switch (a.size()) {
203 case 5:
204 activations_.emplace_back(new GroupNorm{
205 std::any_cast<int64_t>(a[1]), std::any_cast<torch::Tensor>(a[2]),
206 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
207 break;
208 case 2:
209 try {
210 activations_.emplace_back(new GroupNorm{
211 std::any_cast<torch::nn::functional::GroupNormFuncOptions>(
212 a[1])});
213 } catch (...) {
214 activations_.emplace_back(
215 new GroupNorm{std::any_cast<int64_t>(a[1])});
216 }
217 break;
218 default:
219 throw std::runtime_error("Invalid number of parameters");
220 }
221 break;
222
223 // Gumbel-Softmax
225 switch (a.size()) {
226 case 4:
227 activations_.emplace_back(new GumbelSoftmax{
228 std::any_cast<double>(a[1]), std::any_cast<int>(a[2]),
229 std::any_cast<bool>(a[3])});
230 break;
231 case 2:
232 activations_.emplace_back(new GumbelSoftmax{
233 std::any_cast<torch::nn::functional::GumbelSoftmaxFuncOptions>(
234 a[1])});
235 break;
236 case 1:
237 activations_.emplace_back(new GumbelSoftmax{});
238 break;
239 default:
240 throw std::runtime_error("Invalid number of parameters");
241 }
242 break;
243
244 // Hard shrinkish
246 switch (a.size()) {
247 case 2:
248 try {
249 activations_.emplace_back(new Hardshrink{
250 std::any_cast<torch::nn::functional::HardshrinkFuncOptions>(
251 a[1])});
252 } catch (...) {
253 activations_.emplace_back(
254 new Hardshrink{std::any_cast<double>(a[1])});
255 }
256 break;
257 case 1:
258 activations_.emplace_back(new Hardshrink{});
259 break;
260 default:
261 throw std::runtime_error("Invalid number of parameters");
262 }
263 break;
264
265 // Hardsigmoid
267 switch (a.size()) {
268 case 1:
269 activations_.emplace_back(new Hardsigmoid{});
270 break;
271 default:
272 throw std::runtime_error("Invalid number of parameters");
273 }
274 break;
275
276 // Hardswish
278 switch (a.size()) {
279 case 1:
280 activations_.emplace_back(new Hardswish{});
281 break;
282 default:
283 throw std::runtime_error("Invalid number of parameters");
284 }
285 break;
286
287 // Hardtanh
289 switch (a.size()) {
290 case 4:
291 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
292 std::any_cast<double>(a[2]),
293 std::any_cast<bool>(a[3])});
294 break;
295 case 3:
296 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
297 std::any_cast<double>(a[2])});
298 break;
299 case 2:
300 activations_.emplace_back(new Hardtanh{
301 std::any_cast<torch::nn::functional::HardtanhFuncOptions>(a[1])});
302 break;
303 case 1:
304 activations_.emplace_back(new Hardtanh{});
305 break;
306 default:
307 throw std::runtime_error("Invalid number of parameters");
308 }
309 break;
310
311 // Instance Normalization
313 switch (a.size()) {
314 case 8:
315 activations_.emplace_back(new InstanceNorm{
316 std::any_cast<torch::Tensor>(a[1]),
317 std::any_cast<torch::Tensor>(a[2]),
318 std::any_cast<torch::Tensor>(a[3]),
319 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
320 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
321 break;
322 case 7:
323 activations_.emplace_back(new InstanceNorm{
324 std::any_cast<torch::Tensor>(a[1]),
325 std::any_cast<torch::Tensor>(a[2]),
326 std::any_cast<torch::Tensor>(a[3]),
327 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
328 std::any_cast<double>(a[6])});
329 break;
330 case 2:
331 activations_.emplace_back(new InstanceNorm{
332 std::any_cast<torch::nn::functional::InstanceNormFuncOptions>(
333 a[1])});
334 break;
335 case 1:
336 activations_.emplace_back(new InstanceNorm{});
337 break;
338 default:
339 throw std::runtime_error("Invalid number of parameters");
340 }
341 break;
342
343 // Layer Normalization
345 switch (a.size()) {
346 case 5:
347 activations_.emplace_back(new LayerNorm{
348 std::any_cast<std::vector<int64_t>>(a[1]),
349 std::any_cast<torch::Tensor>(a[2]),
350 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
351 break;
352 case 2:
353 try {
354 activations_.emplace_back(new LayerNorm{
355 std::any_cast<torch::nn::functional::LayerNormFuncOptions>(
356 a[1])});
357 } catch (...) {
358 activations_.emplace_back(
359 new LayerNorm{std::any_cast<std::vector<int64_t>>(a[1])});
360 }
361 break;
362 default:
363 throw std::runtime_error("Invalid number of parameters");
364 }
365 break;
366
367 // Leaky ReLU
369 switch (a.size()) {
370 case 3:
371 activations_.emplace_back(new LeakyReLU{std::any_cast<double>(a[1]),
372 std::any_cast<bool>(a[2])});
373 break;
374 case 2:
375 try {
376 activations_.emplace_back(new LeakyReLU{
377 std::any_cast<torch::nn::functional::LeakyReLUFuncOptions>(
378 a[1])});
379 } catch (...) {
380 activations_.emplace_back(
381 new LeakyReLU{std::any_cast<double>(a[1])});
382 }
383 break;
384 case 1:
385 activations_.emplace_back(new LeakyReLU{});
386 break;
387 default:
388 throw std::runtime_error("Invalid number of parameters");
389 }
390 break;
391
392 // Local response Normalization
394 switch (a.size()) {
395 case 5:
396 activations_.emplace_back(new LocalResponseNorm{
397 std::any_cast<int64_t>(a[1]), std::any_cast<double>(a[2]),
398 std::any_cast<double>(a[3]), std::any_cast<double>(a[4])});
399 break;
400 case 2:
401 try {
402 activations_.emplace_back(new LocalResponseNorm{std::any_cast<
403 torch::nn::functional::LocalResponseNormFuncOptions>(a[1])});
404 } catch (...) {
405 activations_.emplace_back(
406 new LocalResponseNorm{std::any_cast<int64_t>(a[1])});
407 }
408 break;
409 default:
410 throw std::runtime_error("Invalid number of parameters");
411 }
412 break;
413
414 // LogSigmoid
416 switch (a.size()) {
417 case 1:
418 activations_.emplace_back(new LogSigmoid{});
419 break;
420 default:
421 throw std::runtime_error("Invalid number of parameters");
422 }
423 break;
424
425 // LogSoftmax
427 switch (a.size()) {
428 case 2:
429 try {
430 activations_.emplace_back(new LogSoftmax{
431 std::any_cast<torch::nn::functional::LogSoftmaxFuncOptions>(
432 a[1])});
433 } catch (...) {
434 activations_.emplace_back(
435 new LogSoftmax{std::any_cast<int64_t>(a[1])});
436 }
437 break;
438 default:
439 throw std::runtime_error("Invalid number of parameters");
440 }
441 break;
442
443 // Mish
444 case activation::mish:
445 switch (a.size()) {
446 case 1:
447 activations_.emplace_back(new Mish{});
448 break;
449 default:
450 throw std::runtime_error("Invalid number of parameters");
451 }
452 break;
453
454 // Lp Normalization
456 switch (a.size()) {
457 case 4:
458 activations_.emplace_back(new Normalize{
459 std::any_cast<double>(a[1]), std::any_cast<double>(a[2]),
460 std::any_cast<int64_t>(a[3])});
461 break;
462 case 2:
463 activations_.emplace_back(new Normalize{
464 std::any_cast<torch::nn::functional::NormalizeFuncOptions>(
465 a[1])});
466 break;
467 case 1:
468 activations_.emplace_back(new Normalize{});
469 break;
470 default:
471 throw std::runtime_error("Invalid number of parameters");
472 }
473 break;
474
475 // PReLU
477 switch (a.size()) {
478 case 2:
479 activations_.emplace_back(
480 new PReLU{std::any_cast<torch::Tensor>(a[1])});
481 break;
482 default:
483 throw std::runtime_error("Invalid number of parameters");
484 }
485 break;
486
487 // ReLU
488 case activation::relu:
489 switch (a.size()) {
490 case 2:
491 try {
492 activations_.emplace_back(new ReLU{
493 std::any_cast<torch::nn::functional::ReLUFuncOptions>(a[1])});
494 } catch (...) {
495 activations_.emplace_back(new ReLU{std::any_cast<bool>(a[1])});
496 }
497 break;
498 case 1:
499 activations_.emplace_back(new ReLU{});
500 break;
501 default:
502 throw std::runtime_error("Invalid number of parameters");
503 }
504 break;
505
506 // Relu6
508 switch (a.size()) {
509 case 2:
510 try {
511 activations_.emplace_back(new ReLU6{
512 std::any_cast<torch::nn::functional::ReLU6FuncOptions>(a[1])});
513 } catch (...) {
514 activations_.emplace_back(new ReLU6{std::any_cast<bool>(a[1])});
515 }
516 break;
517 case 1:
518 activations_.emplace_back(new ReLU6{});
519 break;
520 default:
521 throw std::runtime_error("Invalid number of parameters");
522 }
523 break;
524
525 // Randomized ReLU
527 switch (a.size()) {
528 case 4:
529 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
530 std::any_cast<double>(a[2]),
531 std::any_cast<bool>(a[3])});
532 break;
533 case 3:
534 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
535 std::any_cast<double>(a[2])});
536 break;
537 case 2:
538 activations_.emplace_back(new RReLU{
539 std::any_cast<torch::nn::functional::RReLUFuncOptions>(a[1])});
540 break;
541 case 1:
542 activations_.emplace_back(new RReLU{});
543 break;
544 default:
545 throw std::runtime_error("Invalid number of parameters");
546 }
547 break;
548
549 // SELU
550 case activation::selu:
551 switch (a.size()) {
552 case 2:
553 try {
554 activations_.emplace_back(new SELU{
555 std::any_cast<torch::nn::functional::SELUFuncOptions>(a[1])});
556 } catch (...) {
557 activations_.emplace_back(new SELU{std::any_cast<bool>(a[1])});
558 }
559 break;
560 case 1:
561 activations_.emplace_back(new SELU{});
562 break;
563 default:
564 throw std::runtime_error("Invalid number of parameters");
565 }
566 break;
567
568 // Sigmoid
570 switch (a.size()) {
571 case 1:
572 activations_.emplace_back(new Sigmoid{});
573 break;
574 default:
575 throw std::runtime_error("Invalid number of parameters");
576 }
577 break;
578
579 // SiLU
580 case activation::silu:
581 switch (a.size()) {
582 case 1:
583 activations_.emplace_back(new SiLU{});
584 break;
585 default:
586 throw std::runtime_error("Invalid number of parameters");
587 }
588 break;
589
590 // Softmax
592 switch (a.size()) {
593 case 2:
594 try {
595 activations_.emplace_back(new Softmax{
596 std::any_cast<torch::nn::functional::SoftmaxFuncOptions>(
597 a[1])});
598 } catch (...) {
599 activations_.emplace_back(
600 new Softmax{std::any_cast<int64_t>(a[1])});
601 }
602 break;
603 default:
604 throw std::runtime_error("Invalid number of parameters");
605 }
606 break;
607
608 // Softmin
610 switch (a.size()) {
611 case 2:
612 try {
613 activations_.emplace_back(new Softmin{
614 std::any_cast<torch::nn::functional::SoftminFuncOptions>(
615 a[1])});
616 } catch (...) {
617 activations_.emplace_back(
618 new Softmin{std::any_cast<int64_t>(a[1])});
619 }
620 break;
621 default:
622 throw std::runtime_error("Invalid number of parameters");
623 }
624 break;
625
626 // Softplus
628 switch (a.size()) {
629 case 3:
630 activations_.emplace_back(new Softplus{std::any_cast<double>(a[1]),
631 std::any_cast<double>(a[2])});
632 break;
633 case 2:
634 activations_.emplace_back(new Softplus{
635 std::any_cast<torch::nn::functional::SoftplusFuncOptions>(a[1])});
636 break;
637 case 1:
638 activations_.emplace_back(new Softplus{});
639 break;
640 default:
641 throw std::runtime_error("Invalid number of parameters");
642 }
643 break;
644
645 // Softshrink
647 switch (a.size()) {
648 case 2:
649 try {
650 activations_.emplace_back(new Softshrink{
651 std::any_cast<torch::nn::functional::SoftshrinkFuncOptions>(
652 a[1])});
653 } catch (...) {
654 activations_.emplace_back(
655 new Softshrink{std::any_cast<double>(a[1])});
656 }
657 break;
658 case 1:
659 activations_.emplace_back(new Softshrink{});
660 break;
661 default:
662 throw std::runtime_error("Invalid number of parameters");
663 }
664 break;
665
666 // Softsign
668 switch (a.size()) {
669 case 1:
670 activations_.emplace_back(new Softsign{});
671 break;
672 default:
673 throw std::runtime_error("Invalid number of parameters");
674 }
675 break;
676
677 // Tanh
678 case activation::tanh:
679 switch (a.size()) {
680 case 1:
681 activations_.emplace_back(new Tanh{});
682 break;
683 default:
684 throw std::runtime_error("Invalid number of parameters");
685 }
686 break;
687
688 // Tanhshrink
690 switch (a.size()) {
691 case 1:
692 activations_.emplace_back(new Tanhshrink{});
693 break;
694 default:
695 throw std::runtime_error("Invalid number of parameters");
696 }
697 break;
698
699 // Threshold
701 switch (a.size()) {
702 case 4:
703 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
704 std::any_cast<double>(a[2]),
705 std::any_cast<bool>(a[3])});
706 break;
707 case 3:
708 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
709 std::any_cast<double>(a[2])});
710 break;
711 case 2:
712 activations_.emplace_back(new Threshold{
713 std::any_cast<torch::nn::functional::ThresholdFuncOptions>(
714 a[1])});
715 break;
716 default:
717 throw std::runtime_error("Invalid number of parameters");
718 }
719 break;
720
721 default:
722 throw std::runtime_error("Invalid activation function");
723 }
724 }
725
727 torch::Tensor forward(torch::Tensor x) {
728 torch::Tensor x_in = x.clone();
729
730 // Standard feed-forward neural network
731 for (auto [layer, activation] : utils::zip(layers_, activations_))
732 x = activation->apply(layer->forward(x));
733
734 return x;
735 }
736
738 inline torch::serialize::OutputArchive &
739 write(torch::serialize::OutputArchive &archive,
740 const std::string &key = "iganet") const {
741 assert(layers_.size() == activations_.size());
742
743 archive.write(key + ".layers",
744 torch::full({1}, static_cast<int64_t>(layers_.size())));
745 for (std::size_t i = 0; i < layers_.size(); ++i) {
746 archive.write(
747 key + ".layer[" + std::to_string(i) + "].in_features",
748 torch::full({1}, (int64_t)layers_[i]->options.in_features()));
749 archive.write(
750 key + ".layer[" + std::to_string(i) + "].outputs_features",
751 torch::full({1}, (int64_t)layers_[i]->options.out_features()));
752 archive.write(key + ".layer[" + std::to_string(i) + "].bias",
753 torch::full({1}, (int64_t)layers_[i]->options.bias()));
754
755 activations_[i]->write(archive, key + ".layer[" + std::to_string(i) +
756 "].activation");
757 }
758
759 return archive;
760 }
761
763 inline torch::serialize::InputArchive &
764 read(torch::serialize::InputArchive &archive,
765 const std::string &key = "iganet") {
766 torch::Tensor layers, in_features, outputs_features, bias, activation;
767
768 auto options = iganet::Options<real_t>{};
769
770 archive.read(key + ".layers", layers);
771 for (int64_t i = 0; i < layers.item<int64_t>(); ++i) {
772 archive.read(key + ".layer[" + std::to_string(i) + "].in_features",
773 in_features);
774 archive.read(key + ".layer[" + std::to_string(i) + "].outputs_features",
775 outputs_features);
776 archive.read(key + ".layer[" + std::to_string(i) + "].bias", bias);
777 layers_.emplace_back(register_module(
778 "layer[" + std::to_string(i) + "]",
779 torch::nn::Linear(
780 torch::nn::LinearOptions(in_features.item<int64_t>(),
781 outputs_features.item<int64_t>())
782 .bias(bias.item<bool>()))));
783 layers_.back()->to(options.device(), options.dtype(), true);
784
785 archive.read(key + ".layer[" + std::to_string(i) + "].activation.type",
786 activation);
787 switch (static_cast<enum activation>(activation.item<int64_t>())) {
788 case activation::none:
789 activations_.emplace_back(new None{});
790 break;
792 activations_.emplace_back(
793 new BatchNorm{torch::Tensor{}, torch::Tensor{}});
794 break;
795 case activation::celu:
796 activations_.emplace_back(new CELU{});
797 break;
798 case activation::elu:
799 activations_.emplace_back(new ELU{});
800 break;
801 case activation::gelu:
802 activations_.emplace_back(new GELU{});
803 break;
804 case activation::glu:
805 activations_.emplace_back(new GLU{});
806 break;
808 activations_.emplace_back(new GroupNorm{0});
809 break;
811 activations_.emplace_back(new GumbelSoftmax{});
812 break;
814 activations_.emplace_back(new Hardshrink{});
815 break;
817 activations_.emplace_back(new Hardsigmoid{});
818 break;
820 activations_.emplace_back(new Hardswish{});
821 break;
823 activations_.emplace_back(new Hardtanh{});
824 break;
826 activations_.emplace_back(new InstanceNorm{});
827 break;
829 activations_.emplace_back(new LayerNorm{{}});
830 break;
832 activations_.emplace_back(new LeakyReLU{});
833 break;
835 activations_.emplace_back(new LocalResponseNorm{0});
836 break;
838 activations_.emplace_back(new LogSigmoid{});
839 break;
841 activations_.emplace_back(new LogSoftmax{0});
842 break;
843 case activation::mish:
844 activations_.emplace_back(new Mish{});
845 break;
847 activations_.emplace_back(new Normalize{0, 0, 0});
848 break;
850 activations_.emplace_back(new PReLU{torch::Tensor{}});
851 break;
852 case activation::relu:
853 activations_.emplace_back(new ReLU{});
854 break;
856 activations_.emplace_back(new ReLU6{});
857 break;
859 activations_.emplace_back(new RReLU{});
860 break;
861 case activation::selu:
862 activations_.emplace_back(new SELU{});
863 break;
865 activations_.emplace_back(new Sigmoid{});
866 break;
867 case activation::silu:
868 activations_.emplace_back(new SiLU{});
869 break;
871 activations_.emplace_back(new Softmax{0});
872 break;
874 activations_.emplace_back(new Softmin{0});
875 break;
877 activations_.emplace_back(new Softplus{});
878 break;
880 activations_.emplace_back(new Softshrink{});
881 break;
883 activations_.emplace_back(new Softsign{});
884 break;
885 case activation::tanh:
886 activations_.emplace_back(new Tanh{});
887 break;
889 activations_.emplace_back(new Tanhshrink{});
890 break;
892 activations_.emplace_back(new Threshold{0, 0});
893 break;
894 default:
895 throw std::runtime_error("Invalid activation function");
896 }
897 activations_.back()->read(archive, key + ".layer[" + std::to_string(i) +
898 "].activation");
899 }
900 return archive;
901 }
902
903 inline void pretty_print(std::ostream &os) const noexcept override {
904 os << "(\n";
905
906 int i = 0;
907 for (const auto &activation : activations_)
908 os << "activation[" << i++ << "] = " << *activation << "\n";
909 os << ")\n";
910 }
911
912private:
914 std::vector<torch::nn::Linear> layers_;
915
917 std::vector<std::unique_ptr<iganet::ActivationFunction>> activations_;
918};
919
925template <typename real_t>
927 : public torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>> {
928
929public:
930 using torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>>::ModuleHolder;
932};
933
934} // namespace iganet
Activation functions.
Batch Normalization as described in the paper.
Definition activation.hpp:138
Continuously Differentiable Exponential Linear Units activation function.
Definition activation.hpp:268
Exponential Linear Units activation function.
Definition activation.hpp:345
Gaussian Error Linear Units activation function.
Definition activation.hpp:422
Grated Linear Units activation function.
Definition activation.hpp:473
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition activation.hpp:538
Gumbel-Softmax distribution activation function.
Definition activation.hpp:621
Hard shrinkish activation function.
Definition activation.hpp:700
Hardsigmoid activation function.
Definition activation.hpp:777
Hardswish activation function.
Definition activation.hpp:831
Hardtanh activation function.
Definition activation.hpp:884
IgANetGenerator.
Definition generator.hpp:927
IgANetGeneratorImpl.
Definition generator.hpp:51
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Reads the IgANet from a torch::serialize::InputArchive object.
Definition generator.hpp:764
IgANetGeneratorImpl()=default
Default constructor.
void pretty_print(std::ostream &os) const noexcept override
Definition generator.hpp:903
IgANetGeneratorImpl(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, Options< real_t > options=Options< real_t >{})
Constructor.
Definition generator.hpp:57
std::vector< std::unique_ptr< iganet::ActivationFunction > > activations_
Vector of activation functions.
Definition generator.hpp:917
torch::Tensor forward(torch::Tensor x)
Forward evaluation.
Definition generator.hpp:727
std::vector< torch::nn::Linear > layers_
Vector of linear layers.
Definition generator.hpp:914
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition generator.hpp:739
Instance Normalization as described in the paper.
Definition activation.hpp:968
Layer Normalization as described in the paper.
Definition activation.hpp:1076
Leaky ReLU activation function.
Definition activation.hpp:1170
Local response Normalization.
Definition activation.hpp:1244
LogSigmoid activation function.
Definition activation.hpp:1337
LogSoftmax activation function.
Definition activation.hpp:1387
Mish activation function.
Definition activation.hpp:1453
No-op activation function.
Definition activation.hpp:95
Lp Normalization.
Definition activation.hpp:1496
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:104
PReLU activation function.
Definition activation.hpp:1570
Randomized ReLU activation function.
Definition activation.hpp:1772
ReLU6 activation function.
Definition activation.hpp:1700
ReLU activation function.
Definition activation.hpp:1632
SELU activation function.
Definition activation.hpp:1855
Sigmoid Linear Unit activation function.
Definition activation.hpp:1966
Sigmoid activation function.
Definition activation.hpp:1923
Softmax activation function.
Definition activation.hpp:2011
Softmin activation function.
Definition activation.hpp:2081
Softplus activation function.
Definition activation.hpp:2151
Softshrink activation function.
Definition activation.hpp:2232
Softsign activation function.
Definition activation.hpp:2303
Tanh activation function.
Definition activation.hpp:2346
Tanhshrink activation function.
Definition activation.hpp:2389
Threshold activation function.
Definition activation.hpp:2436
Core components.
auto zip(T &&...seqs)
Definition zip.hpp:97
Definition core.hpp:72
nn_init
Enumerator for specifying the initialization of network weights.
Definition generator.hpp:30
activation
Enumerator for nonlinear activation functions.
Definition activation.hpp:26
short int short_t
Definition core.hpp:74
STL namespace.
Options.
Zip utility function.