30template <
typename Optimizer>
35 using type = torch::optim::AdagradOptions;
39 using type = torch::optim::AdamOptions;
43 using type = torch::optim::AdamWOptions;
47 using type = torch::optim::LBFGSOptions;
51 using type = torch::optim::SGDOptions;
55 using type = torch::optim::RMSpropOptions;
66 const torch::optim::AdagradOptions &obj) {
67 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
71 if (name_->find(
"struct ") == 0) {
72 name_->erase(name_->begin(), name_->begin() + 7);
73 }
else if (name_->find(
"class ") == 0) {
74 name_->erase(name_->begin(), name_->begin() + 6);
78 os << *name_ <<
"(\nlr = " << obj.lr() <<
", lr_decay = " << obj.lr_decay()
79 <<
", weight_decay = " << obj.weight_decay()
80 <<
", initial_accumulator_value = " << obj.initial_accumulator_value()
81 <<
", eps = " << obj.eps() <<
"\n)";
88 const torch::optim::AdamOptions &obj) {
89 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
93 if (name_->find(
"struct ") == 0) {
94 name_->erase(name_->begin(), name_->begin() + 7);
95 }
else if (name_->find(
"class ") == 0) {
96 name_->erase(name_->begin(), name_->begin() + 6);
100 os << *name_ <<
"(\nlr = " << obj.lr() <<
", betas = ["
101 << std::get<0>(obj.betas()) <<
", " << std::get<1>(obj.betas()) <<
"]"
102 <<
", weight_decay = " << obj.weight_decay() <<
", eps = " << obj.eps()
103 <<
", amsgrad = " << obj.amsgrad() <<
"\n)";
110 const torch::optim::AdamWOptions &obj) {
111 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
115 if (name_->find(
"struct ") == 0) {
116 name_->erase(name_->begin(), name_->begin() + 7);
117 }
else if (name_->find(
"class ") == 0) {
118 name_->erase(name_->begin(), name_->begin() + 6);
122 os << *name_ <<
"(\nlr = " << obj.lr() <<
", betas = ["
123 << std::get<0>(obj.betas()) <<
", " << std::get<1>(obj.betas()) <<
"]"
124 <<
", weight_decay = " << obj.weight_decay() <<
", eps = " << obj.eps()
125 <<
", amsgrad = " << obj.amsgrad() <<
"\n)";
132 const torch::optim::LBFGSOptions &obj) {
133 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
137 if (name_->find(
"struct ") == 0) {
138 name_->erase(name_->begin(), name_->begin() + 7);
139 }
else if (name_->find(
"class ") == 0) {
140 name_->erase(name_->begin(), name_->begin() + 6);
144 os << *name_ <<
"(\nlr = " << obj.lr() <<
", max_iter = " << obj.max_iter()
146 << (obj.max_eval().has_value() ? std::to_string(*obj.max_eval())
148 <<
", tolerance_grad = " << obj.tolerance_grad()
149 <<
", tolerance_change = " << obj.tolerance_change()
150 <<
", history_size = " << obj.history_size() <<
", line_search_fn = "
151 << (obj.line_search_fn().has_value() ? *obj.line_search_fn() :
"undefined")
159 const torch::optim::RMSpropOptions &obj) {
160 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
164 if (name_->find(
"struct ") == 0) {
165 name_->erase(name_->begin(), name_->begin() + 7);
166 }
else if (name_->find(
"class ") == 0) {
167 name_->erase(name_->begin(), name_->begin() + 6);
171 os << *name_ <<
"(\nlr = " << obj.lr() <<
", alpha = " << obj.alpha()
172 <<
", eps = " << obj.eps() <<
", weight_decay = " << obj.weight_decay()
173 <<
", momentum = " << obj.momentum() <<
", centered = " << obj.centered()
181 const torch::optim::SGDOptions &obj) {
182 at::optional<std::string> name_ = c10::demangle(
typeid(obj).name());
186 if (name_->find(
"struct ") == 0) {
187 name_->erase(name_->begin(), name_->begin() + 7);
188 }
else if (name_->find(
"class ") == 0) {
189 name_->erase(name_->begin(), name_->begin() + 6);
193 os << *name_ <<
"(\nlr = " << obj.lr() <<
", momentum = " << obj.momentum()
194 <<
", dampening = " << obj.dampening()
195 <<
", weight_decay = " << obj.weight_decay()
196 <<
", nesterov = " << obj.nesterov() <<
"\n)";
Concept to identify template parameters that are derived from torch::optim::Optimizer.
Definition optimizer.hpp:26
torch::optim::AdamWOptions type
Definition optimizer.hpp:43
torch::optim::SGDOptions type
Definition optimizer.hpp:51
torch::optim::LBFGSOptions type
Definition optimizer.hpp:47
torch::optim::RMSpropOptions type
Definition optimizer.hpp:55
torch::optim::AdagradOptions type
Definition optimizer.hpp:35
torch::optim::AdamOptions type
Definition optimizer.hpp:39
Type trait for the optimizer options type.
Definition optimizer.hpp:32
std::ostream & operator<<(std::ostream &os, const torch::optim::AdagradOptions &obj)
Print (as string) a torch::optim::AdagradOptions object.
Definition optimizer.hpp:65
Definition optimizer.hpp:61