IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
optimizer.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <iostream>
18
19#include <core/core.hpp>
20
21namespace iganet {
22
25template <typename T>
26concept OptimizerType = std::is_base_of_v<torch::optim::Optimizer, T>;
27
30template <typename Optimizer>
33
34template <> struct optimizer_options_type<torch::optim::Adagrad> {
35 using type = torch::optim::AdagradOptions;
36};
37
38template <> struct optimizer_options_type<torch::optim::Adam> {
39 using type = torch::optim::AdamOptions;
40};
41
42template <> struct optimizer_options_type<torch::optim::AdamW> {
43 using type = torch::optim::AdamWOptions;
44};
45
46template <> struct optimizer_options_type<torch::optim::LBFGS> {
47 using type = torch::optim::LBFGSOptions;
48};
49
50template <> struct optimizer_options_type<torch::optim::SGD> {
51 using type = torch::optim::SGDOptions;
52};
53
54template <> struct optimizer_options_type<torch::optim::RMSprop> {
55 using type = torch::optim::RMSpropOptions;
56};
58
59} // namespace iganet
60
61namespace torch {
62namespace optim {
63
65inline std::ostream &operator<<(std::ostream &os,
66 const torch::optim::AdagradOptions &obj) {
67 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
68
69#if defined(_WIN32)
70 // Windows adds "struct" or "class" as a prefix.
71 if (name_->find("struct ") == 0) {
72 name_->erase(name_->begin(), name_->begin() + 7);
73 } else if (name_->find("class ") == 0) {
74 name_->erase(name_->begin(), name_->begin() + 6);
75 }
76#endif // defined(_WIN32)
77
78 os << *name_ << "(\nlr = " << obj.lr() << ", lr_decay = " << obj.lr_decay()
79 << ", weight_decay = " << obj.weight_decay()
80 << ", initial_accumulator_value = " << obj.initial_accumulator_value()
81 << ", eps = " << obj.eps() << "\n)";
82
83 return os;
84}
85
87inline std::ostream &operator<<(std::ostream &os,
88 const torch::optim::AdamOptions &obj) {
89 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
90
91#if defined(_WIN32)
92 // Windows adds "struct" or "class" as a prefix.
93 if (name_->find("struct ") == 0) {
94 name_->erase(name_->begin(), name_->begin() + 7);
95 } else if (name_->find("class ") == 0) {
96 name_->erase(name_->begin(), name_->begin() + 6);
97 }
98#endif // defined(_WIN32)
99
100 os << *name_ << "(\nlr = " << obj.lr() << ", betas = ["
101 << std::get<0>(obj.betas()) << ", " << std::get<1>(obj.betas()) << "]"
102 << ", weight_decay = " << obj.weight_decay() << ", eps = " << obj.eps()
103 << ", amsgrad = " << obj.amsgrad() << "\n)";
104
105 return os;
106}
107
109inline std::ostream &operator<<(std::ostream &os,
110 const torch::optim::AdamWOptions &obj) {
111 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
112
113#if defined(_WIN32)
114 // Windows adds "struct" or "class" as a prefix.
115 if (name_->find("struct ") == 0) {
116 name_->erase(name_->begin(), name_->begin() + 7);
117 } else if (name_->find("class ") == 0) {
118 name_->erase(name_->begin(), name_->begin() + 6);
119 }
120#endif // defined(_WIN32)
121
122 os << *name_ << "(\nlr = " << obj.lr() << ", betas = ["
123 << std::get<0>(obj.betas()) << ", " << std::get<1>(obj.betas()) << "]"
124 << ", weight_decay = " << obj.weight_decay() << ", eps = " << obj.eps()
125 << ", amsgrad = " << obj.amsgrad() << "\n)";
126
127 return os;
128}
129
131inline std::ostream &operator<<(std::ostream &os,
132 const torch::optim::LBFGSOptions &obj) {
133 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
134
135#if defined(_WIN32)
136 // Windows adds "struct" or "class" as a prefix.
137 if (name_->find("struct ") == 0) {
138 name_->erase(name_->begin(), name_->begin() + 7);
139 } else if (name_->find("class ") == 0) {
140 name_->erase(name_->begin(), name_->begin() + 6);
141 }
142#endif // defined(_WIN32)
143
144 os << *name_ << "(\nlr = " << obj.lr() << ", max_iter = " << obj.max_iter()
145 << ", max_eval = "
146 << (obj.max_eval().has_value() ? std::to_string(*obj.max_eval())
147 : "undefined")
148 << ", tolerance_grad = " << obj.tolerance_grad()
149 << ", tolerance_change = " << obj.tolerance_change()
150 << ", history_size = " << obj.history_size() << ", line_search_fn = "
151 << (obj.line_search_fn().has_value() ? *obj.line_search_fn() : "undefined")
152 << "\n)";
153
154 return os;
155}
156
158inline std::ostream &operator<<(std::ostream &os,
159 const torch::optim::RMSpropOptions &obj) {
160 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
161
162#if defined(_WIN32)
163 // Windows adds "struct" or "class" as a prefix.
164 if (name_->find("struct ") == 0) {
165 name_->erase(name_->begin(), name_->begin() + 7);
166 } else if (name_->find("class ") == 0) {
167 name_->erase(name_->begin(), name_->begin() + 6);
168 }
169#endif // defined(_WIN32)
170
171 os << *name_ << "(\nlr = " << obj.lr() << ", alpha = " << obj.alpha()
172 << ", eps = " << obj.eps() << ", weight_decay = " << obj.weight_decay()
173 << ", momentum = " << obj.momentum() << ", centered = " << obj.centered()
174 << "\n)";
175
176 return os;
177}
178
180inline std::ostream &operator<<(std::ostream &os,
181 const torch::optim::SGDOptions &obj) {
182 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
183
184#if defined(_WIN32)
185 // Windows adds "struct" or "class" as a prefix.
186 if (name_->find("struct ") == 0) {
187 name_->erase(name_->begin(), name_->begin() + 7);
188 } else if (name_->find("class ") == 0) {
189 name_->erase(name_->begin(), name_->begin() + 6);
190 }
191#endif // defined(_WIN32)
192
193 os << *name_ << "(\nlr = " << obj.lr() << ", momentum = " << obj.momentum()
194 << ", dampening = " << obj.dampening()
195 << ", weight_decay = " << obj.weight_decay()
196 << ", nesterov = " << obj.nesterov() << "\n)";
197
198 return os;
199}
200
201} // namespace optim
202} // namespace torch
Concept to identify template parameters that are derived from torch::optim::Optimizer.
Definition optimizer.hpp:26
Core components.
Definition core.hpp:72
torch::optim::AdamWOptions type
Definition optimizer.hpp:43
torch::optim::SGDOptions type
Definition optimizer.hpp:51
torch::optim::LBFGSOptions type
Definition optimizer.hpp:47
torch::optim::RMSpropOptions type
Definition optimizer.hpp:55
torch::optim::AdagradOptions type
Definition optimizer.hpp:35
torch::optim::AdamOptions type
Definition optimizer.hpp:39
Type trait for the optimizer options type.
Definition optimizer.hpp:32
std::ostream & operator<<(std::ostream &os, const torch::optim::AdagradOptions &obj)
Print (as string) a torch::optim::AdagradOptions object.
Definition optimizer.hpp:65
Definition optimizer.hpp:61