IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
iganet.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <any>
18
19#include <boundary.hpp>
20#include <functionspace.hpp>
21#include <igabase.hpp>
22#include <layer.hpp>
23#include <utils/container.hpp>
24#include <utils/fqn.hpp>
25#include <utils/zip.hpp>
26
27namespace iganet {
28
35
45template <typename real_t>
46class IgANetGeneratorImpl : public torch::nn::Module {
47public:
50
53 const std::vector<int64_t> &layers,
54 const std::vector<std::vector<std::any>> &activations,
55 Options<real_t> options = Options<real_t>{}) {
56 assert(layers.size() == activations.size() + 1);
57
58 // Generate vector of linear layers and register them as layer[i]
59 for (auto i = 0; i < layers.size() - 1; ++i) {
60 layers_.emplace_back(
61 register_module("layer[" + std::to_string(i) + "]",
62 torch::nn::Linear(layers[i], layers[i + 1])));
63 layers_.back()->to(options.device(), options.dtype(), true);
64
65 torch::nn::init::xavier_uniform_(layers_.back()->weight);
66 torch::nn::init::constant_(layers_.back()->bias, 0.0);
67 }
68
69 // Generate vector of activation functions
70 for (const auto &a : activations)
72 // No activation function
74 switch (a.size()) {
75 case 1:
76 activations_.emplace_back(new None{});
77 break;
78 default:
79 throw std::runtime_error("Invalid number of parameters");
80 }
81 break;
82
83 // Batch Normalization
85 switch (a.size()) {
86 case 8:
87 activations_.emplace_back(new BatchNorm{
88 std::any_cast<torch::Tensor>(a[1]),
89 std::any_cast<torch::Tensor>(a[2]),
90 std::any_cast<torch::Tensor>(a[3]),
91 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
92 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
93 break;
94 case 7:
95 activations_.emplace_back(new BatchNorm{
96 std::any_cast<torch::Tensor>(a[1]),
97 std::any_cast<torch::Tensor>(a[2]),
98 std::any_cast<torch::Tensor>(a[3]),
99 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
100 std::any_cast<double>(a[6])});
101 break;
102 case 4:
103 activations_.emplace_back(new BatchNorm{
104 std::any_cast<torch::Tensor>(a[1]),
105 std::any_cast<torch::Tensor>(a[2]),
106 std::any_cast<torch::nn::functional::BatchNormFuncOptions>(
107 a[3])});
108 break;
109 case 3:
110 activations_.emplace_back(
111 new BatchNorm{std::any_cast<torch::Tensor>(a[1]),
112 std::any_cast<torch::Tensor>(a[2])});
113 break;
114 default:
115 throw std::runtime_error("Invalid number of parameters");
116 }
117 break;
118
119 // CELU
120 case activation::celu:
121 switch (a.size()) {
122 case 3:
123 activations_.emplace_back(
124 new CELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
125 break;
126 case 2:
127 try {
128 activations_.emplace_back(new CELU{
129 std::any_cast<torch::nn::functional::CELUFuncOptions>(a[1])});
130 } catch (...) {
131 activations_.emplace_back(new CELU{std::any_cast<double>(a[1])});
132 }
133 break;
134 case 1:
135 activations_.emplace_back(new CELU{});
136 break;
137 default:
138 throw std::runtime_error("Invalid number of parameters");
139 }
140 break;
141
142 // ELU
143 case activation::elu:
144 switch (a.size()) {
145 case 3:
146 activations_.emplace_back(
147 new ELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
148 break;
149 case 2:
150 try {
151 activations_.emplace_back(new ELU{
152 std::any_cast<torch::nn::functional::ELUFuncOptions>(a[1])});
153 } catch (...) {
154 activations_.emplace_back(new ELU{std::any_cast<double>(a[1])});
155 }
156 break;
157 case 1:
158 activations_.emplace_back(new ELU{});
159 break;
160 default:
161 throw std::runtime_error("Invalid number of parameters");
162 }
163 break;
164
165 // GELU
166 case activation::gelu:
167 switch (a.size()) {
168 case 1:
169 activations_.emplace_back(new GELU{});
170 break;
171 default:
172 throw std::runtime_error("Invalid number of parameters");
173 }
174 break;
175
176 // GLU
177 case activation::glu:
178 switch (a.size()) {
179 case 2:
180 try {
181 activations_.emplace_back(new GLU{
182 std::any_cast<torch::nn::functional::GLUFuncOptions>(a[1])});
183 } catch (...) {
184 activations_.emplace_back(new GLU{std::any_cast<int64_t>(a[1])});
185 }
186 break;
187 case 1:
188 activations_.emplace_back(new GLU{});
189 break;
190 default:
191 throw std::runtime_error("Invalid number of parameters");
192 }
193 break;
194
195 // Group Normalization
197 switch (a.size()) {
198 case 5:
199 activations_.emplace_back(new GroupNorm{
200 std::any_cast<int64_t>(a[1]), std::any_cast<torch::Tensor>(a[2]),
201 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
202 break;
203 case 2:
204 try {
205 activations_.emplace_back(new GroupNorm{
206 std::any_cast<torch::nn::functional::GroupNormFuncOptions>(
207 a[1])});
208 } catch (...) {
209 activations_.emplace_back(
210 new GroupNorm{std::any_cast<int64_t>(a[1])});
211 }
212 break;
213 default:
214 throw std::runtime_error("Invalid number of parameters");
215 }
216 break;
217
218 // Gumbel-Softmax
220 switch (a.size()) {
221 case 4:
222 activations_.emplace_back(new GumbelSoftmax{
223 std::any_cast<double>(a[1]), std::any_cast<int>(a[2]),
224 std::any_cast<bool>(a[3])});
225 break;
226 case 2:
227 activations_.emplace_back(new GumbelSoftmax{
228 std::any_cast<torch::nn::functional::GumbelSoftmaxFuncOptions>(
229 a[1])});
230 break;
231 case 1:
232 activations_.emplace_back(new GumbelSoftmax{});
233 break;
234 default:
235 throw std::runtime_error("Invalid number of parameters");
236 }
237 break;
238
239 // Hard shrinkish
241 switch (a.size()) {
242 case 2:
243 try {
244 activations_.emplace_back(new Hardshrink{
245 std::any_cast<torch::nn::functional::HardshrinkFuncOptions>(
246 a[1])});
247 } catch (...) {
248 activations_.emplace_back(
249 new Hardshrink{std::any_cast<double>(a[1])});
250 }
251 break;
252 case 1:
253 activations_.emplace_back(new Hardshrink{});
254 break;
255 default:
256 throw std::runtime_error("Invalid number of parameters");
257 }
258 break;
259
260 // Hardsigmoid
262 switch (a.size()) {
263 case 1:
264 activations_.emplace_back(new Hardsigmoid{});
265 break;
266 default:
267 throw std::runtime_error("Invalid number of parameters");
268 }
269 break;
270
271 // Hardswish
273 switch (a.size()) {
274 case 1:
275 activations_.emplace_back(new Hardswish{});
276 break;
277 default:
278 throw std::runtime_error("Invalid number of parameters");
279 }
280 break;
281
282 // Hardtanh
284 switch (a.size()) {
285 case 4:
286 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
287 std::any_cast<double>(a[2]),
288 std::any_cast<bool>(a[3])});
289 break;
290 case 3:
291 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
292 std::any_cast<double>(a[2])});
293 break;
294 case 2:
295 activations_.emplace_back(new Hardtanh{
296 std::any_cast<torch::nn::functional::HardtanhFuncOptions>(a[1])});
297 break;
298 case 1:
299 activations_.emplace_back(new Hardtanh{});
300 break;
301 default:
302 throw std::runtime_error("Invalid number of parameters");
303 }
304 break;
305
306 // Instance Normalization
308 switch (a.size()) {
309 case 8:
310 activations_.emplace_back(new InstanceNorm{
311 std::any_cast<torch::Tensor>(a[1]),
312 std::any_cast<torch::Tensor>(a[2]),
313 std::any_cast<torch::Tensor>(a[3]),
314 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
315 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
316 break;
317 case 7:
318 activations_.emplace_back(new InstanceNorm{
319 std::any_cast<torch::Tensor>(a[1]),
320 std::any_cast<torch::Tensor>(a[2]),
321 std::any_cast<torch::Tensor>(a[3]),
322 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
323 std::any_cast<double>(a[6])});
324 break;
325 case 2:
326 activations_.emplace_back(new InstanceNorm{
327 std::any_cast<torch::nn::functional::InstanceNormFuncOptions>(
328 a[1])});
329 break;
330 case 1:
331 activations_.emplace_back(new InstanceNorm{});
332 break;
333 default:
334 throw std::runtime_error("Invalid number of parameters");
335 }
336 break;
337
338 // Layer Normalization
340 switch (a.size()) {
341 case 5:
342 activations_.emplace_back(new LayerNorm{
343 std::any_cast<std::vector<int64_t>>(a[1]),
344 std::any_cast<torch::Tensor>(a[2]),
345 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
346 break;
347 case 2:
348 try {
349 activations_.emplace_back(new LayerNorm{
350 std::any_cast<torch::nn::functional::LayerNormFuncOptions>(
351 a[1])});
352 } catch (...) {
353 activations_.emplace_back(
354 new LayerNorm{std::any_cast<std::vector<int64_t>>(a[1])});
355 }
356 break;
357 default:
358 throw std::runtime_error("Invalid number of parameters");
359 }
360 break;
361
362 // Leaky ReLU
364 switch (a.size()) {
365 case 3:
366 activations_.emplace_back(new LeakyReLU{std::any_cast<double>(a[1]),
367 std::any_cast<bool>(a[2])});
368 break;
369 case 2:
370 try {
371 activations_.emplace_back(new LeakyReLU{
372 std::any_cast<torch::nn::functional::LeakyReLUFuncOptions>(
373 a[1])});
374 } catch (...) {
375 activations_.emplace_back(
376 new LeakyReLU{std::any_cast<double>(a[1])});
377 }
378 break;
379 case 1:
380 activations_.emplace_back(new LeakyReLU{});
381 break;
382 default:
383 throw std::runtime_error("Invalid number of parameters");
384 }
385 break;
386
387 // Local response Normalization
389 switch (a.size()) {
390 case 5:
391 activations_.emplace_back(new LocalResponseNorm{
392 std::any_cast<int64_t>(a[1]), std::any_cast<double>(a[2]),
393 std::any_cast<double>(a[3]), std::any_cast<double>(a[4])});
394 break;
395 case 2:
396 try {
397 activations_.emplace_back(new LocalResponseNorm{std::any_cast<
398 torch::nn::functional::LocalResponseNormFuncOptions>(a[1])});
399 } catch (...) {
400 activations_.emplace_back(
401 new LocalResponseNorm{std::any_cast<int64_t>(a[1])});
402 }
403 break;
404 default:
405 throw std::runtime_error("Invalid number of parameters");
406 }
407 break;
408
409 // LogSigmoid
411 switch (a.size()) {
412 case 1:
413 activations_.emplace_back(new LogSigmoid{});
414 break;
415 default:
416 throw std::runtime_error("Invalid number of parameters");
417 }
418 break;
419
420 // LogSoftmax
422 switch (a.size()) {
423 case 2:
424 try {
425 activations_.emplace_back(new LogSoftmax{
426 std::any_cast<torch::nn::functional::LogSoftmaxFuncOptions>(
427 a[1])});
428 } catch (...) {
429 activations_.emplace_back(
430 new LogSoftmax{std::any_cast<int64_t>(a[1])});
431 }
432 break;
433 default:
434 throw std::runtime_error("Invalid number of parameters");
435 }
436 break;
437
438 // Mish
439 case activation::mish:
440 switch (a.size()) {
441 case 1:
442 activations_.emplace_back(new Mish{});
443 break;
444 default:
445 throw std::runtime_error("Invalid number of parameters");
446 }
447 break;
448
449 // Lp Normalization
451 switch (a.size()) {
452 case 4:
453 activations_.emplace_back(new Normalize{
454 std::any_cast<double>(a[1]), std::any_cast<double>(a[2]),
455 std::any_cast<int64_t>(a[3])});
456 break;
457 case 2:
458 activations_.emplace_back(new Normalize{
459 std::any_cast<torch::nn::functional::NormalizeFuncOptions>(
460 a[1])});
461 break;
462 case 1:
463 activations_.emplace_back(new Normalize{});
464 break;
465 default:
466 throw std::runtime_error("Invalid number of parameters");
467 }
468 break;
469
470 // PReLU
472 switch (a.size()) {
473 case 2:
474 activations_.emplace_back(
475 new PReLU{std::any_cast<torch::Tensor>(a[1])});
476 break;
477 default:
478 throw std::runtime_error("Invalid number of parameters");
479 }
480 break;
481
482 // ReLU
483 case activation::relu:
484 switch (a.size()) {
485 case 2:
486 try {
487 activations_.emplace_back(new ReLU{
488 std::any_cast<torch::nn::functional::ReLUFuncOptions>(a[1])});
489 } catch (...) {
490 activations_.emplace_back(new ReLU{std::any_cast<bool>(a[1])});
491 }
492 break;
493 case 1:
494 activations_.emplace_back(new ReLU{});
495 break;
496 default:
497 throw std::runtime_error("Invalid number of parameters");
498 }
499 break;
500
501 // Relu6
503 switch (a.size()) {
504 case 2:
505 try {
506 activations_.emplace_back(new ReLU6{
507 std::any_cast<torch::nn::functional::ReLU6FuncOptions>(a[1])});
508 } catch (...) {
509 activations_.emplace_back(new ReLU6{std::any_cast<bool>(a[1])});
510 }
511 break;
512 case 1:
513 activations_.emplace_back(new ReLU6{});
514 break;
515 default:
516 throw std::runtime_error("Invalid number of parameters");
517 }
518 break;
519
520 // Randomized ReLU
522 switch (a.size()) {
523 case 4:
524 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
525 std::any_cast<double>(a[2]),
526 std::any_cast<bool>(a[3])});
527 break;
528 case 3:
529 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
530 std::any_cast<double>(a[2])});
531 break;
532 case 2:
533 activations_.emplace_back(new RReLU{
534 std::any_cast<torch::nn::functional::RReLUFuncOptions>(a[1])});
535 break;
536 case 1:
537 activations_.emplace_back(new RReLU{});
538 break;
539 default:
540 throw std::runtime_error("Invalid number of parameters");
541 }
542 break;
543
544 // SELU
545 case activation::selu:
546 switch (a.size()) {
547 case 2:
548 try {
549 activations_.emplace_back(new SELU{
550 std::any_cast<torch::nn::functional::SELUFuncOptions>(a[1])});
551 } catch (...) {
552 activations_.emplace_back(new SELU{std::any_cast<bool>(a[1])});
553 }
554 break;
555 case 1:
556 activations_.emplace_back(new SELU{});
557 break;
558 default:
559 throw std::runtime_error("Invalid number of parameters");
560 }
561 break;
562
563 // Sigmoid
565 switch (a.size()) {
566 case 1:
567 activations_.emplace_back(new Sigmoid{});
568 break;
569 default:
570 throw std::runtime_error("Invalid number of parameters");
571 }
572 break;
573
574 // SiLU
575 case activation::silu:
576 switch (a.size()) {
577 case 1:
578 activations_.emplace_back(new SiLU{});
579 break;
580 default:
581 throw std::runtime_error("Invalid number of parameters");
582 }
583 break;
584
585 // Softmax
587 switch (a.size()) {
588 case 2:
589 try {
590 activations_.emplace_back(new Softmax{
591 std::any_cast<torch::nn::functional::SoftmaxFuncOptions>(
592 a[1])});
593 } catch (...) {
594 activations_.emplace_back(
595 new Softmax{std::any_cast<int64_t>(a[1])});
596 }
597 break;
598 default:
599 throw std::runtime_error("Invalid number of parameters");
600 }
601 break;
602
603 // Softmin
605 switch (a.size()) {
606 case 2:
607 try {
608 activations_.emplace_back(new Softmin{
609 std::any_cast<torch::nn::functional::SoftminFuncOptions>(
610 a[1])});
611 } catch (...) {
612 activations_.emplace_back(
613 new Softmin{std::any_cast<int64_t>(a[1])});
614 }
615 break;
616 default:
617 throw std::runtime_error("Invalid number of parameters");
618 }
619 break;
620
621 // Softplus
623 switch (a.size()) {
624 case 3:
625 activations_.emplace_back(new Softplus{std::any_cast<double>(a[1]),
626 std::any_cast<double>(a[2])});
627 break;
628 case 2:
629 activations_.emplace_back(new Softplus{
630 std::any_cast<torch::nn::functional::SoftplusFuncOptions>(a[1])});
631 break;
632 case 1:
633 activations_.emplace_back(new Softplus{});
634 break;
635 default:
636 throw std::runtime_error("Invalid number of parameters");
637 }
638 break;
639
640 // Softshrink
642 switch (a.size()) {
643 case 2:
644 try {
645 activations_.emplace_back(new Softshrink{
646 std::any_cast<torch::nn::functional::SoftshrinkFuncOptions>(
647 a[1])});
648 } catch (...) {
649 activations_.emplace_back(
650 new Softshrink{std::any_cast<double>(a[1])});
651 }
652 break;
653 case 1:
654 activations_.emplace_back(new Softshrink{});
655 break;
656 default:
657 throw std::runtime_error("Invalid number of parameters");
658 }
659 break;
660
661 // Softsign
663 switch (a.size()) {
664 case 1:
665 activations_.emplace_back(new Softsign{});
666 break;
667 default:
668 throw std::runtime_error("Invalid number of parameters");
669 }
670 break;
671
672 // Tanh
673 case activation::tanh:
674 switch (a.size()) {
675 case 1:
676 activations_.emplace_back(new Tanh{});
677 break;
678 default:
679 throw std::runtime_error("Invalid number of parameters");
680 }
681 break;
682
683 // Tanhshrink
685 switch (a.size()) {
686 case 1:
687 activations_.emplace_back(new Tanhshrink{});
688 break;
689 default:
690 throw std::runtime_error("Invalid number of parameters");
691 }
692 break;
693
694 // Threshold
696 switch (a.size()) {
697 case 4:
698 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
699 std::any_cast<double>(a[2]),
700 std::any_cast<bool>(a[3])});
701 break;
702 case 3:
703 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
704 std::any_cast<double>(a[2])});
705 break;
706 case 2:
707 activations_.emplace_back(new Threshold{
708 std::any_cast<torch::nn::functional::ThresholdFuncOptions>(
709 a[1])});
710 break;
711 default:
712 throw std::runtime_error("Invalid number of parameters");
713 }
714 break;
715
716 default:
717 throw std::runtime_error("Invalid activation function");
718 }
719 }
720
722 torch::Tensor forward(torch::Tensor x) {
723 torch::Tensor x_in = x.clone();
724
725 // Standard feed-forward neural network
727 x = activation->apply(layer->forward(x));
728
729 return x;
730 }
731
733 inline torch::serialize::OutputArchive &
734 write(torch::serialize::OutputArchive &archive,
735 const std::string &key = "iganet") const {
736 assert(layers_.size() == activations_.size());
737
738 archive.write(key + ".layers", torch::full({1}, (int64_t)layers_.size()));
739 for (std::size_t i = 0; i < layers_.size(); ++i) {
740 archive.write(
741 key + ".layer[" + std::to_string(i) + "].in_features",
742 torch::full({1}, (int64_t)layers_[i]->options.in_features()));
743 archive.write(
744 key + ".layer[" + std::to_string(i) + "].outputs_features",
745 torch::full({1}, (int64_t)layers_[i]->options.out_features()));
746 archive.write(key + ".layer[" + std::to_string(i) + "].bias",
747 torch::full({1}, (int64_t)layers_[i]->options.bias()));
748
749 activations_[i]->write(archive, key + ".layer[" + std::to_string(i) +
750 "].activation");
751 }
752
753 return archive;
754 }
755
757 inline torch::serialize::InputArchive &
758 read(torch::serialize::InputArchive &archive,
759 const std::string &key = "iganet") {
761
762 archive.read(key + ".layers", layers);
763 for (int64_t i = 0; i < layers.item<int64_t>(); ++i) {
764 archive.read(key + ".layer[" + std::to_string(i) + "].in_features",
766 archive.read(key + ".layer[" + std::to_string(i) + "].outputs_features",
768 archive.read(key + ".layer[" + std::to_string(i) + "].bias", bias);
769 layers_.emplace_back(register_module(
770 "layer[" + std::to_string(i) + "]",
771 torch::nn::Linear(
772 torch::nn::LinearOptions(in_features.item<int64_t>(),
774 .bias(bias.item<bool>()))));
775
776 archive.read(key + ".layer[" + std::to_string(i) + "].activation.type",
777 activation);
778 switch (static_cast<enum activation>(activation.item<int64_t>())) {
779 case activation::none:
780 activations_.emplace_back(new None{});
781 break;
783 activations_.emplace_back(
784 new BatchNorm{torch::Tensor{}, torch::Tensor{}});
785 break;
786 case activation::celu:
787 activations_.emplace_back(new CELU{});
788 break;
789 case activation::elu:
790 activations_.emplace_back(new ELU{});
791 break;
792 case activation::gelu:
793 activations_.emplace_back(new GELU{});
794 break;
795 case activation::glu:
796 activations_.emplace_back(new GLU{});
797 break;
799 activations_.emplace_back(new GroupNorm{0});
800 break;
802 activations_.emplace_back(new GumbelSoftmax{});
803 break;
805 activations_.emplace_back(new Hardshrink{});
806 break;
808 activations_.emplace_back(new Hardsigmoid{});
809 break;
811 activations_.emplace_back(new Hardswish{});
812 break;
814 activations_.emplace_back(new Hardtanh{});
815 break;
817 activations_.emplace_back(new InstanceNorm{});
818 break;
820 activations_.emplace_back(new LayerNorm{{}});
821 break;
823 activations_.emplace_back(new LeakyReLU{});
824 break;
826 activations_.emplace_back(new LocalResponseNorm{0});
827 break;
829 activations_.emplace_back(new LogSigmoid{});
830 break;
832 activations_.emplace_back(new LogSoftmax{0});
833 break;
834 case activation::mish:
835 activations_.emplace_back(new Mish{});
836 break;
838 activations_.emplace_back(new Normalize{0, 0, 0});
839 break;
841 activations_.emplace_back(new PReLU{torch::Tensor{}});
842 break;
843 case activation::relu:
844 activations_.emplace_back(new ReLU{});
845 break;
847 activations_.emplace_back(new ReLU6{});
848 break;
850 activations_.emplace_back(new RReLU{});
851 break;
852 case activation::selu:
853 activations_.emplace_back(new SELU{});
854 break;
856 activations_.emplace_back(new Sigmoid{});
857 break;
858 case activation::silu:
859 activations_.emplace_back(new SiLU{});
860 break;
862 activations_.emplace_back(new Softmax{0});
863 break;
865 activations_.emplace_back(new Softmin{0});
866 break;
868 activations_.emplace_back(new Softplus{});
869 break;
871 activations_.emplace_back(new Softshrink{});
872 break;
874 activations_.emplace_back(new Softsign{});
875 break;
876 case activation::tanh:
877 activations_.emplace_back(new Tanh{});
878 break;
880 activations_.emplace_back(new Tanhshrink{});
881 break;
883 activations_.emplace_back(new Threshold{0, 0});
884 break;
885 default:
886 throw std::runtime_error("Invalid activation function");
887 }
888 activations_.back()->read(archive, key + ".layer[" + std::to_string(i) +
889 "].activation");
890 }
891 return archive;
892 }
893
894 inline virtual void
895 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
896 os << "(\n";
897
898 int i = 0;
899 for (const auto &activation : activations_)
900 os << "activation[" << i++ << "]: " << *activation << "\n";
901 os << ")\n";
902 }
903
904private:
906 std::vector<torch::nn::Linear> layers_;
907
909 std::vector<std::unique_ptr<iganet::ActivationFunction>> activations_;
910};
911
917template <typename real_t>
919 : public torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>> {
920
921public:
922 using torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>>::ModuleHolder;
924};
925
929template <typename Optimizer, typename GeometryMap, typename Variable,
930 template <typename, typename> typename IgABase = ::iganet::IgABase>
931class IgANet : public IgABase<GeometryMap, Variable>,
934public:
937
940
941protected:
944
947
950
951public:
957
962 template <std::size_t Coeffs>
963 IgANet(const std::vector<int64_t> &layers,
964 const std::vector<std::vector<std::any>> &activations,
965 std::array<int64_t, Coeffs> ncoeffs, IgANetOptions defaults = {},
968 : IgANet(layers, activations, std::tuple{ncoeffs}, std::tuple{ncoeffs},
969 defaults, options) {}
970
971 template <std::size_t... Coeffs>
972 IgANet(const std::vector<int64_t> &layers,
973 const std::vector<std::vector<std::any>> &activations,
974 std::tuple<std::array<int64_t, Coeffs>...> ncoeffs,
978 : IgANet(layers, activations, ncoeffs, ncoeffs, defaults, options) {}
980
985 template <std::size_t GeometryMapNumCoeffs, std::size_t VariableNumCoeffs>
986 IgANet(const std::vector<int64_t> &layers,
987 const std::vector<std::vector<std::any>> &activations,
988 std::array<int64_t, GeometryMapNumCoeffs> geometryMapNumCoeffs,
989 std::array<int64_t, VariableNumCoeffs> variableNumCoeffs,
995
996 template <std::size_t... GeometryMapNumCoeffs,
997 std::size_t... VariableNumCoeffs>
999 const std::vector<int64_t> &layers,
1000 const std::vector<std::vector<std::any>> &activations,
1001 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
1003 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
1008 // Construct the deep neural network
1009 net_(utils::concat(std::vector<int64_t>{inputs(/* epoch */ 0).size(0)},
1010 layers,
1011 std::vector<int64_t>{Base::u_.as_tensor_size()}),
1013
1014 // Construct the optimizer
1015 opt_(net_->parameters()),
1016
1017 // Set options
1018 options_(defaults) {}
1019
1022 return net_;
1023 }
1024
1027
1029 inline const Optimizer &opt() const { return opt_; }
1030
1032 inline Optimizer &opt() { return opt_; }
1033
1035 inline const auto &options() const { return options_; }
1036
1038 inline auto &options() { return options_; }
1039
1046 virtual torch::Tensor inputs(int64_t epoch) const {
1048 return torch::cat({Base::G_.as_tensor(), Base::f_.as_tensor()});
1049 else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1050 return Base::G_.as_tensor();
1051 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1052 return Base::f_.as_tensor();
1053 else
1054 return torch::empty({0});
1055 }
1056
1058 virtual bool epoch(int64_t) = 0;
1059
1061 virtual torch::Tensor loss(const torch::Tensor &, int64_t) = 0;
1062
1064 virtual void train(
1066 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
1067 c10d::ProcessGroupMPI::createProcessGroupMPI()
1068#endif
1069 ) {
1070 torch::Tensor inputs, outputs, loss;
1071
1072 // Loop over epochs
1073 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
1074
1075 // Update epoch and inputs
1076 if (this->epoch(epoch))
1077 inputs = this->inputs(epoch);
1078
1079 auto closure = [&]() {
1080 // Reset gradients
1081 net_->zero_grad();
1082
1083 // Execute the model on the inputs
1084 outputs = net_->forward(inputs);
1085
1086 // Compute the loss value
1087 loss = this->loss(outputs, epoch);
1088
1089 // Compute gradients of the loss w.r.t. the model parameters
1090 loss.backward({}, true, false);
1091
1092 return loss;
1093 };
1094
1095#ifdef IGANET_WITH_MPI
1096 // Averaging the gradients of the parameters in all the processors
1097 // Note: This may lag behind DistributedDataParallel (DDP) in performance
1098 // since this synchronizes parameters after backward pass while DDP
1099 // overlaps synchronizing parameters and computing gradients in backward
1100 // pass
1101 std::vector<c10::intrusive_ptr<::c10d::Work>> works;
1102 for (auto &param : net_->named_parameters()) {
1103 std::vector<torch::Tensor> tmp = {param.value().grad()};
1104 works.emplace_back(pg->allreduce(tmp));
1105 }
1106
1107 waitWork(pg, works);
1108
1109 for (auto &param : net_->named_parameters()) {
1110 param.value().grad().data() =
1111 param.value().grad().data() / pg->getSize();
1112 }
1113#endif
1114
1115 // Update the parameters based on the calculated gradients
1116 opt_.step(closure);
1117
1118 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
1120 << std::endl;
1121
1122 if (loss.template item<typename Base::value_type>() <
1123 options_.min_loss()) {
1124 Log(log::info) << "Total epochs: " << epoch << ", loss: "
1126 << std::endl;
1127 break;
1128 }
1129 }
1130 }
1131
1133 template <typename DataLoader>
1136 ,
1137 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
1138 c10d::ProcessGroupMPI::createProcessGroupMPI()
1139#endif
1140 ) {
1141 torch::Tensor inputs, outputs, loss;
1142
1143 // Loop over epochs
1144 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
1145
1146 typename Base::value_type Loss(0);
1147
1148 for (auto &batch : loader) {
1149 inputs = batch.data;
1150
1151 if (inputs.dim() > 0) {
1152 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
1153 Base::G_.from_tensor(
1154 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
1155 Base::f_.from_tensor(inputs
1156 .slice(1, Base::G_.as_tensor_size(),
1157 Base::G_.as_tensor_size() +
1158 Base::f_.as_tensor_size())
1159 .t());
1160 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1161 Base::G_.from_tensor(
1162 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
1163 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1164 Base::f_.from_tensor(
1165 inputs.slice(1, 0, Base::f_.as_tensor_size()).t());
1166
1167 } else {
1168 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
1169 Base::G_.from_tensor(
1170 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
1171 Base::f_.from_tensor(inputs
1172 .slice(1, Base::G_.as_tensor_size(),
1173 Base::G_.as_tensor_size() +
1174 Base::f_.as_tensor_size())
1175 .flatten());
1176 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1177 Base::G_.from_tensor(
1178 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
1179 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1180 Base::f_.from_tensor(
1181 inputs.slice(1, 0, Base::f_.as_tensor_size()).flatten());
1182 }
1183
1184 this->epoch(epoch);
1185
1186 auto closure = [&]() {
1187 // Reset gradients
1188 net_->zero_grad();
1189
1190 // Execute the model on the inputs
1191 outputs = net_->forward(inputs);
1192
1193 // Compute the loss value
1194 loss = this->loss(outputs, epoch);
1195
1196 // Compute gradients of the loss w.r.t. the model parameters
1197 loss.backward({}, true, false);
1198
1199 return loss;
1200 };
1201
1202 // Update the parameters based on the calculated gradients
1203 opt_.step(closure);
1204
1206 }
1207
1208 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": " << Loss
1209 << std::endl;
1210
1211 if (Loss < options_.min_loss()) {
1212 Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1213 << std::endl;
1214 break;
1215 }
1216
1217 if (epoch == options_.max_epoch() - 1)
1218 Log(log::warning) << "Total epochs: " << epoch << ", loss: " << Loss
1219 << std::endl;
1220 }
1221 }
1222
1224 void eval() {
1225 torch::Tensor inputs = this->inputs(0);
1226 torch::Tensor outputs = net_->forward(inputs);
1227 Base::u_.from_tensor(outputs);
1228 }
1229
1231 inline virtual nlohmann::json to_json() const override {
1232 return "Not implemented yet";
1233 }
1234
1236 inline std::vector<torch::Tensor> parameters() const noexcept {
1237 return net_->parameters();
1238 }
1239
1242 inline torch::OrderedDict<std::string, torch::Tensor>
1244 return net_->named_parameters();
1245 }
1246
1248 inline std::size_t nparameters() const noexcept {
1249 std::size_t result = 0;
1250 for (const auto &param : this->parameters()) {
1251 result += param.numel();
1252 }
1253 return result;
1254 }
1255
1257 inline virtual void
1258 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1259 os << name() << "(\n"
1260 << "net = " << net_ << "\n";
1261 if constexpr (Base::has_GeometryMap)
1262 os << "G = " << Base::G_ << "\n";
1263 if constexpr (Base::has_RefData)
1264 os << "f = " << Base::f_ << "\n";
1265 if constexpr (Base::has_Solution)
1266 os << "u = " << Base::u_ << "\n)";
1267 }
1268
1270 inline void save(const std::string &filename,
1271 const std::string &key = "iganet") const {
1272 torch::serialize::OutputArchive archive;
1273 write(archive, key).save_to(filename);
1274 }
1275
1277 inline void load(const std::string &filename,
1278 const std::string &key = "iganet") {
1279 torch::serialize::InputArchive archive;
1280 archive.load_from(filename);
1281 read(archive, key);
1282 }
1283
1285 inline torch::serialize::OutputArchive &
1286 write(torch::serialize::OutputArchive &archive,
1287 const std::string &key = "iganet") const {
1288 if constexpr (Base::has_GeometryMap)
1289 Base::G_.write(archive, key + ".geo");
1290 if constexpr (Base::has_RefData)
1291 Base::f_.write(archive, key + ".ref");
1292 if constexpr (Base::has_Solution)
1293 Base::u_.write(archive, key + ".out");
1294
1295 net_->write(archive, key + ".net");
1296 torch::serialize::OutputArchive archive_net;
1297 net_->save(archive_net);
1298 archive.write(key + ".net.data", archive_net);
1299
1300 torch::serialize::OutputArchive archive_opt;
1301 opt_.save(archive_opt);
1302 archive.write(key + ".opt", archive_opt);
1303
1304 return archive;
1305 }
1306
1308 inline torch::serialize::InputArchive &
1309 read(torch::serialize::InputArchive &archive,
1310 const std::string &key = "iganet") {
1311 if constexpr (Base::has_GeometryMap)
1312 Base::G_.read(archive, key + ".geo");
1313 if constexpr (Base::has_RefData)
1314 Base::f_.read(archive, key + ".ref");
1315 if constexpr (Base::has_Solution)
1316 Base::u_.read(archive, key + ".out");
1317
1318 net_->read(archive, key + ".net");
1319 torch::serialize::InputArchive archive_net;
1320 archive.read(key + ".net.data", archive_net);
1321 net_->load(archive_net);
1322
1323 opt_.add_parameters(net_->parameters());
1324 torch::serialize::InputArchive archive_opt;
1325 archive.read(key + ".opt", archive_opt);
1326 opt_.load(archive_opt);
1327
1328 return archive;
1329 }
1330
1332 bool operator==(const IgANet &other) const {
1333 bool result(true);
1334
1335 if constexpr (Base::has_GeometryMap)
1336 result *= (Base::G_ == other.G());
1337 if constexpr (Base::has_RefData)
1338 result *= (Base::f_ == other.f());
1339 if constexpr (Base::has_Solution)
1340 result *= (Base::u_ == other.u());
1341
1342 return result;
1343 }
1344
1346 bool operator!=(const IgANet &other) const { return *this != other; }
1347
1348#ifdef IGANET_WITH_MPI
1349private:
1351 static void waitWork(c10::intrusive_ptr<c10d::ProcessGroupMPI> pg,
1352 std::vector<c10::intrusive_ptr<c10d::Work>> works) {
1353 for (auto &work : works) {
1354 try {
1355 work->wait();
1356 } catch (const std::exception &ex) {
1357 Log(log::error) << "Exception received during waitWork: " << ex.what()
1358 << std::endl;
1359 pg->abort();
1360 }
1361 }
1362 }
1363#endif
1364};
1365
1367template <typename Optimizer, typename GeometryMap, typename Variable>
1368inline std::ostream &
1369operator<<(std::ostream &os,
1371 obj.pretty_print(os);
1372 return os;
1373}
1374
1380template <typename GeometryMap, typename Variable> class IgANetCustomizable {
1381public:
1384 decltype(std::declval<GeometryMap>()
1386 std::declval<typename GeometryMap::eval_type>()));
1387
1390 decltype(std::declval<GeometryMap>()
1392 std::declval<
1393 typename GeometryMap::boundary_eval_type>()));
1394
1397 decltype(std::declval<Variable>()
1399 std::declval<typename Variable::eval_type>()));
1400
1403 decltype(std::declval<Variable>()
1405 std::declval<typename Variable::boundary_eval_type>()));
1406
1409 decltype(std::declval<GeometryMap>()
1411 std::declval<typename GeometryMap::eval_type>()));
1412
1415 decltype(std::declval<GeometryMap>()
1417 std::declval<
1418 typename GeometryMap::boundary_eval_type>()));
1419
1422 decltype(std::declval<Variable>()
1424 std::declval<typename Variable::eval_type>()));
1425
1428 decltype(std::declval<Variable>()
1430 std::declval<typename Variable::boundary_eval_type>()));
1431};
1432
1433} // namespace iganet
Boundary treatment.
Definition unittest_iganet.cxx:24
Batch Normalization as described in the paper.
Definition layer.hpp:134
Continuously Differentiable Exponential Linear Units activation function.
Definition layer.hpp:264
Exponential Linear Units activation function.
Definition layer.hpp:341
Gaussian Error Linear Units activation function.
Definition layer.hpp:417
Grated Linear Units activation function.
Definition layer.hpp:468
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition layer.hpp:532
Gumbel-Softmax distribution activation function.
Definition layer.hpp:615
Hard shrinkish activation function.
Definition layer.hpp:692
Hardsigmoid activation function.
Definition layer.hpp:769
Hardswish activation function.
Definition layer.hpp:822
Hardtanh activation function.
Definition layer.hpp:875
IgA base class.
Definition igabase.hpp:456
Variable f_
Spline representation of the reference data.
Definition igabase.hpp:487
typename Base::value_type value_type
Value type.
Definition igabase.hpp:462
static bool constexpr has_GeometryMap
Indicates whether this class provides a geometry map.
Definition igabase.hpp:477
static bool constexpr has_Solution
Indicates whether this class provides a solution.
Definition igabase.hpp:483
static bool constexpr has_RefData
Indicates whether this class provides a reference solution.
Definition igabase.hpp:480
GeometryMap G_
Spline representation of the geometry map.
Definition igabase.hpp:71
Variable u_
Spline representation of the solution.
Definition igabase.hpp:74
IgANetCustomizable.
Definition iganet.hpp:1380
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_knot_indices_type
Type of the knot indices of the variables in the interior.
Definition iganet.hpp:1399
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_coeff_indices_type
Type of the coefficient indices of geometry type at the boundary.
Definition iganet.hpp:1418
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_knot_indices_type
Type of the knot indices of boundary_eval_type type at the boundary.
Definition iganet.hpp:1405
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_coeff_indices_type
Type of the coefficient indices of variable type in the interior.
Definition iganet.hpp:1424
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_knot_indices_type
Type of the knot indices of the geometry map at the boundary.
Definition iganet.hpp:1393
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_coeff_indices_type
Type of the coefficient indices of geometry type in the interior.
Definition iganet.hpp:1411
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_coeff_indices_type
Type of the coefficient indices of variable type at the boundary.
Definition iganet.hpp:1430
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_knot_indices_type
Type of the knot indices of the geometry map in the interior.
Definition iganet.hpp:1386
IgANetGenerator.
Definition iganet.hpp:919
IgANetGeneratorImpl.
Definition iganet.hpp:46
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Reads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:758
IgANetGeneratorImpl()=default
Default constructor.
IgANetGeneratorImpl(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, Options< real_t > options=Options< real_t >{})
Constructor.
Definition iganet.hpp:52
std::vector< std::unique_ptr< iganet::ActivationFunction > > activations_
Vector of activation functions.
Definition iganet.hpp:909
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Definition iganet.hpp:895
torch::Tensor forward(torch::Tensor x)
Forward evaluation.
Definition iganet.hpp:722
std::vector< torch::nn::Linear > layers_
Vector of linear layers.
Definition iganet.hpp:906
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:734
IgANet.
Definition iganet.hpp:933
const auto & options() const
Returns a constant reference to the options structure.
Definition iganet.hpp:1035
void save(const std::string &filename, const std::string &key="iganet") const
Saves the IgANet to file.
Definition iganet.hpp:1270
void load(const std::string &filename, const std::string &key="iganet")
Loads the IgANet from file.
Definition iganet.hpp:1277
Optimizer opt_
Optimizer.
Definition iganet.hpp:946
torch::OrderedDict< std::string, torch::Tensor > named_parameters() const noexcept
Returns a constant reference to the named parameters of the IgANet object.
Definition iganet.hpp:1243
virtual torch::Tensor loss(const torch::Tensor &, int64_t)=0
Computes the loss function.
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Loads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:1309
IgABase< GeometryMap, Variable > Base
Base type.
Definition iganet.hpp:936
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:998
std::vector< torch::Tensor > parameters() const noexcept
Returns a constant reference to the parameters of the IgANet object.
Definition iganet.hpp:1236
std::size_t nparameters() const noexcept
Returns the total number of parameters of the IgANet object.
Definition iganet.hpp:1248
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, Coeffs >... > ncoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:972
virtual void train()
Trains the IgANet.
Definition iganet.hpp:1064
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the IgANet object.
Definition iganet.hpp:1258
Optimizer optimizer_type
Type of the optimizer.
Definition iganet.hpp:939
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:1286
virtual torch::Tensor inputs(int64_t epoch) const
Returns the network inputs.
Definition iganet.hpp:1046
auto & options()
Returns a non-constant reference to the options structure.
Definition iganet.hpp:1038
const Optimizer & opt() const
Returns a constant reference to the optimizer.
Definition iganet.hpp:1029
Optimizer & opt()
Returns a non-constant reference to the optimizer.
Definition iganet.hpp:1032
IgANetGenerator< typename Base::value_type > & net()
Returns a non-constant reference to the IgANet generator.
Definition iganet.hpp:1026
const IgANetGenerator< typename Base::value_type > & net() const
Returns a constant reference to the IgANet generator.
Definition iganet.hpp:1021
virtual bool epoch(int64_t)=0
Initializes epoch.
IgANet(IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Default constructor.
Definition iganet.hpp:953
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, GeometryMapNumCoeffs > geometryMapNumCoeffs, std::array< int64_t, VariableNumCoeffs > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:986
void train(DataLoader &loader)
Trains the IgANet.
Definition iganet.hpp:1134
IgANetOptions options_
Options.
Definition iganet.hpp:949
virtual nlohmann::json to_json() const override
Returns the IgANet object as JSON object.
Definition iganet.hpp:1231
IgANetGenerator< typename Base::value_type > net_
IgANet generator.
Definition iganet.hpp:943
void eval()
Evaluate IgANet.
Definition iganet.hpp:1224
bool operator==(const IgANet &other) const
Returns true if both IgANet objects are the same.
Definition iganet.hpp:1332
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, Coeffs > ncoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:963
bool operator!=(const IgANet &other) const
Returns true if both IgANet objects are different.
Definition iganet.hpp:1346
Instance Normalization as described in the paper.
Definition layer.hpp:958
Layer Normalization as described in the paper.
Definition layer.hpp:1064
Leaky ReLU activation function.
Definition layer.hpp:1159
Local response Normalization.
Definition layer.hpp:1234
LogSigmoid activation function.
Definition layer.hpp:1326
LogSoftmax activation function.
Definition layer.hpp:1377
Mish activation function.
Definition layer.hpp:1444
No-op activation function.
Definition layer.hpp:92
Lp Normalization.
Definition layer.hpp:1487
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:90
PReLU activation function.
Definition layer.hpp:1562
Randomized ReLU activation function.
Definition layer.hpp:1764
ReLU6 activation function.
Definition layer.hpp:1692
ReLU activation function.
Definition layer.hpp:1624
SELU activation function.
Definition layer.hpp:1847
Sigmoid Linear Unit activation function.
Definition layer.hpp:1959
Sigmoid activation function.
Definition layer.hpp:1915
Softmax activation function.
Definition layer.hpp:2004
Softmin activation function.
Definition layer.hpp:2075
Softplus activation function.
Definition layer.hpp:2146
Softshrink activation function.
Definition layer.hpp:2228
Softsign activation function.
Definition layer.hpp:2300
Tanh activation function.
Definition layer.hpp:2344
Tanhshrink activation function.
Definition layer.hpp:2387
Threshold activation function.
Definition layer.hpp:2435
Full qualified name descriptor.
Definition fqn.hpp:26
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:31
Container utility functions.
Full qualified name utility functions.
Function spaces.
Isogeometric analysis base class.
Network layer.
auto zip(T &&...seqs)
Definition zip.hpp:97
Definition boundary.hpp:22
constexpr bool is_SplineType_v
Alias to the value of is_SplineType.
Definition bspline.hpp:3243
std::ostream & operator<<(std::ostream &os, const Boundary< Spline > &obj)
Print (as string) a Boundary object.
Definition boundary.hpp:1978
struct iganet::@0 Log
Logger.
activation
Enumerator for nonlinear activation functions.
Definition layer.hpp:23
STL namespace.
IgANetOptions.
Definition iganet.hpp:30
TORCH_ARG(int64_t, batch_size)
TORCH_ARG(double, min_loss)
TORCH_ARG(int64_t, max_epoch)
Serialization prototype.
Definition serialize.hpp:31
Zip utility function.