IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
iganet.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core/options.hpp>
18#include <net/generator.hpp>
19#include <net/collocation.hpp>
20#include <net/optimizer.hpp>
21#include <splines/boundary.hpp>
23#include <utils/container.hpp>
24#include <utils/fqn.hpp>
25#include <utils/tuple.hpp>
26#include <utils/zip.hpp>
27
28namespace iganet::v1 {
29
32 TORCH_ARG(int64_t, max_epoch) = 100;
33 TORCH_ARG(int64_t, batch_size) = 1000;
34 TORCH_ARG(double, min_loss) = 1e-4;
35 TORCH_ARG(double, min_loss_change) = 0;
36 TORCH_ARG(double, min_loss_rel_change) = 1e-3;
37 };
38
43template <typename GeometryMap, typename Variable>
45class [[deprecated("Use novel IgANet implementation")]] IgABaseNoRefData {
46public:
48 using value_type = std::common_type_t<typename GeometryMap::value_type,
49 typename Variable::value_type>;
50
52 using geometryMap_type = GeometryMap;
53
55 using variable_type = Variable;
56
59 std::pair<typename GeometryMap::eval_type,
60 typename GeometryMap::boundary_eval_type>;
61
64 std::pair<typename Variable::eval_type,
65 typename Variable::boundary_eval_type>;
66
68 bool static constexpr has_GeometryMap = true;
69
71 bool static constexpr has_RefData = false;
72
74 bool static constexpr has_Solution = true;
75
76protected:
78 GeometryMap G_;
79
81 Variable u_;
82
83private:
86 template <std::size_t... GeometryMapNumCoeffs, std::size_t... Is,
87 std::size_t... VariableNumCoeffs, std::size_t... Js>
89 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
90 geometryMapNumCoeffs,
91 std::index_sequence<Is...>,
92 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
93 std::index_sequence<Js...>,
95 : // Construct the different spline objects individually
96 G_(std::get<Is>(geometryMapNumCoeffs)..., init::greville, options),
97 u_(std::get<Js>(variableNumCoeffs)..., init::random, options) {}
98
99public:
103 : G_(), u_() {}
104
108 template <std::size_t NumCoeffs>
110 std::array<int64_t, NumCoeffs> numCoeffs,
112 : IgABaseNoRefData(std::tuple{numCoeffs}, std::tuple{numCoeffs},
113 options) {}
114
115 template <std::size_t... NumCoeffs>
117 std::tuple<std::array<int64_t, NumCoeffs>...> numCoeffs,
119 : IgABaseNoRefData(numCoeffs, numCoeffs, options) {}
121
125 template <std::size_t GeometryMapNumCoeffs, std::size_t VariableNumCoeffs>
127 std::array<int64_t, GeometryMapNumCoeffs> geometryMapNumCoeffs,
128 std::array<int64_t, VariableNumCoeffs> variableNumCoeffs,
130 : IgABaseNoRefData(std::tuple{geometryMapNumCoeffs},
131 std::tuple{variableNumCoeffs}, options) {}
132
133 template <std::size_t... GeometryMapNumCoeffs,
134 std::size_t... VariableNumCoeffs>
136 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
137 geometryMapNumCoeffs,
138 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
141 geometryMapNumCoeffs,
142 std::make_index_sequence<sizeof...(GeometryMapNumCoeffs)>{},
143 variableNumCoeffs,
144 std::make_index_sequence<sizeof...(VariableNumCoeffs)>{}, options) {
145 }
147
149 virtual ~IgABaseNoRefData() = default;
150
153 inline const GeometryMap &G() const { return G_; }
154
157 inline GeometryMap &G() { return G_; }
158
161 inline const Variable &u() const { return u_; }
162
165 inline Variable &u() { return u_; }
166
167private:
174 template <std::size_t... Is>
175 geometryMap_collPts_type
176 geometryMap_collPts(enum collPts collPtsType,
177 std::index_sequence<Is...>) const {
179
180 switch (collPtsType) {
181
182 case collPts::greville:
183 // Get Greville abscissae inside the domain and at the boundary
184 ((std::get<Is>(collPts.first) =
185 G_.template space<Is>().greville(/* interior */ false)),
186 ...);
187
188 // Get Greville abscissae at the domain
189 ((std::get<Is>(collPts.second) = G_.template boundary<Is>().greville()),
190 ...);
191 break;
192
193 case collPts::greville_interior:
194 // Get Greville abscissae inside the domain
195 ((std::get<Is>(collPts.first) =
196 G_.template space<Is>().greville(/* interior */ true)),
197 ...);
198
199 // Get Greville abscissae at the domain
200 ((std::get<Is>(collPts.second) = G_.template boundary<Is>().greville()),
201 ...);
202 break;
203
204 case collPts::greville_ref1:
205 // Get Greville abscissae inside the domain and at the boundary
206 ((std::get<Is>(collPts.first) =
207 G_.template space<Is>().clone().uniform_refine().greville(
208 /* interior */ false)),
209 ...);
210
211 // Get Greville abscissae at the domain
212 ((std::get<Is>(collPts.second) =
213 G_.template boundary<Is>().clone().uniform_refine().greville()),
214 ...);
215 break;
216
217 case collPts::greville_interior_ref1:
218 // Get Greville abscissae inside the domain
219 ((std::get<Is>(collPts.first) =
220 G_.template space<Is>().clone().uniform_refine().greville(
221 /* interior */ true)),
222 ...);
223
224 // Get Greville abscissae at the domain
225 ((std::get<Is>(collPts.second) =
226 G_.template boundary<Is>().clone().uniform_refine().greville()),
227 ...);
228 break;
229
230 case collPts::greville_ref2:
231 // Get Greville abscissae inside the domain and at the boundary
232 ((std::get<Is>(collPts.first) =
233 G_.template space<Is>().clone().uniform_refine(2, -1).greville(
234 /* interior */ false)),
235 ...);
236
237 // Get Greville abscissae at the domain
238 ((std::get<Is>(collPts.second) = G_.template boundary<Is>()
239 .clone()
240 .uniform_refine(2, -1)
241 .greville()),
242 ...);
243 break;
244
245 case collPts::greville_interior_ref2:
246 // Get Greville abscissae inside the domain
247 ((std::get<Is>(collPts.first) =
248 G_.template space<Is>().clone().uniform_refine(2, -1).greville(
249 /* interior */ true)),
250 ...);
251
252 // Get Greville abscissae at the domain
253 ((std::get<Is>(collPts.second) = G_.template boundary<Is>()
254 .clone()
255 .uniform_refine(2, -1)
256 .greville()),
257 ...);
258 break;
259
260 case collPts::greville_ref3:
261 // Get Greville abscissae inside the domain and at the boundary
262 ((std::get<Is>(collPts.first) =
263 G_.template space<Is>().clone().uniform_refine(3, -1).greville(
264 /* interior */ false)),
265 ...);
266
267 // Get Greville abscissae at the domain
268 ((std::get<Is>(collPts.second) = G_.template boundary<Is>()
269 .clone()
270 .uniform_refine(3, -1)
271 .greville()),
272 ...);
273 break;
274
275 case collPts::greville_interior_ref3:
276 // Get Greville abscissae inside the domain
277 ((std::get<Is>(collPts.first) =
278 G_.template space<Is>().clone().uniform_refine(3, -1).greville(
279 /* interior */ true)),
280 ...);
281
282 // Get Greville abscissae at the domain
283 ((std::get<Is>(collPts.second) = G_.template boundary<Is>()
284 .clone()
285 .uniform_refine(3, -1)
286 .greville()),
287 ...);
288 break;
289
290 default:
291 throw std::runtime_error("Invalid collocation point specifier");
292 }
293
294 return collPts;
295 }
296
303 template <std::size_t... Is>
305 std::index_sequence<Is...>) const {
307
308 switch (collPtsType) {
309
310 case collPts::greville:
311 // Get Greville abscissae inside the domain and at the boundary
312 ((std::get<Is>(collPts.first) =
313 u_.template space<Is>().greville(/* interior */ false)),
314 ...);
315
316 // Get Greville abscissae at the domain
317 ((std::get<Is>(collPts.second) = u_.template boundary<Is>().greville()),
318 ...);
319 break;
320
321 case collPts::greville_interior:
322 // Get Greville abscissae inside the domain and at the boundary
323 ((std::get<Is>(collPts.first) =
324 u_.template space<Is>().greville(/* interior */ true)),
325 ...);
326
327 // Get Greville abscissae at the domain
328 ((std::get<Is>(collPts.second) = u_.template boundary<Is>().greville()),
329 ...);
330 break;
331
332 case collPts::greville_ref1:
333 // Get Greville abscissae inside the domain and at the boundary
334 ((std::get<Is>(collPts.first) =
335 u_.template space<Is>().clone().uniform_refine().greville(
336 /* interior */ false)),
337 ...);
338
339 // Get Greville abscissae at the domain
340 ((std::get<Is>(collPts.second) =
341 u_.template boundary<Is>().clone().uniform_refine().greville()),
342 ...);
343 break;
344
345 case collPts::greville_interior_ref1:
346 // Get Greville abscissae inside the domain and at the boundary
347 ((std::get<Is>(collPts.first) =
348 u_.template space<Is>().clone().uniform_refine().greville(
349 /* interior */ true)),
350 ...);
351
352 // Get Greville abscissae at the domain
353 ((std::get<Is>(collPts.second) =
354 u_.template boundary<Is>().clone().uniform_refine().greville()),
355 ...);
356 break;
357
358 case collPts::greville_ref2:
359 // Get Greville abscissae inside the domain and at the boundary
360 ((std::get<Is>(collPts.first) =
361 u_.template space<Is>().clone().uniform_refine(2, -1).greville(
362 /* interior */ false)),
363 ...);
364
365 // Get Greville abscissae at the domain
366 ((std::get<Is>(collPts.second) = u_.template boundary<Is>()
367 .clone()
368 .uniform_refine(2, -1)
369 .greville()),
370 ...);
371 break;
372
373 case collPts::greville_interior_ref2:
374 // Get Greville abscissae inside the domain and at the boundary
375 ((std::get<Is>(collPts.first) =
376 u_.template space<Is>().clone().uniform_refine(2, -1).greville(
377 /* interior */ true)),
378 ...);
379
380 // Get Greville abscissae at the domain
381 ((std::get<Is>(collPts.second) = u_.template boundary<Is>()
382 .clone()
383 .uniform_refine(2, -1)
384 .greville()),
385 ...);
386 break;
387
388 case collPts::greville_ref3:
389 // Get Greville abscissae inside the domain and at the boundary
390 ((std::get<Is>(collPts.first) =
391 u_.template space<Is>().clone().uniform_refine(3, -1).greville(
392 /* interior */ false)),
393 ...);
394
395 // Get Greville abscissae at the domain
396 ((std::get<Is>(collPts.second) = u_.template boundary<Is>()
397 .clone()
398 .uniform_refine(3, -1)
399 .greville()),
400 ...);
401 break;
402
403 case collPts::greville_interior_ref3:
404 // Get Greville abscissae inside the domain and at the boundary
405 ((std::get<Is>(collPts.first) =
406 u_.template space<Is>().clone().uniform_refine(3, -1).greville(
407 /* interior */ true)),
408 ...);
409
410 // Get Greville abscissae at the domain
411 ((std::get<Is>(collPts.second) = u_.template boundary<Is>()
412 .clone()
413 .uniform_refine(3, -1)
414 .greville()),
415 ...);
416 break;
417
418 default:
419 throw std::runtime_error("Invalid collocation point specifier");
420 }
421
422 return collPts;
423 }
424
425public:
432 virtual geometryMap_collPts_type
434 if constexpr (GeometryMap::nspaces() == 1)
435
436 switch (collPts) {
437
438 case collPts::greville:
439 return {G_.space().greville(/* interior */ false),
440 G_.boundary().greville()};
441
442 case collPts::greville_interior:
443 return {G_.space().greville(/* interior */ true),
444 G_.boundary().greville()};
445
446 case collPts::greville_ref1:
447 return {
448 G_.space().clone().uniform_refine().greville(/* interior */ false),
449 G_.boundary().clone().uniform_refine().greville()};
450
451 case collPts::greville_interior_ref1:
452 return {
453 G_.space().clone().uniform_refine().greville(/* interior */ true),
454 G_.boundary().clone().uniform_refine().greville()};
455
456 case collPts::greville_ref2:
457 return {G_.space().clone().uniform_refine(2, -1).greville(
458 /* interior */ false),
459 G_.boundary().clone().uniform_refine(2, -1).greville()};
460
461 case collPts::greville_interior_ref2:
462 return {G_.space().clone().uniform_refine(2, -1).greville(
463 /* interior */ true),
464 G_.boundary().clone().uniform_refine(2, -1).greville()};
465
466 case collPts::greville_ref3:
467 return {G_.space().clone().uniform_refine(3, -1).greville(
468 /* interior */ false),
469 G_.boundary().clone().uniform_refine(3, -1).greville()};
470
471 case collPts::greville_interior_ref3:
472 return {G_.space().clone().uniform_refine(3, -1).greville(
473 /* interior */ true),
474 G_.boundary().clone().uniform_refine(3, -1).greville()};
475
476 default:
477 throw std::runtime_error("Invalid collocation point specifier");
478 }
479
480 else
481 return geometryMap_collPts(
482 collPts, std::make_index_sequence<GeometryMap::nspaces()>{});
483 }
484
492 if constexpr (Variable::nspaces() == 1)
493
494 switch (collPts) {
495
496 case collPts::greville:
497 return {u_.space().greville(/* interior */ false),
498 u_.boundary().greville()};
499
500 case collPts::greville_interior:
501 return {u_.space().greville(/* interior */ true),
502 u_.boundary().greville()};
503
504 case collPts::greville_ref1:
505 return {
506 u_.space().clone().uniform_refine().greville(/* interior */ false),
507 u_.boundary().clone().uniform_refine().greville()};
508
509 case collPts::greville_interior_ref1:
510 return {
511 u_.space().clone().uniform_refine().greville(/* interior */ true),
512 u_.boundary().clone().uniform_refine().greville()};
513
514 case collPts::greville_ref2:
515 return {u_.space().clone().uniform_refine(2, -1).greville(
516 /* interior */ false),
517 u_.boundary().clone().uniform_refine(2, -1).greville()};
518
519 case collPts::greville_interior_ref2:
520 return {u_.space().clone().uniform_refine(2, -1).greville(
521 /* interior */ true),
522 u_.boundary().clone().uniform_refine(2, -1).greville()};
523
524 case collPts::greville_ref3:
525 return {u_.space().clone().uniform_refine(3, -1).greville(
526 /* interior */ false),
527 u_.boundary().clone().uniform_refine(3, -1).greville()};
528
529 case collPts::greville_interior_ref3:
530 return {u_.space().clone().uniform_refine(3, -1).greville(
531 /* interior */ true),
532 u_.boundary().clone().uniform_refine(3, -1).greville()};
533
534 default:
535 throw std::runtime_error("Invalid collocation point specifier");
536 }
537
538 else
539 return variable_collPts(collPts,
540 std::make_index_sequence<Variable::nspaces()>{});
541 }
542};
543
547template <typename GeometryMap, typename Variable>
549class [[deprecated("Use novel IgANet implementation")]] IgABase : public IgABaseNoRefData<GeometryMap, Variable> {
550public:
553
556
558 using geometryMap_type = GeometryMap;
559
561 using variable_type = Variable;
562
565
568
570 bool static constexpr has_GeometryMap = true;
571
573 bool static constexpr has_RefData = true;
574
576 bool static constexpr has_Solution = true;
577
578protected:
580 Variable f_;
581
582private:
585 template <std::size_t... GeometryMapNumCoeffs, std::size_t... Is,
586 std::size_t... VariableNumCoeffs, std::size_t... Js>
588 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
589 geometryMapNumCoeffs,
590 std::index_sequence<Is...>,
591 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
592 std::index_sequence<Js...>,
594 : // Construct the different spline objects individually
595 Base(geometryMapNumCoeffs, variableNumCoeffs, options),
596 f_(std::get<Js>(variableNumCoeffs)..., init::zeros, options) {}
597
598public:
600 explicit IgABase(
602 : Base(), f_() {}
603
607 template <std::size_t NumCoeffs>
608 explicit IgABase(
609 std::array<int64_t, NumCoeffs> numCoeffs,
611 : IgABase(std::tuple{numCoeffs}, std::tuple{numCoeffs}, options) {}
612
613 template <std::size_t... NumCoeffs>
614 explicit IgABase(
615 std::tuple<std::array<int64_t, NumCoeffs>...> numCoeffs,
617 : IgABase(numCoeffs, numCoeffs, options) {}
619
623 template <std::size_t GeometryMapNumCoeffs, std::size_t VariableNumCoeffs>
624 IgABase(std::array<int64_t, GeometryMapNumCoeffs> geometryMapNumCoeffs,
625 std::array<int64_t, VariableNumCoeffs> variableNumCoeffs,
627 : IgABase(std::tuple{geometryMapNumCoeffs}, std::tuple{variableNumCoeffs},
628 options) {}
629
630 template <std::size_t... GeometryMapNumCoeffs,
631 std::size_t... VariableNumCoeffs>
633 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
634 geometryMapNumCoeffs,
635 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
637 : IgABase(geometryMapNumCoeffs,
638 std::make_index_sequence<sizeof...(GeometryMapNumCoeffs)>{},
639 variableNumCoeffs,
640 std::make_index_sequence<sizeof...(VariableNumCoeffs)>{},
641 options) {}
643
646 inline const Variable &f() const { return f_; }
647
650 inline Variable &f() { return f_; }
651};
652
656template <typename Optimizer, typename GeometryMap, typename Variable,
657 template <typename, typename> typename IgABase = IgABase>
660class [[deprecated("Use novel IgANet implementation")]] IgANet : public IgABase<GeometryMap, Variable>,
663public:
666
668 using optimizer_type = Optimizer;
669
672
673protected:
676
678 std::unique_ptr<optimizer_type> opt_;
679
682
683public:
685 explicit IgANet(const IgANetOptions &defaults = {},
688 : // Construct the base class
689 Base(),
690 // Construct the optimizer
691 opt_(std::make_unique<optimizer_type>(net_->parameters())),
692 // Set options
693 options_(defaults) {}
694
699 template <std::size_t NumCoeffs>
700 IgANet(const std::vector<int64_t> &layers,
701 const std::vector<std::vector<std::any>> &activations,
702 std::array<int64_t, NumCoeffs> numCoeffs, IgANetOptions defaults = {},
705 : IgANet(layers, activations, std::tuple{numCoeffs},
706 std::tuple{numCoeffs}, defaults, options) {}
707
708 template <std::size_t... NumCoeffs>
709 IgANet(const std::vector<int64_t> &layers,
710 const std::vector<std::vector<std::any>> &activations,
711 std::tuple<std::array<int64_t, NumCoeffs>...> numCoeffs,
712 IgANetOptions defaults = {},
715 : IgANet(layers, activations, numCoeffs, numCoeffs, defaults, options) {}
717
722 template <std::size_t GeometryMapNumCoeffs, std::size_t VariableNumCoeffs>
723 IgANet(const std::vector<int64_t> &layers,
724 const std::vector<std::vector<std::any>> &activations,
725 std::array<int64_t, GeometryMapNumCoeffs> geometryMapNumCoeffs,
726 std::array<int64_t, VariableNumCoeffs> variableNumCoeffs,
727 IgANetOptions defaults = {},
730 : IgANet(layers, activations, std::tuple{geometryMapNumCoeffs},
731 std::tuple{variableNumCoeffs}, defaults, options) {}
732
733 template <std::size_t... GeometryMapNumCoeffs,
734 std::size_t... VariableNumCoeffs>
736 const std::vector<int64_t> &layers,
737 const std::vector<std::vector<std::any>> &activations,
738 std::tuple<std::array<int64_t, GeometryMapNumCoeffs>...>
739 geometryMapNumCoeffs,
740 std::tuple<std::array<int64_t, VariableNumCoeffs>...> variableNumCoeffs,
741 IgANetOptions defaults = {},
744 : // Construct the base class
745 Base(geometryMapNumCoeffs, variableNumCoeffs, options),
746 // Construct the deep neural network
747 net_(utils::concat(std::vector<int64_t>{inputs(/* epoch */ 0).size(0)},
748 layers,
749 std::vector<int64_t>{Base::u_.as_tensor_size()}),
750 activations, options),
751
752 // Construct the optimizer
753 opt_(std::make_unique<optimizer_type>(net_->parameters())),
754
755 // Set options
756 options_(defaults) {}
757
760 return net_;
761 }
762
765
767 inline const optimizer_type &optimizer() const { return *opt_; }
768
770 inline optimizer_type &optimizer() { return *opt_; }
771
776 inline void optimizerReset(bool resetOptions = true) {
777 if (resetOptions)
778 opt_ = std::make_unique<optimizer_type>(net_->parameters());
779 else {
780 std::vector<optimizer_options_type> options;
781 for (auto &group : opt_->param_groups())
782 options.push_back(
783 static_cast<optimizer_options_type &>(group.options()));
784 opt_ = std::make_unique<optimizer_type>(net_->parameters());
785 for (auto [group, options] : utils::zip(opt_->param_groups(), options))
786 static_cast<optimizer_options_type &>(group.options()) = options;
787 }
788 }
789
791 inline void optimizerReset(const optimizer_options_type &optimizerOptions) {
792 opt_ =
793 std::make_unique<optimizer_type>(net_->parameters(), optimizerOptions);
794 }
795
797 inline optimizer_options_type &optimizerOptions(std::size_t param_group = 0) {
798 if (param_group < opt_->param_groups().size())
799 return static_cast<optimizer_options_type &>(
800 opt_->param_groups()[param_group].options());
801 else
802 throw std::runtime_error("Index exceeds number of parameter groups");
803 }
804
806 inline const optimizer_options_type &
807 optimizerOptions(std::size_t param_group = 0) const {
808 if (param_group < opt_->param_groups().size())
809 return static_cast<optimizer_options_type &>(
810 opt_->param_groups()[param_group].options());
811 else
812 throw std::runtime_error("Index exceeds number of parameter groups");
813 }
814
816 inline void optimizerOptionsReset(const optimizer_options_type &options) {
817 for (auto &group : opt_->param_groups())
818 static_cast<optimizer_options_type &>(group.options()) = options;
819 }
820
823 for (auto &group : opt_->param_groups())
824 static_cast<optimizer_options_type &>(group.options()) = options;
825 }
826
829 std::size_t param_group) {
830 if (param_group < opt_->param_groups().size())
831 static_cast<optimizer_options_type &>(opt_->param_group().options()) =
832 options;
833 else
834 throw std::runtime_error("Index exceeds number of parameter groups");
835 }
836
839 std::size_t param_group) {
840 if (param_group < opt_->param_groups().size())
841 static_cast<optimizer_options_type &>(opt_->param_group().options()) =
842 options;
843 else
844 throw std::runtime_error("Index exceeds number of parameter groups");
845 }
846
848 inline const auto &options() const { return options_; }
849
851 inline auto &options() { return options_; }
852
859 virtual torch::Tensor inputs(int64_t epoch) const {
860 if constexpr (Base::has_GeometryMap && Base::has_RefData)
861 return torch::cat({Base::G_.as_tensor(), Base::f_.as_tensor()});
862 else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
863 return Base::G_.as_tensor();
864 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
865 return Base::f_.as_tensor();
866 else
867 return torch::empty({0});
868 }
869
871 virtual bool epoch(int64_t) = 0;
872
874 virtual torch::Tensor loss(const torch::Tensor &, int64_t) = 0;
875
877 virtual void train(
878#ifdef IGANET_WITH_MPI
879 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
880 c10d::ProcessGroupMPI::createProcessGroupMPI()
881#endif
882 ) {
883 torch::Tensor inputs, outputs, loss;
884 typename Base::value_type previous_loss(-1.0);
885
886 // Loop over epochs
887 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
888
889 // Update epoch and inputs
890 if (this->epoch(epoch))
891 inputs = this->inputs(epoch);
892
893 auto closure = [&]() {
894 // Reset gradients
895 net_->zero_grad();
896
897 // Execute the model on the inputs
898 outputs = net_->forward(inputs);
899
900 // Compute the loss value
901 loss = this->loss(outputs, epoch);
902
903 // Compute gradients of the loss w.r.t. the model parameters
904 loss.backward({}, true, false);
905
906 return loss;
907 };
908
909#ifdef IGANET_WITH_MPI
910 // Averaging the gradients of the parameters in all the processors
911 // Note: This may lag behind DistributedDataParallel (DDP) in performance
912 // since this synchronizes parameters after backward pass while DDP
913 // overlaps synchronizing parameters and computing gradients in backward
914 // pass
915 std::vector<c10::intrusive_ptr<::c10d::Work>> works;
916 for (auto &param : net_->named_parameters()) {
917 std::vector<torch::Tensor> tmp = {param.value().grad()};
918 works.emplace_back(pg->allreduce(tmp));
919 }
920
921 waitWork(pg, works);
922
923 for (auto &param : net_->named_parameters()) {
924 param.value().grad().data() =
925 param.value().grad().data() / pg->getSize();
926 }
927#endif
928
929 // Update the parameters based on the calculated gradients
930 opt_->step(closure);
931
932 typename Base::value_type current_loss =
933 loss.item<typename Base::value_type>();
934 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
935 << current_loss << std::endl;
936
937 if (current_loss < options_.min_loss() ||
938 std::abs(current_loss - previous_loss) < options_.min_loss_change() ||
939 std::abs(current_loss - previous_loss) / current_loss <
940 options_.min_loss_rel_change() ||
941 loss.isnan().item<bool>()) {
942 Log(log::info) << "Total epochs: " << epoch
943 << ", loss: " << current_loss << std::endl;
944 return;
945 }
946 previous_loss = current_loss;
947 }
948 Log(log::info) << "Max epochs reached: " << options_.max_epoch()
949 << ", loss: " << previous_loss << std::endl;
950 }
951
953 template <typename DataLoader>
954 void train(DataLoader &loader
955#ifdef IGANET_WITH_MPI
956 ,
957 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
958 c10d::ProcessGroupMPI::createProcessGroupMPI()
959#endif
960 ) {
961 torch::Tensor inputs, outputs, loss;
962 typename Base::value_type previous_loss(-1.0);
963
964 // Loop over epochs
965 for (int64_t epoch = 0; epoch != options_.max_epoch(); ++epoch) {
966
967 typename Base::value_type current_loss(0);
968
969 for (auto &batch : loader) {
970 inputs = batch.data;
971
972 if (inputs.dim() > 0) {
973 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
974 Base::G_.from_tensor(
975 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
976 Base::f_.from_tensor(inputs
977 .slice(1, Base::G_.as_tensor_size(),
978 Base::G_.as_tensor_size() +
979 Base::f_.as_tensor_size())
980 .t());
981 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
982 Base::G_.from_tensor(
983 inputs.slice(1, 0, Base::G_.as_tensor_size()).t());
984 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
985 Base::f_.from_tensor(
986 inputs.slice(1, 0, Base::f_.as_tensor_size()).t());
987
988 } else {
989 if constexpr (Base::has_GeometryMap && Base::has_RefData) {
990 Base::G_.from_tensor(
991 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
992 Base::f_.from_tensor(inputs
993 .slice(1, Base::G_.as_tensor_size(),
994 Base::G_.as_tensor_size() +
995 Base::f_.as_tensor_size())
996 .flatten());
997 } else if constexpr (Base::has_GeometryMap && !Base::has_RefData)
998 Base::G_.from_tensor(
999 inputs.slice(1, 0, Base::G_.as_tensor_size()).flatten());
1000 else if constexpr (!Base::has_GeometryMap && Base::has_RefData)
1001 Base::f_.from_tensor(
1002 inputs.slice(1, 0, Base::f_.as_tensor_size()).flatten());
1003 }
1004
1005 this->epoch(epoch);
1006
1007 auto closure = [&]() {
1008 // Reset gradients
1009 net_->zero_grad();
1010
1011 // Execute the model on the inputs
1012 outputs = net_->forward(inputs);
1013
1014 // Compute the loss value
1015 loss = this->loss(outputs, epoch);
1016
1017 // Compute gradients of the loss w.r.t. the model parameters
1018 loss.backward({}, true, false);
1019
1020 return loss;
1021 };
1022
1023 // Update the parameters based on the calculated gradients
1024 opt_->step(closure);
1025
1026 current_loss += loss.item<typename Base::value_type>();
1027 }
1028 Log(log::verbose) << "Epoch " << std::to_string(epoch) << ": "
1029 << current_loss << std::endl;
1030
1031 if (current_loss < options_.min_loss() ||
1032 std::abs(current_loss - previous_loss) < options_.min_loss_change() ||
1033 std::abs(current_loss - previous_loss) / current_loss <
1034 options_.min_loss_rel_change() ||
1035 loss.isnan().item<bool>()) {
1036 Log(log::info) << "Total epochs: " << epoch
1037 << ", loss: " << current_loss << std::endl;
1038 return;
1039 }
1040 previous_loss = current_loss;
1041 }
1042 Log(log::info) << "Max epochs reached: " << options_.max_epoch()
1043 << ", loss: " << previous_loss << std::endl;
1044 }
1045
1047 void eval() {
1048 torch::Tensor inputs = this->inputs(0);
1049 torch::Tensor outputs = net_->forward(inputs);
1050 Base::u_.from_tensor(outputs);
1051 }
1052
1054 inline nlohmann::json to_json() const override {
1055 return "Not implemented yet";
1056 }
1057
1059 inline std::vector<torch::Tensor> parameters() const noexcept {
1060 return net_->parameters();
1061 }
1062
1065 inline torch::OrderedDict<std::string, torch::Tensor>
1066 named_parameters() const noexcept {
1067 return net_->named_parameters();
1068 }
1069
1071 inline std::size_t nparameters() const noexcept {
1072 std::size_t result = 0;
1073 for (const auto &param : this->parameters()) {
1074 result += param.numel();
1075 }
1076 return result;
1077 }
1078
1080 torch::Tensor& register_parameter(std::string name, torch::Tensor tensor, bool requires_grad = true) {
1081 return net_->register_parameter(name, tensor, requires_grad);
1082 }
1083
1085 inline void pretty_print(std::ostream &os) const noexcept override {
1086 os << name() << "(\n"
1087 << "net = " << net_ << "\n";
1088 if constexpr (Base::has_GeometryMap)
1089 os << "G = " << Base::G_ << "\n";
1090 if constexpr (Base::has_RefData)
1091 os << "f = " << Base::f_ << "\n";
1092 if constexpr (Base::has_Solution)
1093 os << "u = " << Base::u_ << "\n)";
1094 }
1095
1097 inline void save(const std::string &filename,
1098 const std::string &key = "iganet") const {
1099 torch::serialize::OutputArchive archive;
1100 write(archive, key).save_to(filename);
1101 }
1102
1104 inline void load(const std::string &filename,
1105 const std::string &key = "iganet") {
1106 torch::serialize::InputArchive archive;
1107 archive.load_from(filename);
1108 read(archive, key);
1109 }
1110
1112 inline torch::serialize::OutputArchive &
1113 write(torch::serialize::OutputArchive &archive,
1114 const std::string &key = "iganet") const {
1115 if constexpr (Base::has_GeometryMap)
1116 Base::G_.write(archive, key + ".geo");
1117 if constexpr (Base::has_RefData)
1118 Base::f_.write(archive, key + ".ref");
1119 if constexpr (Base::has_Solution)
1120 Base::u_.write(archive, key + ".out");
1121
1122 net_->write(archive, key + ".net");
1123 torch::serialize::OutputArchive archive_net;
1124 net_->save(archive_net);
1125 archive.write(key + ".net.data", archive_net);
1126
1127 torch::serialize::OutputArchive archive_opt;
1128 opt_->save(archive_opt);
1129 archive.write(key + ".opt", archive_opt);
1130
1131 return archive;
1132 }
1133
1135 inline torch::serialize::InputArchive &
1136 read(torch::serialize::InputArchive &archive,
1137 const std::string &key = "iganet") {
1138 if constexpr (Base::has_GeometryMap)
1139 Base::G_.read(archive, key + ".geo");
1140 if constexpr (Base::has_RefData)
1141 Base::f_.read(archive, key + ".ref");
1142 if constexpr (Base::has_Solution)
1143 Base::u_.read(archive, key + ".out");
1144
1145 net_->read(archive, key + ".net");
1146 torch::serialize::InputArchive archive_net;
1147 archive.read(key + ".net.data", archive_net);
1148 net_->load(archive_net);
1149
1150 opt_->add_parameters(net_->parameters());
1151 torch::serialize::InputArchive archive_opt;
1152 archive.read(key + ".opt", archive_opt);
1153 opt_->load(archive_opt);
1154
1155 return archive;
1156 }
1157
1159 bool operator==(const IgANet &other) const {
1160 bool result(true);
1161
1162 if constexpr (Base::has_GeometryMap)
1163 result *= (Base::G_ == other.G());
1164 if constexpr (Base::has_RefData)
1165 result *= (Base::f_ == other.f());
1166 if constexpr (Base::has_Solution)
1167 result *= (Base::u_ == other.u());
1168
1169 return result;
1170 }
1171
1173 bool operator!=(const IgANet &other) const { return *this != other; }
1174
1175#ifdef IGANET_WITH_MPI
1176private:
1178 static void waitWork(c10::intrusive_ptr<c10d::ProcessGroupMPI> pg,
1179 std::vector<c10::intrusive_ptr<c10d::Work>> works) {
1180 for (auto &work : works) {
1181 try {
1182 work->wait();
1183 } catch (const std::exception &ex) {
1184 Log(log::error) << "Exception received during waitWork: " << ex.what()
1185 << std::endl;
1186 pg->abort();
1187 }
1188 }
1189 }
1190#endif
1191};
1192
1194template <typename Optimizer, typename GeometryMap, typename Variable>
1195 requires OptimizerType<Optimizer> && FunctionSpaceType<GeometryMap> &&
1196 FunctionSpaceType<Variable>
1197inline std::ostream &
1198operator<<(std::ostream &os,
1200 obj.pretty_print(os);
1201 return os;
1202}
1203
1209template <typename GeometryMap, typename Variable>
1211class [[deprecated("Use novel IgANetCustomizable implementation")]] IgANetCustomizable {
1212public:
1215 decltype(std::declval<GeometryMap>()
1216 .template find_knot_indices<functionspace::interior>(
1217 std::declval<typename GeometryMap::eval_type>()));
1218
1221 decltype(std::declval<GeometryMap>()
1222 .template find_knot_indices<functionspace::boundary>(
1223 std::declval<
1224 typename GeometryMap::boundary_eval_type>()));
1225
1228 decltype(std::declval<Variable>()
1229 .template find_knot_indices<functionspace::interior>(
1230 std::declval<typename Variable::eval_type>()));
1231
1234 decltype(std::declval<Variable>()
1235 .template find_knot_indices<functionspace::boundary>(
1236 std::declval<typename Variable::boundary_eval_type>()));
1237
1240 decltype(std::declval<GeometryMap>()
1241 .template find_coeff_indices<functionspace::interior>(
1242 std::declval<typename GeometryMap::eval_type>()));
1243
1246 decltype(std::declval<GeometryMap>()
1247 .template find_coeff_indices<functionspace::boundary>(
1248 std::declval<
1249 typename GeometryMap::boundary_eval_type>()));
1250
1253 decltype(std::declval<Variable>()
1254 .template find_coeff_indices<functionspace::interior>(
1255 std::declval<typename Variable::eval_type>()));
1256
1259 decltype(std::declval<Variable>()
1260 .template find_coeff_indices<functionspace::boundary>(
1261 std::declval<typename Variable::boundary_eval_type>()));
1262};
1263
1264} // namespace iganet::v1
Boundary treatment.
IgA base class.
Definition iganet.hpp:46
IgANetGenerator.
Definition generator.hpp:927
IgANet.
Definition iganet.hpp:519
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:104
Full qualified name descriptor.
Definition fqn.hpp:22
IgA base class.
Definition iganet.hpp:549
Base::variable_collPts_type variable_collPts_type
Type of the variable collocation points.
Definition iganet.hpp:567
IgABase(std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for geometry map and variables)
Definition iganet.hpp:632
GeometryMap geometryMap_type
Type of the geometry map function space(s)
Definition iganet.hpp:558
Variable & f()
Returns a non-constant reference to the spline representation of the reference data.
Definition iganet.hpp:650
Base::value_type value_type
Value type.
Definition iganet.hpp:555
IgABase(std::array< int64_t, GeometryMapNumCoeffs > geometryMapNumCoeffs, std::array< int64_t, VariableNumCoeffs > variableNumCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for geometry map and variables)
Definition iganet.hpp:624
IgABase(iganet::Options< value_type > options=iganet::Options< value_type >{})
Default constructor.
Definition iganet.hpp:600
const Variable & f() const
Returns a constant reference to the spline representation of the reference data.
Definition iganet.hpp:646
IgABase(std::tuple< std::array< int64_t, NumCoeffs >... > numCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (same for geometry map and variables)
Definition iganet.hpp:614
IgABase(std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::index_sequence< Is... >, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, std::index_sequence< Js... >, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for Geometry and Variable types)
Definition iganet.hpp:587
Variable f_
Spline representation of the reference data.
Definition iganet.hpp:580
Variable variable_type
Type of the variable function space(s)
Definition iganet.hpp:561
Base::geometryMap_collPts_type geometryMap_collPts_type
Type of the geometry map collocation points.
Definition iganet.hpp:564
IgABase(std::array< int64_t, NumCoeffs > numCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (same for geometry map and variables)
Definition iganet.hpp:608
IgA base class (no reference data)
Definition iganet.hpp:45
virtual ~IgABaseNoRefData()=default
Destructor.
std::pair< typename Variable::eval_type, typename Variable::boundary_eval_type > variable_collPts_type
Type of the variable collocation points.
Definition iganet.hpp:65
IgABaseNoRefData(std::array< int64_t, GeometryMapNumCoeffs > geometryMapNumCoeffs, std::array< int64_t, VariableNumCoeffs > variableNumCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for geometry map and variables)
Definition iganet.hpp:126
GeometryMap G_
Spline representation of the geometry map.
Definition iganet.hpp:78
Variable variable_type
Type of the variable function space(s)
Definition iganet.hpp:55
Variable & u()
Returns a non-constant reference to the spline representation of the solution.
Definition iganet.hpp:165
GeometryMap & G()
Returns a non-constant reference to the spline representation of the geometry map.
Definition iganet.hpp:157
variable_collPts_type variable_collPts(enum collPts collPtsType, std::index_sequence< Is... >) const
Returns the variable collocation points.
Definition iganet.hpp:304
virtual geometryMap_collPts_type geometryMap_collPts(enum collPts collPts) const
Returns the geometry map collocation points.
Definition iganet.hpp:433
IgABaseNoRefData(iganet::Options< value_type > options=iganet::Options< value_type >{})
Default constructor.
Definition iganet.hpp:101
Variable u_
Spline representation of the solution.
Definition iganet.hpp:81
virtual variable_collPts_type variable_collPts(enum collPts collPts) const
Returns the variable collocation points.
Definition iganet.hpp:491
IgABaseNoRefData(std::tuple< std::array< int64_t, NumCoeffs >... > numCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (same for geometry map and variables)
Definition iganet.hpp:116
std::pair< typename GeometryMap::eval_type, typename GeometryMap::boundary_eval_type > geometryMap_collPts_type
Type of the geometry map collocation points.
Definition iganet.hpp:60
GeometryMap geometryMap_type
Type of the geometry map function space(s)
Definition iganet.hpp:52
const GeometryMap & G() const
Returns a constant reference to the spline representation of the geometry map.
Definition iganet.hpp:153
const Variable & u() const
Returns a constant reference to the spline representation of the solution.
Definition iganet.hpp:161
IgABaseNoRefData(std::array< int64_t, NumCoeffs > numCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (same for geometry map and variables)
Definition iganet.hpp:109
IgABaseNoRefData(std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::index_sequence< Is... >, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, std::index_sequence< Js... >, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for Geometry and Variable types)
Definition iganet.hpp:88
geometryMap_collPts_type geometryMap_collPts(enum collPts collPtsType, std::index_sequence< Is... >) const
Returns the geometry map collocation points.
Definition iganet.hpp:176
std::common_type_t< typename GeometryMap::value_type, typename Variable::value_type > value_type
Value type.
Definition iganet.hpp:49
IgABaseNoRefData(std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, iganet::Options< value_type > options=iganet::Options< value_type >{})
Constructor: number of spline coefficients (different for geometry map and variables)
Definition iganet.hpp:135
IgANetCustomizable.
Definition iganet.hpp:1211
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_knot_indices_type
Type of the knot indices of the geometry map at the boundary.
Definition iganet.hpp:1224
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_coeff_indices_type
Type of the coefficient indices of variable type in the interior.
Definition iganet.hpp:1255
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_coeff_indices_type
Type of the coefficient indices of geometry type in the interior.
Definition iganet.hpp:1242
decltype(std::declval< Variable >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_coeff_indices_type
Type of the coefficient indices of variable type at the boundary.
Definition iganet.hpp:1261
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::interior >(std::declval< typename Variable::eval_type >())) variable_interior_knot_indices_type
Type of the knot indices of the variables in the interior.
Definition iganet.hpp:1230
decltype(std::declval< GeometryMap >() .template find_coeff_indices< functionspace::boundary >(std::declval< typename GeometryMap::boundary_eval_type >())) geometryMap_boundary_coeff_indices_type
Type of the coefficient indices of geometry type at the boundary.
Definition iganet.hpp:1249
decltype(std::declval< GeometryMap >() .template find_knot_indices< functionspace::interior >(std::declval< typename GeometryMap::eval_type >())) geometryMap_interior_knot_indices_type
Type of the knot indices of the geometry map in the interior.
Definition iganet.hpp:1217
decltype(std::declval< Variable >() .template find_knot_indices< functionspace::boundary >(std::declval< typename Variable::boundary_eval_type >())) variable_boundary_knot_indices_type
Type of the knot indices of boundary_eval_type type at the boundary.
Definition iganet.hpp:1236
IgANet.
Definition iganet.hpp:662
auto & options()
Returns a non-constant reference to the options structure.
Definition iganet.hpp:851
virtual torch::Tensor inputs(int64_t epoch) const
Returns the network inputs.
Definition iganet.hpp:859
void optimizerOptionsReset(optimizer_options_type &&options)
Resets the optimizer options.
Definition iganet.hpp:822
Optimizer optimizer_type
Type of the optimizer.
Definition iganet.hpp:668
IgANetOptions options_
Options.
Definition iganet.hpp:681
bool operator==(const IgANet &other) const
Returns true if both IgANet objects are the same.
Definition iganet.hpp:1159
optimizer_options_type & optimizerOptions(std::size_t param_group=0)
Returns a non-constant reference to the optimizer options.
Definition iganet.hpp:797
void train(DataLoader &loader)
Trains the IgANet.
Definition iganet.hpp:954
torch::Tensor & register_parameter(std::string name, torch::Tensor tensor, bool requires_grad=true)
Registers a parameter.
Definition iganet.hpp:1080
virtual void train()
Trains the IgANet.
Definition iganet.hpp:877
nlohmann::json to_json() const override
Returns the IgANet object as JSON object.
Definition iganet.hpp:1054
optimizer_type & optimizer()
Returns a non-constant reference to the optimizer.
Definition iganet.hpp:770
void eval()
Evaluate IgANet.
Definition iganet.hpp:1047
IgANet(const IgANetOptions &defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Default constructor.
Definition iganet.hpp:685
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:1113
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, GeometryMapNumCoeffs > geometryMapNumCoeffs, std::array< int64_t, VariableNumCoeffs > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:723
const IgANetGenerator< typename Base::value_type > & net() const
Returns a constant reference to the IgANet generator.
Definition iganet.hpp:759
void load(const std::string &filename, const std::string &key="iganet")
Loads the IgANet from file.
Definition iganet.hpp:1104
torch::OrderedDict< std::string, torch::Tensor > named_parameters() const noexcept
Returns a constant reference to the named parameters of the IgANet object.
Definition iganet.hpp:1066
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::array< int64_t, NumCoeffs > numCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:700
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Loads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:1136
IgANetGenerator< typename Base::value_type > & net()
Returns a non-constant reference to the IgANet generator.
Definition iganet.hpp:764
const optimizer_options_type & optimizerOptions(std::size_t param_group=0) const
Returns a constant reference to the optimizer options.
Definition iganet.hpp:807
optimizer_options_type< Optimizer >::type optimizer_options_type
Type of the optimizer options.
Definition iganet.hpp:671
std::unique_ptr< optimizer_type > opt_
Optimizer.
Definition iganet.hpp:678
void optimizerReset(const optimizer_options_type &optimizerOptions)
Resets the optimizer.
Definition iganet.hpp:791
bool operator!=(const IgANet &other) const
Returns true if both IgANet objects are different.
Definition iganet.hpp:1173
virtual bool epoch(int64_t)=0
Initializes epoch.
void optimizerReset(bool resetOptions=true)
Resets the optimizer.
Definition iganet.hpp:776
IgANetGenerator< typename Base::value_type > net_
IgANet generator.
Definition iganet.hpp:675
std::size_t nparameters() const noexcept
Returns the total number of parameters of the IgANet object.
Definition iganet.hpp:1071
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, GeometryMapNumCoeffs >... > geometryMapNumCoeffs, std::tuple< std::array< int64_t, VariableNumCoeffs >... > variableNumCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (different for...
Definition iganet.hpp:735
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, std::tuple< std::array< int64_t, NumCoeffs >... > numCoeffs, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for geom...
Definition iganet.hpp:709
void optimizerOptionsReset(const optimizer_options_type &options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:828
void optimizerOptionsReset(optimizer_options_type &&options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:838
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the IgANet object.
Definition iganet.hpp:1085
virtual torch::Tensor loss(const torch::Tensor &, int64_t)=0
Computes the loss function.
const optimizer_type & optimizer() const
Returns a constant reference to the optimizer.
Definition iganet.hpp:767
void optimizerOptionsReset(const optimizer_options_type &options)
Resets the optimizer options.
Definition iganet.hpp:816
std::vector< torch::Tensor > parameters() const noexcept
Returns a constant reference to the parameters of the IgANet object.
Definition iganet.hpp:1059
const auto & options() const
Returns a constant reference to the options structure.
Definition iganet.hpp:848
void save(const std::string &filename, const std::string &key="iganet") const
Saves the IgANet to file.
Definition iganet.hpp:1097
Isogeometric analysis base class.
Concept to identify template parameters that are derived from iganet::details::FunctionSpaceType.
Definition functionspace.hpp:3255
Concept to identify template parameters that are derived from torch::optim::Optimizer.
Definition optimizer.hpp:26
Container utility functions.
Full qualified name utility functions.
Function spaces.
Network generator.
Definition iganet.hpp:28
std::ostream & operator<<(std::ostream &os, const IgANet< Optimizer, GeometryMap, Variable > &obj)
Print (as string) a IgANet object.
Definition iganet.hpp:1198
collPts
Enumerator for the collocation point specifier.
Definition collocation.hpp:21
struct iganet::@0 Log
Logger.
init
Enumerator for specifying the initialization of B-spline coefficients.
Definition bspline.hpp:55
Type trait for the optimizer options type.
Definition optimizer.hpp:32
STL namespace.
Options.
Serialization prototype.
Definition serialize.hpp:29
IgANetOptions.
Definition iganet.hpp:31
TORCH_ARG(double, min_loss_rel_change)
TORCH_ARG(int64_t, max_epoch)
TORCH_ARG(int64_t, batch_size)
TORCH_ARG(double, min_loss_change)=0
TORCH_ARG(double, min_loss)
Tuple utility functions.
Zip utility function.