IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
optimizer.hpp
Go to the documentation of this file.
1
15#pragma once
16
17namespace iganet {
18
20 template<typename T>
21 concept OptimizerType = std::is_base_of_v<torch::optim::Optimizer, T>;
22
25 template <typename Optimizer>
28
29 template <>
30 struct optimizer_options_type<torch::optim::Adagrad> {
31 using type = torch::optim::AdagradOptions;
32 };
33
34 template <>
35 struct optimizer_options_type<torch::optim::Adam> {
36 using type = torch::optim::AdamOptions;
37 };
38
39 template <>
40 struct optimizer_options_type<torch::optim::AdamW> {
41 using type = torch::optim::AdamWOptions;
42 };
43
44 template <>
45 struct optimizer_options_type<torch::optim::LBFGS> {
46 using type = torch::optim::LBFGSOptions;
47 };
48
49 template <>
50 struct optimizer_options_type<torch::optim::SGD> {
51 using type = torch::optim::SGDOptions;
52 };
53
54 template <>
55 struct optimizer_options_type<torch::optim::RMSprop> {
56 using type = torch::optim::RMSpropOptions;
57 };
59
60} // namespace iganet
61
62namespace torch {
63namespace optim {
64
66inline std::ostream &operator<<(std::ostream &os,
67 const torch::optim::AdagradOptions &obj) {
68 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
69
70#if defined(_WIN32)
71 // Windows adds "struct" or "class" as a prefix.
72 if (name_->find("struct ") == 0) {
73 name_->erase(name_->begin(), name_->begin() + 7);
74 } else if (name_->find("class ") == 0) {
75 name_->erase(name_->begin(), name_->begin() + 6);
76 }
77#endif // defined(_WIN32)
78
79 os << *name_ << "(\nlr = " << obj.lr()
80 << ", lr_decay = " << obj.lr_decay()
81 << ", weight_decay = " << obj.weight_decay()
82 << ", initial_accumulator_value = " << obj.initial_accumulator_value()
83 << ", eps = " << obj.eps()
84 << "\n)";
85
86 return os;
87}
88
90inline std::ostream &operator<<(std::ostream &os,
91 const torch::optim::AdamOptions &obj) {
92 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
93
94#if defined(_WIN32)
95 // Windows adds "struct" or "class" as a prefix.
96 if (name_->find("struct ") == 0) {
97 name_->erase(name_->begin(), name_->begin() + 7);
98 } else if (name_->find("class ") == 0) {
99 name_->erase(name_->begin(), name_->begin() + 6);
100 }
101#endif // defined(_WIN32)
102
103 os << *name_ << "(\nlr = " << obj.lr()
104 << ", betas = [" << std::get<0>(obj.betas()) << ", " << std::get<1>(obj.betas()) << "]"
105 << ", weight_decay = " << obj.weight_decay()
106 << ", eps = " << obj.eps()
107 << ", amsgrad = " << obj.amsgrad()
108 << "\n)";
109
110 return os;
111}
112
114inline std::ostream &operator<<(std::ostream &os,
115 const torch::optim::AdamWOptions &obj) {
116 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
117
118#if defined(_WIN32)
119 // Windows adds "struct" or "class" as a prefix.
120 if (name_->find("struct ") == 0) {
121 name_->erase(name_->begin(), name_->begin() + 7);
122 } else if (name_->find("class ") == 0) {
123 name_->erase(name_->begin(), name_->begin() + 6);
124 }
125#endif // defined(_WIN32)
126
127 os << *name_ << "(\nlr = " << obj.lr()
128 << ", betas = [" << std::get<0>(obj.betas()) << ", " << std::get<1>(obj.betas()) << "]"
129 << ", weight_decay = " << obj.weight_decay()
130 << ", eps = " << obj.eps()
131 << ", amsgrad = " << obj.amsgrad()
132 << "\n)";
133
134 return os;
135}
136
138inline std::ostream &operator<<(std::ostream &os,
139 const torch::optim::LBFGSOptions &obj) {
140 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
141
142#if defined(_WIN32)
143 // Windows adds "struct" or "class" as a prefix.
144 if (name_->find("struct ") == 0) {
145 name_->erase(name_->begin(), name_->begin() + 7);
146 } else if (name_->find("class ") == 0) {
147 name_->erase(name_->begin(), name_->begin() + 6);
148 }
149#endif // defined(_WIN32)
150
151 os << *name_ << "(\nlr = " << obj.lr()
152 << ", max_iter = " << obj.max_iter()
153 << ", max_eval = " << (obj.max_eval().has_value() ? std::to_string(*obj.max_eval()) : "undefined")
154 << ", tolerance_grad = " << obj.tolerance_grad()
155 << ", tolerance_change = " << obj.tolerance_change()
156 << ", history_size = " << obj.history_size()
157 << ", line_search_fn = " << (obj.line_search_fn().has_value() ? *obj.line_search_fn() : "undefined")
158 << "\n)";
159
160 return os;
161}
162
164inline std::ostream &operator<<(std::ostream &os,
165 const torch::optim::RMSpropOptions &obj) {
166 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
167
168#if defined(_WIN32)
169 // Windows adds "struct" or "class" as a prefix.
170 if (name_->find("struct ") == 0) {
171 name_->erase(name_->begin(), name_->begin() + 7);
172 } else if (name_->find("class ") == 0) {
173 name_->erase(name_->begin(), name_->begin() + 6);
174 }
175#endif // defined(_WIN32)
176
177 os << *name_ << "(\nlr = " << obj.lr()
178 << ", alpha = " << obj.alpha()
179 << ", eps = " << obj.eps()
180 << ", weight_decay = " << obj.weight_decay()
181 << ", momentum = " << obj.momentum()
182 << ", centered = " << obj.centered()
183 << "\n)";
184
185 return os;
186}
187
189inline std::ostream &operator<<(std::ostream &os,
190 const torch::optim::SGDOptions &obj) {
191 at::optional<std::string> name_ = c10::demangle(typeid(obj).name());
192
193#if defined(_WIN32)
194 // Windows adds "struct" or "class" as a prefix.
195 if (name_->find("struct ") == 0) {
196 name_->erase(name_->begin(), name_->begin() + 7);
197 } else if (name_->find("class ") == 0) {
198 name_->erase(name_->begin(), name_->begin() + 6);
199 }
200#endif // defined(_WIN32)
201
202 os << *name_ << "(\nlr = " << obj.lr()
203 << ", momentum = " << obj.momentum()
204 << ", dampening = " << obj.dampening()
205 << ", weight_decay = " << obj.weight_decay()
206 << ", nesterov = " << obj.nesterov()
207 << "\n)";
208
209 return os;
210}
211
212} // namespace optim
213} // namespace torch
Concept to identify template parameters that are derived from torch::optim::Optimizer.
Definition optimizer.hpp:21
Definition boundary.hpp:22
torch::optim::AdamWOptions type
Definition optimizer.hpp:41
torch::optim::SGDOptions type
Definition optimizer.hpp:51
torch::optim::LBFGSOptions type
Definition optimizer.hpp:46
torch::optim::RMSpropOptions type
Definition optimizer.hpp:56
torch::optim::AdagradOptions type
Definition optimizer.hpp:31
torch::optim::AdamOptions type
Definition optimizer.hpp:36
Type trait for the optimizer options type.
Definition optimizer.hpp:27
std::ostream & operator<<(std::ostream &os, const torch::optim::AdagradOptions &obj)
Print (as string) a torch::optim::AdagradOptions object.
Definition optimizer.hpp:66
Definition optimizer.hpp:62