46template <
typename,
typename,
typename =
void>
class IgABase;
50class IgABase<
std::tuple<Inputs...>, std::tuple<Outputs...>,
51 std::tuple<CollPts...>> {
54 using value_type = std::common_type_t<
typename Inputs::value_type...,
55 typename Outputs::value_type...>;
80 template <
typename... Objs, std::size_t... NumCoeffs, std::size_t... Is>
82 const std::tuple<std::array<int64_t, NumCoeffs>...> &numCoeffs,
84 std::index_sequence<Is...>) {
85 static_assert(
sizeof...(Objs) ==
sizeof...(NumCoeffs));
86 return std::make_tuple(std::apply(
87 [&]<
typename... Args>(Args &&...args) {
88 return Objs(std::forward<Args>(args)...,
init, options);
90 std::get<Is>(numCoeffs))...);
93 template <
typename... Objs, std::size_t... NumCoeffs>
95 const std::tuple<std::array<int64_t, NumCoeffs>...> &numCoeffs,
97 return construct_tuple_from_arrays_impl<Objs...>(
98 numCoeffs,
init, options, std::index_sequence_for<Objs...>{});
105 template <
typename... Objs,
typename... NumCoeffsTuples, std::size_t... Is>
107 const std::tuple<NumCoeffsTuples...> &numCoeffs,
enum init init,
109 static_assert(
sizeof...(Objs) ==
sizeof...(NumCoeffsTuples));
110 return std::make_tuple(std::apply(
111 [&]<
typename... Args>(Args &&...args) {
112 return Objs(std::forward<Args>(args)...,
init, options);
114 std::get<Is>(numCoeffs))...);
117 template <
typename... Objs,
typename... NumCoeffsTuples>
122 return construct_tuple_from_tuples_impl<Objs...>(
123 numCoeffs,
init, options, std::index_sequence_for<Objs...>{});
131 : inputs_(), outputs_(), collPts_() {}
137 template <std::
size_t NumCoeffs>
139 const std::array<int64_t, NumCoeffs> &ncoeffs,
149 template <std::size_t NumCoeffsInputs, std::size_t NumCoeffsOutputs,
150 std::size_t NumCoeffsCollPts>
151 IgABase(
const std::array<int64_t, NumCoeffsInputs> &ncoeffsInputs,
152 const std::array<int64_t, NumCoeffsOutputs> &ncoeffsOutputs,
153 const std::array<int64_t, NumCoeffsCollPts> &ncoeffsCollPts,
156 :
IgABase(
std::tuple{ncoeffsInputs},
std::tuple{ncoeffsOutputs},
157 std::tuple{ncoeffsCollPts},
init, options) {}
164 template <std::size_t... NumCoeffs>
166 const std::tuple<std::array<int64_t, NumCoeffs>...> &ncoeffs,
169 :
IgABase(ncoeffs, ncoeffs, ncoeffs,
init, options) {}
175 template <std::size_t... NumCoeffsInputs, std::size_t... NumCoeffsOutputs,
176 std::size_t... NumCoeffsCollPts>
178 const std::tuple<std::array<int64_t, NumCoeffsInputs>...> &ncoeffsInputs,
179 const std::tuple<std::array<int64_t, NumCoeffsOutputs>...>
181 const std::tuple<std::array<int64_t, NumCoeffsCollPts>...>
185 : inputs_(construct_tuple_from_arrays<Inputs...>(ncoeffsInputs,
init,
187 outputs_(construct_tuple_from_arrays<Outputs...>(ncoeffsOutputs,
init,
189 collPts_(construct_tuple_from_arrays<Outputs...>(ncoeffsCollPts,
init,
197 template <
typename... CoeffsInputs,
typename... CoeffsOutputs,
198 typename... CoeffsCollPts>
199 IgABase(
const std::tuple<CoeffsInputs...> &coeffsInputs,
200 const std::tuple<CoeffsOutputs...> &coeffsOutputs,
201 const std::tuple<CoeffsCollPts...> &coeffsCollPts,
204 : inputs_(construct_tuple_from_tuples<Inputs...>(coeffsInputs,
init,
206 outputs_(construct_tuple_from_tuples<Outputs...>(coeffsOutputs,
init,
208 collPts_(construct_tuple_from_tuples<CollPts...>(coeffsCollPts,
init,
212 inline static constexpr std::size_t
ninputs() noexcept {
213 return sizeof...(Inputs);
217 inline constexpr const auto &
inputs()
const {
return inputs_; }
220 inline constexpr auto &
inputs() {
return inputs_; }
223 template <std::
size_t index>
inline constexpr const auto &
input()
const {
224 static_assert(index <
sizeof...(Inputs));
225 return std::get<index>(inputs_);
229 template <std::
size_t index>
inline constexpr auto &
input() {
230 static_assert(index <
sizeof...(Inputs));
231 return std::get<index>(inputs_);
235 inline static constexpr std::size_t
noutputs() noexcept {
236 return sizeof...(Outputs);
240 inline constexpr const auto &
outputs()
const {
return outputs_; }
243 inline constexpr auto &
outputs() {
return outputs_; }
246 template <std::
size_t index>
inline constexpr const auto &
output()
const {
247 static_assert(index <
sizeof...(Outputs));
248 return std::get<index>(outputs_);
252 template <std::
size_t index>
inline constexpr auto &
output() {
253 static_assert(index <
sizeof...(Outputs));
254 return std::get<index>(outputs_);
259 inline static constexpr std::size_t
ncollPts() noexcept {
260 return sizeof...(CollPts);
265 inline constexpr const auto &
collPts()
const {
return collPts_; }
269 inline constexpr auto &
collPts() {
return collPts_; }
277 template <std::
size_t index>
278 std::tuple_element_t<index, collPts_type>
284template <detail::HasAsTensor... Inputs, detail::HasAsTensor... Outputs>
285class IgABase<
std::tuple<Inputs...>, std::tuple<Outputs...>, void> {
288 using value_type = std::common_type_t<
typename Inputs::value_type...,
289 typename Outputs::value_type...>;
295 template <std::
size_t index>
296 using input_t = std::tuple_element_t<index, inputs_type>;
302 template <std::
size_t index>
303 using output_t = std::tuple_element_t<index, outputs_type>;
309 template <std::
size_t index>
310 using collPts_t = std::tuple_element_t<index, collPts_type>;
323 template <
typename... Objs, std::size_t... NumCoeffs, std::size_t... Is>
325 const std::tuple<std::array<int64_t, NumCoeffs>...> &numCoeffs,
327 std::index_sequence<Is...>) {
328 static_assert(
sizeof...(Objs) ==
sizeof...(NumCoeffs));
329 return std::make_tuple(Objs(std::get<Is>(numCoeffs),
init, options)...);
332 template <
typename... Objs, std::size_t... NumCoeffs>
334 const std::tuple<std::array<int64_t, NumCoeffs>...> &numCoeffs,
336 return construct_tuple_from_arrays_impl<Objs...>(
337 numCoeffs,
init, options, std::index_sequence_for<Objs...>{});
344 template <
typename... Objs,
typename... NumCoeffs, std::size_t... Is>
346 const std::tuple<NumCoeffs...> &numCoeffs,
enum init init,
348 static_assert(
sizeof...(Objs) ==
sizeof...(NumCoeffs));
349 return std::make_tuple(std::apply(
350 [&]<
typename... Args>(Args &&...args) {
351 return Objs(std::forward<Args>(args)...,
init, options);
353 std::get<Is>(numCoeffs))...);
356 template <
typename... Objs,
typename... NumCoeffsTuples>
361 return construct_tuple_from_tuples_impl<Objs...>(
362 numCoeffs,
init, options, std::index_sequence_for<Objs...>{});
370 : inputs_(), outputs_() {}
376 template <std::
size_t NumCoeffs>
378 const std::array<int64_t, NumCoeffs> &ncoeffs,
387 template <std::
size_t NumCoeffsInputs, std::
size_t NumCoeffsOutputs>
388 IgABase(
const std::array<int64_t, NumCoeffsInputs> &ncoeffsInputs,
389 const std::array<int64_t, NumCoeffsOutputs> &ncoeffsOutputs,
400 template <std::size_t... NumCoeffs>
402 const std::tuple<std::array<int64_t, NumCoeffs>...> &ncoeffs,
411 template <std::size_t... NumCoeffsInputs, std::size_t... NumCoeffsOutputs>
413 const std::tuple<std::array<int64_t, NumCoeffsInputs>...> &ncoeffsInputs,
414 const std::tuple<std::array<int64_t, NumCoeffsOutputs>...>
418 : inputs_(construct_tuple_from_arrays<Inputs...>(ncoeffsInputs,
init,
420 outputs_(construct_tuple_from_arrays<Outputs...>(ncoeffsOutputs,
init,
427 template <
typename... CoeffsInputs,
typename... CoeffsOutputs>
428 IgABase(
const std::tuple<CoeffsInputs...> &coeffsInputs,
429 const std::tuple<CoeffsOutputs...> &coeffsOutputs,
432 : inputs_(construct_tuple_from_tuples<Inputs...>(coeffsInputs,
init,
434 outputs_(construct_tuple_from_tuples<Outputs...>(coeffsOutputs,
init,
438 inline static constexpr std::size_t
ninputs() noexcept {
439 return sizeof...(Inputs);
443 inline constexpr const auto &
inputs()
const {
return inputs_; }
446 inline constexpr auto &
inputs() {
return inputs_; }
449 template <std::
size_t index>
inline constexpr const auto &
input()
const {
450 static_assert(index <
sizeof...(Inputs));
451 return std::get<index>(inputs_);
455 template <std::
size_t index>
inline constexpr auto &
input() {
456 static_assert(index <
sizeof...(Inputs));
457 return std::get<index>(inputs_);
461 inline static constexpr std::size_t
noutputs() noexcept {
462 return sizeof...(Outputs);
466 inline constexpr const auto &
outputs()
const {
return outputs_; }
469 inline constexpr auto &
outputs() {
return outputs_; }
472 template <std::
size_t index>
inline constexpr const auto &
output()
const {
473 static_assert(index <
sizeof...(Outputs));
474 return std::get<index>(outputs_);
478 template <std::
size_t index>
inline constexpr auto &
output() {
479 static_assert(index <
sizeof...(Outputs));
480 return std::get<index>(outputs_);
485 inline static constexpr std::size_t
ncollPts() noexcept {
486 return sizeof...(Outputs);
491 inline constexpr const auto &
collPts()
const {
return outputs_; }
495 inline constexpr auto &
collPts() {
return outputs_; }
503 template <std::
size_t index>
504 std::tuple_element_t<index, collPts_type>
514template <
typename Optimizer,
typename Inputs,
typename Outputs,
515 typename CollPts =
void>
516 requires OptimizerType<Optimizer>
538 std::unique_ptr<optimizer_type>
opt_;
557 template <
typename NumCoeffs>
558 IgANet(
const std::vector<int64_t> &layers,
559 const std::vector<std::vector<std::any>> &activations,
564 :
IgANet(layers, activations, numCoeffs, numCoeffs,
init, defaults,
569 template <
typename NumCoeffsInputs,
typename NumCoeffsOutputs>
570 IgANet(
const std::vector<int64_t> &layers,
571 const std::vector<std::vector<std::any>> &activations,
572 const NumCoeffsInputs &numCoeffsInputs,
573 const NumCoeffsOutputs &numCoeffsOutputs,
581 std::vector<int64_t>{
inputs( 0).size(0)}, layers,
582 std::vector<int64_t>{
outputs( 0).size(0)}),
586 opt_(std::make_unique<optimizer_type>(
net_->parameters())),
611 opt_ = std::make_unique<optimizer_type>(
net_->parameters());
613 std::vector<optimizer_options_type>
options;
614 for (
auto &group :
opt_->param_groups())
617 opt_ = std::make_unique<optimizer_type>(
net_->parameters());
631 if (param_group < opt_->param_groups().size())
633 opt_->param_groups()[param_group].options());
635 throw std::runtime_error(
"Index exceeds number of parameter groups");
641 if (param_group < opt_->param_groups().size())
643 opt_->param_groups()[param_group].options());
645 throw std::runtime_error(
"Index exceeds number of parameter groups");
650 for (
auto &group :
opt_->param_groups())
656 for (
auto &group :
opt_->param_groups())
662 std::size_t param_group) {
663 if (param_group < opt_->param_groups().size())
667 throw std::runtime_error(
"Index exceeds number of parameter groups");
672 std::size_t param_group) {
673 if (param_group < opt_->param_groups().size())
677 throw std::runtime_error(
"Index exceeds number of parameter groups");
687 inline constexpr const auto &
inputs()
const {
return Base::inputs(); }
690 inline constexpr auto &
inputs() {
return Base::inputs(); }
693 inline constexpr const auto &
outputs()
const {
return Base::outputs(); }
696 inline constexpr auto &
outputs() {
return Base::outputs(); }
701 Base::inputs_, [](
const auto &obj) {
return obj.as_tensor(); });
707 Base::outputs_, [](
const auto &obj) {
return obj.as_tensor(); });
711 virtual void inputs(
const torch::Tensor &tensor) {
713 Base::inputs_, tensor,
714 [](
const auto &obj) {
return obj.as_tensor_size(); },
715 [](
auto &obj,
const auto &tensor) {
return obj.from_tensor(tensor); });
719 virtual void outputs(
const torch::Tensor &tensor) {
721 Base::outputs_, tensor,
722 [](
const auto &obj) {
return obj.as_tensor_size(); },
723 [](
auto &obj,
const auto &tensor) {
return obj.from_tensor(tensor); });
730 virtual torch::Tensor
loss(
const torch::Tensor &, int64_t) = 0;
734#ifdef IGANET_WITH_MPI
735 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
736 c10d::ProcessGroupMPI::createProcessGroupMPI()
740 typename Base::value_type previous_loss(-1.0);
749 auto closure = [&]() {
760 loss.backward({},
true,
false);
765#ifdef IGANET_WITH_MPI
771 std::vector<c10::intrusive_ptr<::c10d::Work>> works;
772 for (
auto ¶m :
net_->named_parameters()) {
773 std::vector<torch::Tensor> tmp = {param.value().grad()};
774 works.emplace_back(pg->allreduce(tmp));
779 for (
auto ¶m :
net_->named_parameters()) {
780 param.value().grad().data() =
781 param.value().grad().data() / pg->getSize();
788 typename Base::value_type current_loss =
789 loss.item<
typename Base::value_type>();
791 << current_loss << std::endl;
793 if (current_loss <
options_.min_loss() ||
794 std::abs(current_loss - previous_loss) <
options_.min_loss_change() ||
795 std::abs(current_loss - previous_loss) / current_loss <
797 loss.isnan().item<
bool>()) {
799 <<
", loss: " << current_loss << std::endl;
802 previous_loss = current_loss;
805 <<
", loss: " << previous_loss << std::endl;
809 template <
typename DataLoader>
811#ifdef IGANET_WITH_MPI
813 c10::intrusive_ptr<c10d::ProcessGroupMPI> pg =
814 c10d::ProcessGroupMPI::createProcessGroupMPI()
818 typename Base::value_type previous_loss(-1.0);
823 typename Base::value_type current_loss(0);
825 for (
auto &batch : loader) {
863 auto closure = [&]() {
874 loss.backward({},
true,
false);
882 current_loss +=
loss.item<
typename Base::value_type>();
885 << current_loss << std::endl;
887 if (current_loss <
options_.min_loss() ||
888 std::abs(current_loss - previous_loss) <
options_.min_loss_change() ||
889 std::abs(current_loss - previous_loss) / current_loss <
891 loss.isnan().item<
bool>()) {
893 <<
", loss: " << current_loss << std::endl;
896 previous_loss = current_loss;
899 <<
", loss: " << previous_loss << std::endl;
910 inline nlohmann::json
to_json()
const override {
911 return "Not implemented yet";
915 inline std::vector<torch::Tensor>
parameters() const noexcept {
916 return net_->parameters();
921 inline torch::OrderedDict<std::string, torch::Tensor>
923 return net_->named_parameters();
928 std::size_t result = 0;
929 for (
const auto ¶m : this->
parameters()) {
930 result += param.numel();
937 return net_->register_parameter(
name, tensor, requires_grad);
942 os <<
name() <<
"(\n"
943 <<
"net = " <<
net_ <<
"\n";
945 os <<
"inputs[" << Base::ninputs() <<
"] = (";
946 std::apply([&os](
const auto &...elems) { ((os << elems <<
"\n"), ...); },
950 os <<
"outputs [" << Base::noutputs() <<
"]= (";
951 std::apply([&os](
const auto &...elems) { ((os << elems <<
"\n"), ...); },
955 os <<
"collPts [" << Base::ncollPts() <<
"]= (";
956 std::apply([&os](
const auto &...elems) { ((os << elems <<
"\n"), ...); },
962 inline void save(
const std::string &filename,
963 const std::string &key =
"iganet")
const {
964 torch::serialize::OutputArchive archive;
965 write(archive, key).save_to(filename);
969 inline void load(
const std::string &filename,
970 const std::string &key =
"iganet") {
971 torch::serialize::InputArchive archive;
972 archive.load_from(filename);
977 inline torch::serialize::OutputArchive &
978 write(torch::serialize::OutputArchive &archive,
979 const std::string &key =
"iganet")
const {
982 [&](
auto &&...elems) {
983 std::size_t counter = 0;
984 (elems.write(archive,
985 key +
".input[" + std::to_string(counter++) +
"]"),
991 [&](
auto &&...elems) {
992 std::size_t counter = 0;
993 (elems.write(archive,
994 key +
".output[" + std::to_string(counter++) +
"]"),
999 if constexpr (!std::is_void_v<CollPts>) {
1001 [&](
auto &&...elems) {
1002 std::size_t counter = 0;
1003 (elems.write(archive,
1004 key +
".collpts[" + std::to_string(counter++) +
"]"),
1010 net_->write(archive, key +
".net");
1011 torch::serialize::OutputArchive archive_net;
1012 net_->save(archive_net);
1013 archive.write(key +
".net.data", archive_net);
1015 torch::serialize::OutputArchive archive_opt;
1016 opt_->save(archive_opt);
1017 archive.write(key +
".opt", archive_opt);
1023 inline torch::serialize::InputArchive &
1024 read(torch::serialize::InputArchive &archive,
1025 const std::string &key =
"iganet") {
1028 [&](
auto &&...elems) {
1029 std::size_t counter = 0;
1030 (elems.read(archive,
1031 key +
".input[" + std::to_string(counter++) +
"]"),
1037 [&](
auto &&...elems) {
1038 std::size_t counter = 0;
1039 (elems.read(archive,
1040 key +
".output[" + std::to_string(counter++) +
"]"),
1045 if constexpr (!std::is_void_v<CollPts>) {
1047 [&](
auto &&...elems) {
1048 std::size_t counter = 0;
1049 (elems.read(archive,
1050 key +
".collpts[" + std::to_string(counter++) +
"]"),
1056 net_->read(archive, key +
".net");
1057 torch::serialize::InputArchive archive_net;
1058 archive.read(key +
".net.data", archive_net);
1059 net_->load(archive_net);
1061 opt_->add_parameters(
net_->parameters());
1062 torch::serialize::InputArchive archive_opt;
1063 archive.read(key +
".opt", archive_opt);
1064 opt_->load(archive_opt);
1073 result *= std::apply(
1074 [&](
auto &&...elemsThis) {
1076 [&](
auto &&...elemsOther) {
1077 return ((elemsThis == elemsOther) && ...);
1089#ifdef IGANET_WITH_MPI
1092 static void waitWork(c10::intrusive_ptr<c10d::ProcessGroupMPI> pg,
1093 std::vector<c10::intrusive_ptr<c10d::Work>> works) {
1094 for (
auto &work : works) {
1097 }
catch (
const std::exception &ex) {
1098 Log(
log::error) <<
"Exception received during waitWork: " << ex.what()
1108template <
typename Optimizer,
typename Inputs,
typename Outputs,
1110 requires OptimizerType<Optimizer>
1111inline std::ostream &
1133 []<
typename... Elems>(Elems &&...elems) {
1134 return std::make_tuple(([&] {
1135 using T = std::decay_t<Elems>;
1137 return elems.template find_knot_indices<functionspace::interior>(
1138 typename T::eval_type{});
1141 return elems.find_knot_indices(
typename T::eval_type{});
1150 []<
typename... Elems>(Elems &&...elems) {
1151 return std::make_tuple(([&] {
1152 using T = std::decay_t<Elems>;
1154 return elems.template find_knot_indices<functionspace::boundary>(
1155 typename T::boundary_eval_type{});
1158 return elems.find_knot_indices(
typename T::eval_type{});
1167 []<
typename... Elems>(Elems &&...elems) {
1168 return std::make_tuple(([&] {
1169 using T = std::decay_t<Elems>;
1171 return elems.template find_coeff_indices<functionspace::interior>(
1172 typename T::eval_type{});
1175 return elems.find_coeff_indices(
typename T::eval_type{});
1184 []<
typename... Elems>(Elems &&...elems) {
1185 return std::make_tuple(([&] {
1186 using T = std::decay_t<Elems>;
1188 return elems.template find_coeff_indices<functionspace::boundary>(
1189 typename T::boundary_eval_type{});
1192 return elems.find_coeff_indices(
typename T::eval_type{});
1201 std::declval<std::tuple<Inputs...>>()));
1205 template <std::
size_t index>
1207 std::tuple_element_t<index, inputs_interior_knot_indices_type>;
1211 std::declval<std::tuple<Inputs...>>()));
1215 template <std::
size_t index>
1217 std::tuple_element_t<index, inputs_boundary_knot_indices_type>;
1221 decltype(find_interior_knot_indices(
1222 std::declval<std::tuple<Outputs...>>()));
1226 template <std::
size_t index>
1228 std::tuple_element_t<index, outputs_interior_knot_indices_type>;
1232 decltype(find_boundary_knot_indices(
1233 std::declval<std::tuple<Outputs...>>()));
1237 template <std::
size_t index>
1239 std::tuple_element_t<index, outputs_boundary_knot_indices_type>;
1243 decltype(find_interior_coeff_indices(
1244 std::declval<std::tuple<Inputs...>>()));
1248 template <std::
size_t index>
1250 std::tuple_element_t<index, inputs_interior_coeff_indices_type>;
1254 decltype(find_boundary_coeff_indices(
1255 std::declval<std::tuple<Inputs...>>()));
1259 template <std::
size_t index>
1261 std::tuple_element_t<index, inputs_boundary_coeff_indices_type>;
1265 decltype(find_interior_coeff_indices(
1266 std::declval<std::tuple<Outputs...>>()));
1270 template <std::
size_t index>
1272 std::tuple_element_t<index, outputs_interior_coeff_indices_type>;
1276 decltype(find_boundary_coeff_indices(
1277 std::declval<std::tuple<Outputs...>>()));
1281 template <std::
size_t index>
1283 std::tuple_element_t<index, outputs_boundary_coeff_indices_type>;
1289 std::tuple<CollPts...>>
1296 decltype(std::declval<CollPts>()
1297 .template find_knot_indices<functionspace::interior>(
1298 std::declval<typename CollPts::eval_type>()))...>;
1303 decltype(std::declval<CollPts>()
1304 .template find_knot_indices<functionspace::boundary>(
1306 typename CollPts::boundary_eval_type>()))...>;
1311 decltype(std::declval<CollPts>()
1312 .template find_coeff_indices<functionspace::interior>(
1313 std::declval<typename CollPts::eval_type>()))...>;
1318 decltype(std::declval<CollPts>()
1319 .template find_coeff_indices<functionspace::boundary>(
1321 typename CollPts::boundary_eval_type>()))...>;
Definition unittest_iganet.cxx:24
IgA base class.
Definition iganet.hpp:46
IgANetGenerator.
Definition generator.hpp:927
IgANet.
Definition iganet.hpp:519
void optimizerReset(bool resetOptions=true)
Resets the optimizer.
Definition iganet.hpp:609
void optimizerReset(const optimizer_options_type &optimizerOptions)
Resets the optimizer.
Definition iganet.hpp:624
void load(const std::string &filename, const std::string &key="iganet")
Loads the IgANet from file.
Definition iganet.hpp:969
constexpr auto & inputs()
Returns a non-constant reference to the tuple of input objects.
Definition iganet.hpp:690
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="iganet") const
Writes the IgANet into a torch::serialize::OutputArchive object.
Definition iganet.hpp:978
virtual void outputs(const torch::Tensor &tensor)
Attaches the given tensor to the outputs.
Definition iganet.hpp:719
void train(DataLoader &loader)
Trains the IgANet.
Definition iganet.hpp:810
const optimizer_type & optimizer() const
Returns a constant reference to the optimizer.
Definition iganet.hpp:600
void optimizerOptionsReset(optimizer_options_type &&options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:671
torch::OrderedDict< std::string, torch::Tensor > named_parameters() const noexcept
Returns a constant reference to the named parameters of the IgANet object.
Definition iganet.hpp:922
Base::value_type value_type
Value type.
Definition iganet.hpp:525
std::vector< torch::Tensor > parameters() const noexcept
Returns a constant reference to the parameters of the IgANet object.
Definition iganet.hpp:915
torch::Tensor & register_parameter(std::string name, torch::Tensor tensor, bool requires_grad=true)
Registers a parameter.
Definition iganet.hpp:936
void save(const std::string &filename, const std::string &key="iganet") const
Saves the IgANet to file.
Definition iganet.hpp:962
void optimizerOptionsReset(const optimizer_options_type &options)
Resets the optimizer options.
Definition iganet.hpp:649
IgANet(const IgANetOptions &defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Default constructor.
Definition iganet.hpp:545
optimizer_type & optimizer()
Returns a non-constant reference to the optimizer.
Definition iganet.hpp:603
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="iganet")
Loads the IgANet from a torch::serialize::InputArchive object.
Definition iganet.hpp:1024
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, const NumCoeffsInputs &numCoeffsInputs, const NumCoeffsOutputs &numCoeffsOutputs, enum init init=init::greville, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for all ...
Definition iganet.hpp:570
virtual bool epoch(int64_t)=0
Initializes epoch.
const IgANetGenerator< typename Base::value_type > & net() const
Returns a constant reference to the IgANet generator.
Definition iganet.hpp:592
auto & options()
Returns a non-constant reference to the options structure.
Definition iganet.hpp:684
IgANetGenerator< typename Base::value_type > net_
IgANet generator.
Definition iganet.hpp:535
IgABase< Inputs, Outputs, CollPts > Base
Base type.
Definition iganet.hpp:522
std::unique_ptr< optimizer_type > opt_
Optimizer.
Definition iganet.hpp:538
bool operator!=(const IgANet &other) const
Returns true if both IgANet objects are different.
Definition iganet.hpp:1087
IgANetOptions options_
Options.
Definition iganet.hpp:541
constexpr auto & outputs()
Returns a non-constant reference to the tuple of output objects.
Definition iganet.hpp:696
optimizer_options_type & optimizerOptions(std::size_t param_group=0)
Returns a non-constant reference to the optimizer options.
Definition iganet.hpp:630
virtual torch::Tensor outputs(int64_t epoch) const
Returns the network outputs as tensor.
Definition iganet.hpp:705
IgANet(const std::vector< int64_t > &layers, const std::vector< std::vector< std::any > > &activations, const NumCoeffs &numCoeffs, enum init init=init::greville, IgANetOptions defaults={}, iganet::Options< typename Base::value_type > options=iganet::Options< typename Base::value_type >{})
Constructor: number of layers, activation functions, and number of spline coefficients (same for all ...
Definition iganet.hpp:558
virtual void train()
Trains the IgANet.
Definition iganet.hpp:733
std::size_t nparameters() const noexcept
Returns the total number of parameters of the IgANet object.
Definition iganet.hpp:927
Optimizer optimizer_type
Type of the optimizer.
Definition iganet.hpp:528
void optimizerOptionsReset(optimizer_options_type &&options)
Resets the optimizer options.
Definition iganet.hpp:655
optimizer_options_type< Optimizer >::type optimizer_options_type
Type of the optimizer options.
Definition iganet.hpp:531
virtual torch::Tensor inputs(int64_t epoch) const
Returns the network inputs as tensor.
Definition iganet.hpp:699
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the IgANet object.
Definition iganet.hpp:941
virtual void inputs(const torch::Tensor &tensor)
Attaches the given tensor to the inputs.
Definition iganet.hpp:711
bool operator==(const IgANet &other) const
Returns true if both IgANet objects are the same.
Definition iganet.hpp:1070
nlohmann::json to_json() const override
Returns the IgANet object as JSON object.
Definition iganet.hpp:910
void optimizerOptionsReset(const optimizer_options_type &options, std::size_t param_group)
Resets the optimizer options.
Definition iganet.hpp:661
constexpr const auto & inputs() const
Returns a constant reference to the tuple of input objects.
Definition iganet.hpp:687
constexpr const auto & outputs() const
Returns a constant reference to the tuple of output objects.
Definition iganet.hpp:693
void eval()
Evaluate IgANet.
Definition iganet.hpp:903
const auto & options() const
Returns a constant reference to the options structure.
Definition iganet.hpp:681
IgANetGenerator< typename Base::value_type > & net()
Returns a non-constant reference to the IgANet generator.
Definition iganet.hpp:597
const optimizer_options_type & optimizerOptions(std::size_t param_group=0) const
Returns a constant reference to the optimizer options.
Definition iganet.hpp:640
virtual torch::Tensor loss(const torch::Tensor &, int64_t)=0
Computes the loss function.
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:104
Full qualified name descriptor.
Definition fqn.hpp:22
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:28
Isogeometric analysis base class.
Definition bspline.hpp:119
Definition bspline.hpp:112
Definition functionspace.hpp:54
Definition functionspace.hpp:47
Container utility functions.
Full qualified name utility functions.
auto zip(T &&...seqs)
Definition zip.hpp:97
void slice_tensor_into_tuple(std::tuple< Tensors... > &tuple, const torch::Tensor &tensor, FuncSize &&funcSize, FuncAssign &&funcAssign, int64_t &offset, int64_t dim=0)
Slices the given tensor into the objects of the std::tuple.
Definition tuple.hpp:119
torch::Tensor cat_tuple_into_tensor(const std::tuple< Tensors... > &tensors, int64_t dim=0)
Concatenates the entries of a std::tuple object into a single Torch tensor along the given dimension.
Definition tuple.hpp:80
collPts
Enumerator for the collocation point specifier.
Definition collocation.hpp:21
std::ostream & operator<<(std::ostream &os, const MemoryDebugger< id > &obj)
Print (as string) a memory debugger object.
Definition memory.hpp:125
struct iganet::@0 Log
Logger.
init
Enumerator for specifying the initialization of B-spline coefficients.
Definition bspline.hpp:55
Collocation points helper
Definition collocation.hpp:35
IgANetCustomizable.
Definition iganet.hpp:1125
Type trait for the optimizer options type.
Definition optimizer.hpp:32
IgANetOptions.
Definition iganet.hpp:34
TORCH_ARG(double, min_loss_change)=0
TORCH_ARG(int64_t, batch_size)
TORCH_ARG(double, min_loss)
TORCH_ARG(double, min_loss_rel_change)
TORCH_ARG(int64_t, max_epoch)
Serialization prototype.
Definition serialize.hpp:29