IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
iganet.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <any>
18
19#include <boundary.hpp>
20#include <functionspace.hpp>
21#include <igabase.hpp>
22#include <layer.hpp>
23#include <optimizer.hpp>
24#include <utils/container.hpp>
25#include <utils/fqn.hpp>
26#include <utils/zip.hpp>
27
28namespace iganet {
29
32 TORCH_ARG(int64_t, max_epoch) = 100;
33 TORCH_ARG(int64_t, batch_size) = 1000;
34 TORCH_ARG(double, min_loss) = 1e-4;
35};
36
46template <typename real_t>
47class IgANetGeneratorImpl : public torch::nn::Module {
48public:
51
54 const std::vector<int64_t> &layers,
55 const std::vector<std::vector<std::any>> &activations,
56 Options<real_t> options = Options<real_t>{}) {
57 assert(layers.size() == activations.size() + 1);
58
59 // Generate vector of linear layers and register them as layer[i]
60 for (auto i = 0; i < layers.size() - 1; ++i) {
61 layers_.emplace_back(
62 register_module("layer[" + std::to_string(i) + "]",
63 torch::nn::Linear(layers[i], layers[i + 1])));
64 layers_.back()->to(options.device(), options.dtype(), true);
65
66 torch::nn::init::xavier_uniform_(layers_.back()->weight);
67 torch::nn::init::constant_(layers_.back()->bias, 0.0);
68 }
69
70 // Generate vector of activation functions
71 for (const auto &a : activations)
72 switch (std::any_cast<activation>(a[0])) {
73 // No activation function
75 switch (a.size()) {
76 case 1:
77 activations_.emplace_back(new None{});
78 break;
79 default:
80 throw std::runtime_error("Invalid number of parameters");
81 }
82 break;
83
84 // Batch Normalization
86 switch (a.size()) {
87 case 8:
88 activations_.emplace_back(new BatchNorm{
89 std::any_cast<torch::Tensor>(a[1]),
90 std::any_cast<torch::Tensor>(a[2]),
91 std::any_cast<torch::Tensor>(a[3]),
92 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
93 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
94 break;
95 case 7:
96 activations_.emplace_back(new BatchNorm{
97 std::any_cast<torch::Tensor>(a[1]),
98 std::any_cast<torch::Tensor>(a[2]),
99 std::any_cast<torch::Tensor>(a[3]),
100 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
101 std::any_cast<double>(a[6])});
102 break;
103 case 4:
104 activations_.emplace_back(new BatchNorm{
105 std::any_cast<torch::Tensor>(a[1]),
106 std::any_cast<torch::Tensor>(a[2]),
107 std::any_cast<torch::nn::functional::BatchNormFuncOptions>(
108 a[3])});
109 break;
110 case 3:
111 activations_.emplace_back(
112 new BatchNorm{std::any_cast<torch::Tensor>(a[1]),
113 std::any_cast<torch::Tensor>(a[2])});
114 break;
115 default:
116 throw std::runtime_error("Invalid number of parameters");
117 }
118 break;
119
120 // CELU
121 case activation::celu:
122 switch (a.size()) {
123 case 3:
124 activations_.emplace_back(
125 new CELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
126 break;
127 case 2:
128 try {
129 activations_.emplace_back(new CELU{
130 std::any_cast<torch::nn::functional::CELUFuncOptions>(a[1])});
131 } catch (...) {
132 activations_.emplace_back(new CELU{std::any_cast<double>(a[1])});
133 }
134 break;
135 case 1:
136 activations_.emplace_back(new CELU{});
137 break;
138 default:
139 throw std::runtime_error("Invalid number of parameters");
140 }
141 break;
142
143 // ELU
144 case activation::elu:
145 switch (a.size()) {
146 case 3:
147 activations_.emplace_back(
148 new ELU{std::any_cast<double>(a[1]), std::any_cast<bool>(a[2])});
149 break;
150 case 2:
151 try {
152 activations_.emplace_back(new ELU{
153 std::any_cast<torch::nn::functional::ELUFuncOptions>(a[1])});
154 } catch (...) {
155 activations_.emplace_back(new ELU{std::any_cast<double>(a[1])});
156 }
157 break;
158 case 1:
159 activations_.emplace_back(new ELU{});
160 break;
161 default:
162 throw std::runtime_error("Invalid number of parameters");
163 }
164 break;
165
166 // GELU
167 case activation::gelu:
168 switch (a.size()) {
169 case 1:
170 activations_.emplace_back(new GELU{});
171 break;
172 default:
173 throw std::runtime_error("Invalid number of parameters");
174 }
175 break;
176
177 // GLU
178 case activation::glu:
179 switch (a.size()) {
180 case 2:
181 try {
182 activations_.emplace_back(new GLU{
183 std::any_cast<torch::nn::functional::GLUFuncOptions>(a[1])});
184 } catch (...) {
185 activations_.emplace_back(new GLU{std::any_cast<int64_t>(a[1])});
186 }
187 break;
188 case 1:
189 activations_.emplace_back(new GLU{});
190 break;
191 default:
192 throw std::runtime_error("Invalid number of parameters");
193 }
194 break;
195
196 // Group Normalization
198 switch (a.size()) {
199 case 5:
200 activations_.emplace_back(new GroupNorm{
201 std::any_cast<int64_t>(a[1]), std::any_cast<torch::Tensor>(a[2]),
202 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
203 break;
204 case 2:
205 try {
206 activations_.emplace_back(new GroupNorm{
207 std::any_cast<torch::nn::functional::GroupNormFuncOptions>(
208 a[1])});
209 } catch (...) {
210 activations_.emplace_back(
211 new GroupNorm{std::any_cast<int64_t>(a[1])});
212 }
213 break;
214 default:
215 throw std::runtime_error("Invalid number of parameters");
216 }
217 break;
218
219 // Gumbel-Softmax
221 switch (a.size()) {
222 case 4:
223 activations_.emplace_back(new GumbelSoftmax{
224 std::any_cast<double>(a[1]), std::any_cast<int>(a[2]),
225 std::any_cast<bool>(a[3])});
226 break;
227 case 2:
228 activations_.emplace_back(new GumbelSoftmax{
229 std::any_cast<torch::nn::functional::GumbelSoftmaxFuncOptions>(
230 a[1])});
231 break;
232 case 1:
233 activations_.emplace_back(new GumbelSoftmax{});
234 break;
235 default:
236 throw std::runtime_error("Invalid number of parameters");
237 }
238 break;
239
240 // Hard shrinkish
242 switch (a.size()) {
243 case 2:
244 try {
245 activations_.emplace_back(new Hardshrink{
246 std::any_cast<torch::nn::functional::HardshrinkFuncOptions>(
247 a[1])});
248 } catch (...) {
249 activations_.emplace_back(
250 new Hardshrink{std::any_cast<double>(a[1])});
251 }
252 break;
253 case 1:
254 activations_.emplace_back(new Hardshrink{});
255 break;
256 default:
257 throw std::runtime_error("Invalid number of parameters");
258 }
259 break;
260
261 // Hardsigmoid
263 switch (a.size()) {
264 case 1:
265 activations_.emplace_back(new Hardsigmoid{});
266 break;
267 default:
268 throw std::runtime_error("Invalid number of parameters");
269 }
270 break;
271
272 // Hardswish
274 switch (a.size()) {
275 case 1:
276 activations_.emplace_back(new Hardswish{});
277 break;
278 default:
279 throw std::runtime_error("Invalid number of parameters");
280 }
281 break;
282
283 // Hardtanh
285 switch (a.size()) {
286 case 4:
287 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
288 std::any_cast<double>(a[2]),
289 std::any_cast<bool>(a[3])});
290 break;
291 case 3:
292 activations_.emplace_back(new Hardtanh{std::any_cast<double>(a[1]),
293 std::any_cast<double>(a[2])});
294 break;
295 case 2:
296 activations_.emplace_back(new Hardtanh{
297 std::any_cast<torch::nn::functional::HardtanhFuncOptions>(a[1])});
298 break;
299 case 1:
300 activations_.emplace_back(new Hardtanh{});
301 break;
302 default:
303 throw std::runtime_error("Invalid number of parameters");
304 }
305 break;
306
307 // Instance Normalization
309 switch (a.size()) {
310 case 8:
311 activations_.emplace_back(new InstanceNorm{
312 std::any_cast<torch::Tensor>(a[1]),
313 std::any_cast<torch::Tensor>(a[2]),
314 std::any_cast<torch::Tensor>(a[3]),
315 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
316 std::any_cast<double>(a[6]), std::any_cast<bool>(a[7])});
317 break;
318 case 7:
319 activations_.emplace_back(new InstanceNorm{
320 std::any_cast<torch::Tensor>(a[1]),
321 std::any_cast<torch::Tensor>(a[2]),
322 std::any_cast<torch::Tensor>(a[3]),
323 std::any_cast<torch::Tensor>(a[4]), std::any_cast<double>(a[5]),
324 std::any_cast<double>(a[6])});
325 break;
326 case 2:
327 activations_.emplace_back(new InstanceNorm{
328 std::any_cast<torch::nn::functional::InstanceNormFuncOptions>(
329 a[1])});
330 break;
331 case 1:
332 activations_.emplace_back(new InstanceNorm{});
333 break;
334 default:
335 throw std::runtime_error("Invalid number of parameters");
336 }
337 break;
338
339 // Layer Normalization
341 switch (a.size()) {
342 case 5:
343 activations_.emplace_back(new LayerNorm{
344 std::any_cast<std::vector<int64_t>>(a[1]),
345 std::any_cast<torch::Tensor>(a[2]),
346 std::any_cast<torch::Tensor>(a[3]), std::any_cast<double>(a[4])});
347 break;
348 case 2:
349 try {
350 activations_.emplace_back(new LayerNorm{
351 std::any_cast<torch::nn::functional::LayerNormFuncOptions>(
352 a[1])});
353 } catch (...) {
354 activations_.emplace_back(
355 new LayerNorm{std::any_cast<std::vector<int64_t>>(a[1])});
356 }
357 break;
358 default:
359 throw std::runtime_error("Invalid number of parameters");
360 }
361 break;
362
363 // Leaky ReLU
365 switch (a.size()) {
366 case 3:
367 activations_.emplace_back(new LeakyReLU{std::any_cast<double>(a[1]),
368 std::any_cast<bool>(a[2])});
369 break;
370 case 2:
371 try {
372 activations_.emplace_back(new LeakyReLU{
373 std::any_cast<torch::nn::functional::LeakyReLUFuncOptions>(
374 a[1])});
375 } catch (...) {
376 activations_.emplace_back(
377 new LeakyReLU{std::any_cast<double>(a[1])});
378 }
379 break;
380 case 1:
381 activations_.emplace_back(new LeakyReLU{});
382 break;
383 default:
384 throw std::runtime_error("Invalid number of parameters");
385 }
386 break;
387
388 // Local response Normalization
390 switch (a.size()) {
391 case 5:
392 activations_.emplace_back(new LocalResponseNorm{
393 std::any_cast<int64_t>(a[1]), std::any_cast<double>(a[2]),
394 std::any_cast<double>(a[3]), std::any_cast<double>(a[4])});
395 break;
396 case 2:
397 try {
398 activations_.emplace_back(new LocalResponseNorm{std::any_cast<
399 torch::nn::functional::LocalResponseNormFuncOptions>(a[1])});
400 } catch (...) {
401 activations_.emplace_back(
402 new LocalResponseNorm{std::any_cast<int64_t>(a[1])});
403 }
404 break;
405 default:
406 throw std::runtime_error("Invalid number of parameters");
407 }
408 break;
409
410 // LogSigmoid
412 switch (a.size()) {
413 case 1:
414 activations_.emplace_back(new LogSigmoid{});
415 break;
416 default:
417 throw std::runtime_error("Invalid number of parameters");
418 }
419 break;
420
421 // LogSoftmax
423 switch (a.size()) {
424 case 2:
425 try {
426 activations_.emplace_back(new LogSoftmax{
427 std::any_cast<torch::nn::functional::LogSoftmaxFuncOptions>(
428 a[1])});
429 } catch (...) {
430 activations_.emplace_back(
431 new LogSoftmax{std::any_cast<int64_t>(a[1])});
432 }
433 break;
434 default:
435 throw std::runtime_error("Invalid number of parameters");
436 }
437 break;
438
439 // Mish
440 case activation::mish:
441 switch (a.size()) {
442 case 1:
443 activations_.emplace_back(new Mish{});
444 break;
445 default:
446 throw std::runtime_error("Invalid number of parameters");
447 }
448 break;
449
450 // Lp Normalization
452 switch (a.size()) {
453 case 4:
454 activations_.emplace_back(new Normalize{
455 std::any_cast<double>(a[1]), std::any_cast<double>(a[2]),
456 std::any_cast<int64_t>(a[3])});
457 break;
458 case 2:
459 activations_.emplace_back(new Normalize{
460 std::any_cast<torch::nn::functional::NormalizeFuncOptions>(
461 a[1])});
462 break;
463 case 1:
464 activations_.emplace_back(new Normalize{});
465 break;
466 default:
467 throw std::runtime_error("Invalid number of parameters");
468 }
469 break;
470
471 // PReLU
473 switch (a.size()) {
474 case 2:
475 activations_.emplace_back(
476 new PReLU{std::any_cast<torch::Tensor>(a[1])});
477 break;
478 default:
479 throw std::runtime_error("Invalid number of parameters");
480 }
481 break;
482
483 // ReLU
484 case activation::relu:
485 switch (a.size()) {
486 case 2:
487 try {
488 activations_.emplace_back(new ReLU{
489 std::any_cast<torch::nn::functional::ReLUFuncOptions>(a[1])});
490 } catch (...) {
491 activations_.emplace_back(new ReLU{std::any_cast<bool>(a[1])});
492 }
493 break;
494 case 1:
495 activations_.emplace_back(new ReLU{});
496 break;
497 default:
498 throw std::runtime_error("Invalid number of parameters");
499 }
500 break;
501
502 // Relu6
504 switch (a.size()) {
505 case 2:
506 try {
507 activations_.emplace_back(new ReLU6{
508 std::any_cast<torch::nn::functional::ReLU6FuncOptions>(a[1])});
509 } catch (...) {
510 activations_.emplace_back(new ReLU6{std::any_cast<bool>(a[1])});
511 }
512 break;
513 case 1:
514 activations_.emplace_back(new ReLU6{});
515 break;
516 default:
517 throw std::runtime_error("Invalid number of parameters");
518 }
519 break;
520
521 // Randomized ReLU
523 switch (a.size()) {
524 case 4:
525 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
526 std::any_cast<double>(a[2]),
527 std::any_cast<bool>(a[3])});
528 break;
529 case 3:
530 activations_.emplace_back(new RReLU{std::any_cast<double>(a[1]),
531 std::any_cast<double>(a[2])});
532 break;
533 case 2:
534 activations_.emplace_back(new RReLU{
535 std::any_cast<torch::nn::functional::RReLUFuncOptions>(a[1])});
536 break;
537 case 1:
538 activations_.emplace_back(new RReLU{});
539 break;
540 default:
541 throw std::runtime_error("Invalid number of parameters");
542 }
543 break;
544
545 // SELU
546 case activation::selu:
547 switch (a.size()) {
548 case 2:
549 try {
550 activations_.emplace_back(new SELU{
551 std::any_cast<torch::nn::functional::SELUFuncOptions>(a[1])});
552 } catch (...) {
553 activations_.emplace_back(new SELU{std::any_cast<bool>(a[1])});
554 }
555 break;
556 case 1:
557 activations_.emplace_back(new SELU{});
558 break;
559 default:
560 throw std::runtime_error("Invalid number of parameters");
561 }
562 break;
563
564 // Sigmoid
566 switch (a.size()) {
567 case 1:
568 activations_.emplace_back(new Sigmoid{});
569 break;
570 default:
571 throw std::runtime_error("Invalid number of parameters");
572 }
573 break;
574
575 // SiLU
576 case activation::silu:
577 switch (a.size()) {
578 case 1:
579 activations_.emplace_back(new SiLU{});
580 break;
581 default:
582 throw std::runtime_error("Invalid number of parameters");
583 }
584 break;
585
586 // Softmax
588 switch (a.size()) {
589 case 2:
590 try {
591 activations_.emplace_back(new Softmax{
592 std::any_cast<torch::nn::functional::SoftmaxFuncOptions>(
593 a[1])});
594 } catch (...) {
595 activations_.emplace_back(
596 new Softmax{std::any_cast<int64_t>(a[1])});
597 }
598 break;
599 default:
600 throw std::runtime_error("Invalid number of parameters");
601 }
602 break;
603
604 // Softmin
606 switch (a.size()) {
607 case 2:
608 try {
609 activations_.emplace_back(new Softmin{
610 std::any_cast<torch::nn::functional::SoftminFuncOptions>(
611 a[1])});
612 } catch (...) {
613 activations_.emplace_back(
614 new Softmin{std::any_cast<int64_t>(a[1])});
615 }
616 break;
617 default:
618 throw std::runtime_error("Invalid number of parameters");
619 }
620 break;
621
622 // Softplus
624 switch (a.size()) {
625 case 3:
626 activations_.emplace_back(new Softplus{std::any_cast<double>(a[1]),
627 std::any_cast<double>(a[2])});
628 break;
629 case 2:
630 activations_.emplace_back(new Softplus{
631 std::any_cast<torch::nn::functional::SoftplusFuncOptions>(a[1])});
632 break;
633 case 1:
634 activations_.emplace_back(new Softplus{});
635 break;
636 default:
637 throw std::runtime_error("Invalid number of parameters");
638 }
639 break;
640
641 // Softshrink
643 switch (a.size()) {
644 case 2:
645 try {
646 activations_.emplace_back(new Softshrink{
647 std::any_cast<torch::nn::functional::SoftshrinkFuncOptions>(
648 a[1])});
649 } catch (...) {
650 activations_.emplace_back(
651 new Softshrink{std::any_cast<double>(a[1])});
652 }
653 break;
654 case 1:
655 activations_.emplace_back(new Softshrink{});
656 break;
657 default:
658 throw std::runtime_error("Invalid number of parameters");
659 }
660 break;
661
662 // Softsign
664 switch (a.size()) {
665 case 1:
666 activations_.emplace_back(new Softsign{});
667 break;
668 default:
669 throw std::runtime_error("Invalid number of parameters");
670 }
671 break;
672
673 // Tanh
674 case activation::tanh:
675 switch (a.size()) {
676 case 1:
677 activations_.emplace_back(new Tanh{});
678 break;
679 default:
680 throw std::runtime_error("Invalid number of parameters");
681 }
682 break;
683
684 // Tanhshrink
686 switch (a.size()) {
687 case 1:
688 activations_.emplace_back(new Tanhshrink{});
689 break;
690 default:
691 throw std::runtime_error("Invalid number of parameters");
692 }
693 break;
694
695 // Threshold
697 switch (a.size()) {
698 case 4:
699 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
700 std::any_cast<double>(a[2]),
701 std::any_cast<bool>(a[3])});
702 break;
703 case 3:
704 activations_.emplace_back(new Threshold{std::any_cast<double>(a[1]),
705 std::any_cast<double>(a[2])});
706 break;
707 case 2:
708 activations_.emplace_back(new Threshold{
709 std::any_cast<torch::nn::functional::ThresholdFuncOptions>(
710 a[1])});
711 break;
712 default:
713 throw std::runtime_error("Invalid number of parameters");
714 }
715 break;
716
717 default:
718 throw std::runtime_error("Invalid activation function");
719 }
720 }
721
723 torch::Tensor forward(torch::Tensor x) {
724 torch::Tensor x_in = x.clone();
725
726 // Standard feed-forward neural network
727 for (auto [layer, activation] : utils::zip(layers_, activations_))
728 x = activation->apply(layer->forward(x));
729
730 return x;
731 }
732
734 inline torch::serialize::OutputArchive &
735 write(torch::serialize::OutputArchive &archive,
736 const std::string &key = "iganet") const {
737 assert(layers_.size() == activations_.size());
738
739 archive.write(key + ".layers", torch::full({1}, (int64_t)layers_.size()));
740 for (std::size_t i = 0; i < layers_.size(); ++i) {
741 archive.write(
742 key + ".layer[" + std::to_string(i) + "].in_features",
743 torch::full({1}, (int64_t)layers_[i]->options.in_features()));
744 archive.write(
745 key + ".layer[" + std::to_string(i) + "].outputs_features",
746 torch::full({1}, (int64_t)layers_[i]->options.out_features()));
747 archive.write(key + ".layer[" + std::to_string(i) + "].bias",
748 torch::full({1}, (int64_t)layers_[i]->options.bias()));
749
750 activations_[i]->write(archive, key + ".layer[" + std::to_string(i) +
751 "].activation");
752 }
753
754 return archive;
755 }
756
758 inline torch::serialize::InputArchive &
759 read(torch::serialize::InputArchive &archive,
760 const std::string &key = "iganet") {
761 torch::Tensor layers, in_features, outputs_features, bias, activation;
762
763 archive.read(key + ".layers", layers);
764 for (int64_t i = 0; i < layers.item<int64_t>(); ++i) {
765 archive.read(key + ".layer[" + std::to_string(i) + "].in_features",
766 in_features);
767 archive.read(key + ".layer[" + std::to_string(i) + "].outputs_features",
768 outputs_features);
769 archive.read(key + ".layer[" + std::to_string(i) + "].bias", bias);
770 layers_.emplace_back(register_module(
771 "layer[" + std::to_string(i) + "]",
772 torch::nn::Linear(
773 torch::nn::LinearOptions(in_features.item<int64_t>(),
774 outputs_features.item<int64_t>())
775 .bias(bias.item<bool>()))));
776
777 archive.read(key + ".layer[" + std::to_string(i) + "].activation.type",
778 activation);
779 switch (static_cast<enum activation>(activation.item<int64_t>())) {
780 case activation::none:
781 activations_.emplace_back(new None{});
782 break;
784 activations_.emplace_back(
785 new BatchNorm{torch::Tensor{}, torch::Tensor{}});
786 break;
787 case activation::celu:
788 activations_.emplace_back(new CELU{});
789 break;
790 case activation::elu:
791 activations_.emplace_back(new ELU{});
792 break;
793 case activation::gelu:
794 activations_.emplace_back(new GELU{});
795 break;
796 case activation::glu:
797 activations_.emplace_back(new GLU{});
798 break;
800 activations_.emplace_back(new GroupNorm{0});
801 break;
803 activations_.emplace_back(new GumbelSoftmax{});
804 break;
806 activations_.emplace_back(new Hardshrink{});
807 break;
809 activations_.emplace_back(new Hardsigmoid{});
810 break;
812 activations_.emplace_back(new Hardswish{});
813 break;
815 activations_.emplace_back(new Hardtanh{});
816 break;
818 activations_.emplace_back(new InstanceNorm{});
819 break;
821 activations_.emplace_back(new LayerNorm{{}});
822 break;
824 activations_.emplace_back(new LeakyReLU{});
825 break;
827 activations_.emplace_back(new LocalResponseNorm{0});
828 break;
830 activations_.emplace_back(new LogSigmoid{});
831 break;
833 activations_.emplace_back(new LogSoftmax{0});
834 break;
835 case activation::mish:
836 activations_.emplace_back(new Mish{});
837 break;
839 activations_.emplace_back(new Normalize{0, 0, 0});
840 break;
842 activations_.emplace_back(new PReLU{torch::Tensor{}});
843 break;
844 case activation::relu:
845 activations_.emplace_back(new ReLU{});
846 break;
848 activations_.emplace_back(new ReLU6{});
849 break;
851 activations_.emplace_back(new RReLU{});
852 break;
853 case activation::selu:
854 activations_.emplace_back(new SELU{});
855 break;
857 activations_.emplace_back(new Sigmoid{});
858 break;
859 case activation::silu:
860 activations_.emplace_back(new SiLU{});
861 break;
863 activations_.emplace_back(new Softmax{0});
864 break;
866 activations_.emplace_back(new Softmin{0});
867 break;
869 activations_.emplace_back(new Softplus{});
870 break;
872 activations_.emplace_back(new Softshrink{});
873 break;
875 activations_.emplace_back(new Softsign{});
876 break;
877 case activation::tanh:
878 activations_.emplace_back(new Tanh{});
879 break;
881 activations_.emplace_back(new Tanhshrink{});
882 break;
884 activations_.emplace_back(new Threshold{0, 0});
885 break;
886 default:
887 throw std::runtime_error("Invalid activation function");
888 }
889 activations_.back()->read(archive, key + ".layer[" + std::to_string(i) +
890 "].activation");
891 }
892 return archive;
893 }
894
895 inline virtual void
896 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
897 os << "(\n";
898
899 int i = 0;
900 for (const auto &activation : activations_)
901 os << "activation[" << i++ << "] = " << *activation << "\n";
902 os << ")\n";
903 }
904
905private:
907 std::vector<torch::nn::Linear> layers_;
908
910 std::vector<std::unique_ptr<iganet::ActivationFunction>> activations_;
911};
912
918template <typename real_t>
920 : public torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>> {
921
922public:
923 using torch::nn::ModuleHolder<IgANetGeneratorImpl<real_t>>::ModuleHolder;
925};
926
930template <typename Optimizer, typename GeometryMap, typename Variable,
931 template <typename, typename> typename IgABase = ::iganet::IgABase>
933class IgANet : public IgABase<GeometryMap, Variable>,
936public:
939
941 using optimizer_type = Optimizer;
942
945
946protected:
949
951 std::unique_ptr<optimizer_type> opt_;
952
955
956public:
958 explicit IgANet(IgANetOptions defaults = {},
961 : // Construct the base class
962 Base(),
963 // Construct the optimizer
964 opt_(std::make_unique<optimizer_type>(net_->parameters())),
965 // Set options
966 options_(defaults) {}
967
972 template <std::size_t Coeffs>
973 IgANet(const std::vector<int64_t> &layers,
974 const std::vector<std::vector<std::any>> &activations,
975 std::array<int64_t, Coeffs> ncoeffs, IgANetOptions defaults = {},
978 : IgANet(layers, activations, std::tuple{ncoeffs}, std::tuple{ncoeffs},
979 defaults, options) {}
980
981 template <std::size_t... Coeffs>
982 IgANet(const std::vector<int64_t> &layers,
983 const std::vector<std::vector<std::any>> &activations,
984 std::tuple<std::array<int64_t, Coeffs>...> ncoeffs,
985 IgANetOptions defaults = {},
988 : IgANet(layers, activations, ncoeffs, ncoeffs, defaults, options) {}
990
995 template <std::size_t GeometryMapNumCoeffs, std::size_t VariableNumCoeffs>
996 IgANet(const std::vector<int64_t> &layers,
997 const std::vector<std::vector<std::any>> &activations,
998 std::array<int64_t, GeometryMapNumCoeffs> geometryMapNumCoeffs,
999 std::array<int64_t, VariableNumCoeffs> variableNumCoeffs,
1000 IgANetOptions defaults = {},
1003 : IgANet(layers, activations, std::tuple{geometryMapNumCoeffs},
1004 std::tuple{variableNumCoeffs}, defaults, options) {}
1005
1006 template <std::size_t... GeometryMapNumCoeffs,
1007 std::size_t... VariableNumCoeffs>
1009 const std::vector<int64_t> &layers,
1010 const std::vector<std::vector<std::any>> &activations,
1011 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
1012 geometryMapNumCoeffs,
1013 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
1014 IgANetOptions defaults = {},
1017 : // Construct the base class
1018 Base(geometryMapNumCoeffs, variableNumCoeffs, options),
1019 // Construct the deep neural network
1020 net_(utils::concat(std::vector<int64_t>{inputs(/* epoch */ 0).size(0)},
1021 layers,
1022 std::vector<int64_t>{Base::u_.as_tensor_size()}),
1023 activations, options),
1024
1025 // Construct the optimizer
1026 opt_(std::make_unique<optimizer_type>(net_->parameters())),
1027
1028 // Set options
1029 options_(defaults) {}
1030
1033 return net_;
1034 }
1035
1038
1040 inline const optimizer_type &optimizer() const { return *opt_; }
1041
1043 inline optimizer_type &optimizer() { return *opt_; }
1044
1048 inline void optimizerReset(bool resetOptions = true) {
1049 if (resetOptions)
1050 opt_ = std::make_unique<optimizer_type>(net_->parameters());
1051 else {
1052 std::vector<optimizer_options_type> options;
1053 for (auto & group : opt_->param_groups())
1054 options.push_back(static_cast<optimizer_options_type&>(group.options()));
1055 opt_ = std::make_unique<optimizer_type>(net_->parameters());
1056 for (auto [group, options] : utils::zip(opt_->param_groups(), options))
1057 static_cast<optimizer_options_type&>(group.options()) = options;
1058 }
1059 }
1060
1063 opt_ = std::make_unique<optimizer_type>(net_->parameters(), optimizerOptions);
1064 }
1065
1067 inline optimizer_options_type &optimizerOptions(std::size_t param_group = 0) {
1068 if (param_group < opt_->param_groups().size())
1069 return static_cast<optimizer_options_type&>(opt_->param_groups()[param_group].options());
1070 else
1071 throw std::runtime_error("Index exceeds number of parameter groups");
1072 }
1073
1075 inline const optimizer_options_type &optimizerOptions(std::size_t param_group = 0) const {
1076 if (param_group < opt_->param_groups().size())
1077 return static_cast<optimizer_options_type&>(opt_->param_groups()[param_group].options());
1078 else
1079 throw std::runtime_error("Index exceeds number of parameter groups");
1080 }
1081
1084 for (auto &group : opt_->param_groups())
1085 static_cast<optimizer_options_type&>(group.options()) = options;
1086 }
1087
1090 for (auto &group : opt_->param_groups())
1091 static_cast<optimizer_options_type&>(group.options()) = options;
1092 }
1093
1095 inline void optimizerOptionsReset(const optimizer_options_type& options, std::size_t param_group) {
1096 if (param_group < opt_->param_groups().size())
1097 static_cast<optimizer_options_type&>(opt_->param_group().options()) = options;
1098 else
1099 throw std::runtime_error("Index exceeds number of parameter groups");
1100 }
1101
1103 inline void optimizerOptionsReset(optimizer_options_type&& options, std::size_t param_group) {
1104 if (param_group < opt_->param_groups().size())
1105 static_cast<optimizer_options_type&>(opt_->param_group().options()) = options;
1106 else
1107 throw std::runtime_error("Index exceeds number of parameter groups");
1108 }
1109
1111 inline const auto &options() const { return options_; }
1112
1114 inline auto &options() { return options_; }
1115
1122 virtual torch::Tensor inputs(int64_t epoch) const {
1124 return torch::cat({Base::G_.as_tensor(), Base::f_.as_tensor()});
1125 else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1126 return Base::G_.as_tensor();
1127 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1128 return Base::f_.as_tensor();
1129 else
1130 return torch::empty({0});
1131 }
1132
1134 virtual bool epoch(int64_t) = 0;
1135
1137 virtual torch::Tensor loss(const torch::Tensor &, int64_t) = 0;
1138
1140 virtual void train(
1141#ifdef IGANET_WITH_MPI
1142 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
1143 c10d::ProcessGroupMPI::createProcessGroupMPI()
1144#endif
1145 ) {
1146 torch::Tensor inputs, outputs, loss;
1147 typename Base::value_type previous_loss(-1.0);
1148
1149 // Loop over epochs
1150 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
1151
1152 // Update epoch and inputs
1153 if (this->epoch(epoch))
1154 inputs = this->inputs(epoch);
1155
1156 auto closure = [&]() {
1157 // Reset gradients
1158 net_->zero_grad();
1159
1160 // Execute the model on the inputs
1161 outputs = net_->forward(inputs);
1162
1163 // Compute the loss value
1164 loss = this->loss(outputs, epoch);
1165
1166 // Compute gradients of the loss w.r.t. the model parameters
1167 loss.backward({}, true, false);
1168
1169 return loss;
1170 };
1171
1172#ifdef IGANET_WITH_MPI
1173 // Averaging the gradients of the parameters in all the processors
1174 // Note: This may lag behind DistributedDataParallel (DDP) in performance
1175 // since this synchronizes parameters after backward pass while DDP
1176 // overlaps synchronizing parameters and computing gradients in backward
1177 // pass
1178 std::vector<c10::intrusive_ptr<::c10d::Work>> works;
1179 for (auto &param : net_->named_parameters()) {
1180 std::vector<torch::Tensor> tmp = {param.value().grad()};
1181 works.emplace_back(pg->allreduce(tmp));
1182 }
1183
1184 waitWork(pg, works);
1185
1186 for (auto &param : net_->named_parameters()) {
1187 param.value().grad().data() =
1188 param.value().grad().data() / pg->getSize();
1189 }
1190#endif
1191
1192 // Update the parameters based on the calculated gradients
1193 opt_->step(closure);
1194
1195 typename Base::value_type current_loss = loss.template item<typename Base::value_type>();
1196 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
1197 << current_loss
1198 << std::endl;
1199
1200 if (current_loss <
1201 options_.min_loss()) {
1202 Log(log::info) << "Total epochs: " << epoch << ", loss: "
1203 << current_loss
1204 << std::endl;
1205 break;
1206 }
1207
1208 if (current_loss == previous_loss || std::abs(current_loss-previous_loss) < previous_loss/10) {
1209 Log(log::info) << "Total epochs: " << epoch << ", loss: "
1210 << current_loss
1211 << std::endl;
1212 break;
1213 }
1214
1215 if (loss.isnan().template item<bool>()) {
1216 Log(log::info) << "Total epochs: " << epoch << ", loss: "
1217 << current_loss
1218 << std::endl;
1219 break;
1220 }
1221 previous_loss = current_loss;
1222 }
1223 }
1224
1226 template <typename DataLoader>
1227 void train(DataLoader &loader
1228#ifdef IGANET_WITH_MPI
1229 ,
1230 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
1231 c10d::ProcessGroupMPI::createProcessGroupMPI()
1232#endif
1233 ) {
1234 torch::Tensor inputs, outputs, loss;
1235 typename Base::value_type previous_loss(-1.0);
1236
1237 // Loop over epochs
1238 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
1239
1240 typename Base::value_type Loss(0);
1241
1242 for (auto &batch : loader) {
1243 inputs = batch.data;
1244
1245 if (inputs.dim() > 0) {
1246 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
1247 Base::G_.from_tensor(
1248 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
1249 Base::f_.from_tensor(inputs
1250 .slice(1, Base::G_.as_tensor_size(),
1251 Base::G_.as_tensor_size() +
1252 Base::f_.as_tensor_size())
1253 .t());
1254 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1255 Base::G_.from_tensor(
1256 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
1257 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1258 Base::f_.from_tensor(
1259 inputs.slice(1, 0, Base::f_.as_tensor_size()).t());
1260
1261 } else {
1262 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
1263 Base::G_.from_tensor(
1264 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
1265 Base::f_.from_tensor(inputs
1266 .slice(1, Base::G_.as_tensor_size(),
1267 Base::G_.as_tensor_size() +
1268 Base::f_.as_tensor_size())
1269 .flatten());
1270 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
1271 Base::G_.from_tensor(
1272 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
1273 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1274 Base::f_.from_tensor(
1275 inputs.slice(1, 0, Base::f_.as_tensor_size()).flatten());
1276 }
1277
1278 this->epoch(epoch);
1279
1280 auto closure = [&]() {
1281 // Reset gradients
1282 net_->zero_grad();
1283
1284 // Execute the model on the inputs
1285 outputs = net_->forward(inputs);
1286
1287 // Compute the loss value
1288 loss = this->loss(outputs, epoch);
1289
1290 // Compute gradients of the loss w.r.t. the model parameters
1291 loss.backward({}, true, false);
1292
1293 return loss;
1294 };
1295
1296 // Update the parameters based on the calculated gradients
1297 opt_->step(closure);
1298
1299 Loss += loss.template item<typename Base::value_type>();
1300 }
1301
1302 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": " << Loss
1303 << std::endl;
1304
1305 if (Loss < options_.min_loss()) {
1306 Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1307 << std::endl;
1308 break;
1309 }
1310
1311 if (Loss == previous_loss) {
1312 Log(log::info) << "Total epochs: " << epoch << ", loss: " << Loss
1313 << std::endl;
1314 break;
1315 }
1316 previous_loss = Loss;
1317
1318 if (epoch == options_.max_epoch() - 1)
1319 Log(log::warning) << "Total epochs: " << epoch << ", loss: " << Loss
1320 << std::endl;
1321 }
1322 }
1323
1325 void eval() {
1326 torch::Tensor inputs = this->inputs(0);
1327 torch::Tensor outputs = net_->forward(inputs);
1328 Base::u_.from_tensor(outputs);
1329 }
1330
1332 inline virtual nlohmann::json to_json() const override {
1333 return "Not implemented yet";
1334 }
1335
1337 inline std::vector<torch::Tensor> parameters() const noexcept {
1338 return net_->parameters();
1339 }
1340
1343 inline torch::OrderedDict<std::string, torch::Tensor>
1344 named_parameters() const noexcept {
1345 return net_->named_parameters();
1346 }
1347
1349 inline std::size_t nparameters() const noexcept {
1350 std::size_t result = 0;
1351 for (const auto &param : this->parameters()) {
1352 result += param.numel();
1353 }
1354 return result;
1355 }
1356
1358 inline virtual void
1359 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1360 os << name() << "(\n"
1361 << "net = " << net_ << "\n";
1362 if constexpr (Base::has_GeometryMap)
1363 os << "G = " << Base::G_ << "\n";
1364 if constexpr (Base::has_RefData)
1365 os << "f = " << Base::f_ << "\n";
1366 if constexpr (Base::has_Solution)
1367 os << "u = " << Base::u_ << "\n)";
1368 }
1369
1371 inline void save(const std::string &filename,
1372 const std::string &key = "iganet") const {
1373 torch::serialize::OutputArchive archive;
1374 write(archive, key).save_to(filename);
1375 }
1376
1378 inline void load(const std::string &filename,
1379 const std::string &key = "iganet") {
1380 torch::serialize::InputArchive archive;
1381 archive.load_from(filename);
1382 read(archive, key);
1383 }
1384
1386 inline torch::serialize::OutputArchive &
1387 write(torch::serialize::OutputArchive &archive,
1388 const std::string &key = "iganet") const {
1389 if constexpr (Base::has_GeometryMap)
1390 Base::G_.write(archive, key + ".geo");
1391 if constexpr (Base::has_RefData)
1392 Base::f_.write(archive, key + ".ref");
1393 if constexpr (Base::has_Solution)
1394 Base::u_.write(archive, key + ".out");
1395
1396 net_->write(archive, key + ".net");
1397 torch::serialize::OutputArchive archive_net;
1398 net_->save(archive_net);
1399 archive.write(key + ".net.data", archive_net);
1400
1401 torch::serialize::OutputArchive archive_opt;
1402 opt_->save(archive_opt);
1403 archive.write(key + ".opt", archive_opt);
1404
1405 return archive;
1406 }
1407
1409 inline torch::serialize::InputArchive &
1410 read(torch::serialize::InputArchive &archive,
1411 const std::string &key = "iganet") {
1412 if constexpr (Base::has_GeometryMap)
1413 Base::G_.read(archive, key + ".geo");
1414 if constexpr (Base::has_RefData)
1415 Base::f_.read(archive, key + ".ref");
1416 if constexpr (Base::has_Solution)
1417 Base::u_.read(archive, key + ".out");
1418
1419 net_->read(archive, key + ".net");
1420 torch::serialize::InputArchive archive_net;
1421 archive.read(key + ".net.data", archive_net);
1422 net_->load(archive_net);
1423
1424 opt_->add_parameters(net_->parameters());
1425 torch::serialize::InputArchive archive_opt;
1426 archive.read(key + ".opt", archive_opt);
1427 opt_->load(archive_opt);
1428
1429 return archive;
1430 }
1431
1433 bool operator==(const IgANet &other) const {
1434 bool result(true);
1435
1436 if constexpr (Base::has_GeometryMap)
1437 result *= (Base::G_ == other.G());
1438 if constexpr (Base::has_RefData)
1439 result *= (Base::f_ == other.f());
1440 if constexpr (Base::has_Solution)
1441 result *= (Base::u_ == other.u());
1442
1443 return result;
1444 }
1445
1447 bool operator!=(const IgANet &other) const { return *this != other; }
1448
1449#ifdef IGANET_WITH_MPI
1450private:
1452 static void waitWork(c10::intrusive_ptr<c10d::ProcessGroupMPI> pg,
1453 std::vector<c10::intrusive_ptr<c10d::Work>> works) {
1454 for (auto &work : works) {
1455 try {
1456 work->wait();
1457 } catch (const std::exception &ex) {
1458 Log(log::error) << "Exception received during waitWork: " << ex.what()
1459 << std::endl;
1460 pg->abort();
1461 }
1462 }
1463 }
1464#endif
1465};
1466
1468 template <typename Optimizer, typename GeometryMap, typename Variable>
1469 requires OptimizerType<Optimizer> && FunctionSpaceType<GeometryMap> && FunctionSpaceType<Variable>
1470inline std::ostream &
1471operator<<(std::ostream &os,
1473 obj.pretty_print(os);
1474 return os;
1475}
1476
1482 template <typename GeometryMap, typename Variable>
1483 requires FunctionSpaceType<GeometryMap> && FunctionSpaceType<Variable>
1485public:
1488 decltype(std::declval<GeometryMap>()
1489 .template find_knot_indices<functionspace::interior>(
1490 std::declval<typename GeometryMap::eval_type>()));
1491
1494 decltype(std::declval<GeometryMap>()
1495 .template find_knot_indices<functionspace::boundary>(
1496 std::declval<
1497 typename GeometryMap::boundary_eval_type>()));
1498
1501 decltype(std::declval<Variable>()
1502 .template find_knot_indices<functionspace::interior>(
1503 std::declval<typename Variable::eval_type>()));
1504
1507 decltype(std::declval<Variable>()
1508 .template find_knot_indices<functionspace::boundary>(
1509 std::declval<typename Variable::boundary_eval_type>()));
1510
1513 decltype(std::declval<GeometryMap>()
1514 .template find_coeff_indices<functionspace::interior>(
1515 std::declval<typename GeometryMap::eval_type>()));
1516
1519 decltype(std::declval<GeometryMap>()
1520 .template find_coeff_indices<functionspace::boundary>(
1521 std::declval<
1522 typename GeometryMap::boundary_eval_type>()));
1523
1526 decltype(std::declval<Variable>()
1527 .template find_coeff_indices<functionspace::interior>(
1528 std::declval<typename Variable::eval_type>()));
1529
1532 decltype(std::declval<Variable>()
1533 .template find_coeff_indices<functionspace::boundary>(
1534 std::declval<typename Variable::boundary_eval_type>()));
1535};
1536
1537} // namespace iganet
Boundary treatment.
Batch Normalization as described in the paper.
Definition layer.hpp:134
Continuously Differentiable Exponential Linear Units activation function.
Definition layer.hpp:264
Exponential Linear Units activation function.
Definition layer.hpp:341
Gaussian Error Linear Units activation function.
Definition layer.hpp:417
Grated Linear Units activation function.
Definition layer.hpp:468
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition layer.hpp:532
Gumbel-Softmax distribution activation function.
Definition layer.hpp:615
Hard shrinkish activation function.
Definition layer.hpp:692
Hardsigmoid activation function.
Definition layer.hpp:769
Hardswish activation function.
Definition layer.hpp:822
Hardtanh activation function.
Definition layer.hpp:875
IgA base class.
Definition igabase.hpp:459
Variable f_
Spline representation of the reference data.
Definition igabase.hpp:490
typename Base::value_type value_type
Value type.
Definition igabase.hpp:465
static bool constexpr has_GeometryMap
Indicates whether this class provides a geometry map.
Definition igabase.hpp:480
static bool constexpr has_Solution
Indicates whether this class provides a solution.
Definition igabase.hpp:486
static bool constexpr has_RefData
Indicates whether this class provides a reference solution.
Definition igabase.hpp:483
GeometryMap G_
Spline representation of the geometry map.
Definition igabase.hpp:73
Variable u_
Spline representation of the solution.
Definition igabase.hpp:76
IgANetGenerator.
Definition iganet.hpp:920
IgANetGeneratorImpl.
Definition iganet.hpp:47
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Reads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:759
IgANetGeneratorImpl()=default
Default constructor.
IgANetGeneratorImpl(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, Options< real_t > options=Options< real_t >{})
Constructor.
Definition iganet.hpp:53
std::vector< std::unique_ptr< iganet::ActivationFunction > > activations_
Vector of activation functions.
Definition iganet.hpp:910
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Definition iganet.hpp:896
torch::Tensor forward(torch::Tensor x)
Forward evaluation.
Definition iganet.hpp:723
std::vector< torch::nn::Linear > layers_
Vector of linear layers.
Definition iganet.hpp:907
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:735
IgANet.
Definition iganet.hpp:935
const auto & options() const
Returns a constant reference to the options structure.
Definition iganet.hpp:1111
void optimizerOptionsReset(optimizer_options_type &&options)
Resets the optimizer options.
Definition iganet.hpp:1089
void save(const std::string &filename, const std::string &key="iganet") const
Saves the IgANet to file.
Definition iganet.hpp:1371
void load(const std::string &filename, const std::string &key="iganet")
Loads the IgANet from file.
Definition iganet.hpp:1378
torch::OrderedDict< std::string, torch::Tensor > named_parameters() const noexcept
Returns a constant reference to the named parameters of the IgANet object.
Definition iganet.hpp:1344
void optimizerOptionsReset(const optimizer_options_type &options)
Resets the optimizer options.
Definition iganet.hpp:1083
virtual torch::Tensor loss(const torch::Tensor &, int64_t)=0
Computes the loss function.
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Loads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:1410
IgABase< GeometryMap, Variable > Base
Base type.
Definition iganet.hpp:938
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:1008
std::vector< torch::Tensor > parameters() const noexcept
Returns a constant reference to the parameters of the IgANet object.
Definition iganet.hpp:1337
std::unique_ptr< optimizer_type > opt_
Optimizer.
Definition iganet.hpp:951
std::size_t nparameters() const noexcept
Returns the total number of parameters of the IgANet object.
Definition iganet.hpp:1349
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, Coeffs >... > ncoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:982
optimizer_type & optimizer()
Returns a non-constant reference to the optimizer.
Definition iganet.hpp:1043
virtual void train()
Trains the IgANet.
Definition iganet.hpp:1140
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the IgANet object.
Definition iganet.hpp:1359
Optimizer optimizer_type
Type of the optimizer.
Definition iganet.hpp:941
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:1387
void optimizerOptionsReset(optimizer_options_type &&options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:1103
virtual torch::Tensor inputs(int64_t epoch) const
Returns the network inputs.
Definition iganet.hpp:1122
auto & options()
Returns a non-constant reference to the options structure.
Definition iganet.hpp:1114
void optimizerReset(bool resetOptions=true)
Resets the optimizer.
Definition iganet.hpp:1048
IgANetGenerator< typename Base::value_type > & net()
Returns a non-constant reference to the IgANet generator.
Definition iganet.hpp:1037
const IgANetGenerator< typename Base::value_type > & net() const
Returns a constant reference to the IgANet generator.
Definition iganet.hpp:1032
const optimizer_options_type & optimizerOptions(std::size_t param_group=0) const
Returns a constant reference to the optimizer options.
Definition iganet.hpp:1075
virtual bool epoch(int64_t)=0
Initializes epoch.
IgANet(IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Default constructor.
Definition iganet.hpp:958
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, GeometryMapNumCoeffs > geometryMapNumCoeffs, std::array< int64_t, VariableNumCoeffs > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:996
typename optimizer_options_type< Optimizer >::type optimizer_options_type
Type of the optimizer options.
Definition iganet.hpp:944
void train(DataLoader &loader)
Trains the IgANet.
Definition iganet.hpp:1227
IgANetOptions options_
Options.
Definition iganet.hpp:954
virtual nlohmann::json to_json() const override
Returns the IgANet object as JSON object.
Definition iganet.hpp:1332
IgANetGenerator< typename Base::value_type > net_
IgANet generator.
Definition iganet.hpp:948
void optimizerOptionsReset(const optimizer_options_type &options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:1095
void eval()
Evaluate IgANet.
Definition iganet.hpp:1325
optimizer_options_type & optimizerOptions(std::size_t param_group=0)
Returns a non-constant reference to the optimizer options.
Definition iganet.hpp:1067
const optimizer_type & optimizer() const
Returns a constant reference to the optimizer.
Definition iganet.hpp:1040
bool operator==(const IgANet &other) const
Returns true if both IgANet objects are the same.
Definition iganet.hpp:1433
void optimizerReset(const optimizer_options_type &optimizerOptions)
Resets the optimizer.
Definition iganet.hpp:1062
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, Coeffs > ncoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:973
bool operator!=(const IgANet &other) const
Returns true if both IgANet objects are different.
Definition iganet.hpp:1447
Instance Normalization as described in the paper.
Definition layer.hpp:958
Layer Normalization as described in the paper.
Definition layer.hpp:1064
Leaky ReLU activation function.
Definition layer.hpp:1159
Local response Normalization.
Definition layer.hpp:1234
LogSigmoid activation function.
Definition layer.hpp:1326
LogSoftmax activation function.
Definition layer.hpp:1377
Mish activation function.
Definition layer.hpp:1444
No-op activation function.
Definition layer.hpp:92
Lp Normalization.
Definition layer.hpp:1487
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:107
PReLU activation function.
Definition layer.hpp:1562
Randomized ReLU activation function.
Definition layer.hpp:1764
ReLU6 activation function.
Definition layer.hpp:1692
ReLU activation function.
Definition layer.hpp:1624
SELU activation function.
Definition layer.hpp:1847
Sigmoid Linear Unit activation function.
Definition layer.hpp:1959
Sigmoid activation function.
Definition layer.hpp:1915
Softmax activation function.
Definition layer.hpp:2004
Softmin activation function.
Definition layer.hpp:2075
Softplus activation function.
Definition layer.hpp:2146
Softshrink activation function.
Definition layer.hpp:2228
Softsign activation function.
Definition layer.hpp:2300
Tanh activation function.
Definition layer.hpp:2344
Tanhshrink activation function.
Definition layer.hpp:2387
Threshold activation function.
Definition layer.hpp:2435
Full qualified name descriptor.
Definition fqn.hpp:26
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:31
Concept to identify template parameters that are derived from iganet::details::FunctionSpaceType.
Definition functionspace.hpp:3117
Concept to identify template parameters that are derived from torch::optim::Optimizer.
Definition optimizer.hpp:21
Container utility functions.
Full qualified name utility functions.
Function spaces.
Isogeometric analysis base class.
Network layer.
auto zip(T &&...seqs)
Definition zip.hpp:97
Definition boundary.hpp:22
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_knot_indices_type
Type of the knot indices of the variables in the interior.
Definition iganet.hpp:1503
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_coeff_indices_type
Type of the coefficient indices of geometry type at the boundary.
Definition iganet.hpp:1522
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_knot_indices_type
Type of the knot indices of boundary_eval_type type at the boundary.
Definition iganet.hpp:1509
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_coeff_indices_type
Type of the coefficient indices of variable type in the interior.
Definition iganet.hpp:1528
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_knot_indices_type
Type of the knot indices of the geometry map at the boundary.
Definition iganet.hpp:1497
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_coeff_indices_type
Type of the coefficient indices of geometry type in the interior.
Definition iganet.hpp:1515
struct iganet::@0 Log
Logger.
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_coeff_indices_type
Type of the coefficient indices of variable type at the boundary.
Definition iganet.hpp:1534
activation
Enumerator for nonlinear activation functions.
Definition layer.hpp:23
std::ostream & operator<<(std::ostream &os, const Boundary< Spline > &obj)
Print (as string) a Boundary object.
Definition boundary.hpp:1963
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_knot_indices_type
Type of the knot indices of the geometry map in the interior.
Definition iganet.hpp:1490
IgANetCustomizable.
Definition iganet.hpp:1484
STL namespace.
IgANetOptions.
Definition iganet.hpp:31
TORCH_ARG(int64_t, batch_size)
TORCH_ARG(double, min_loss)
TORCH_ARG(int64_t, max_epoch)
Serialization prototype.
Definition serialize.hpp:31
Zip utility function.