70 virtual torch::Tensor
apply(
const torch::Tensor &)
const = 0;
77 virtual torch::serialize::OutputArchive &
78 write(torch::serialize::OutputArchive &archive,
79 const std::string &key)
const = 0;
83 virtual torch::serialize::InputArchive &
84 read(torch::serialize::InputArchive &archive,
const std::string &key) = 0;
98 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
109 inline torch::serialize::OutputArchive &
110 write(torch::serialize::OutputArchive &archive,
111 const std::string &key =
"none")
const override {
112 archive.write(key +
".type",
120 inline torch::serialize::InputArchive &
121 read(torch::serialize::InputArchive &archive,
122 const std::string &key =
"none")
override {
123 torch::Tensor tensor;
125 archive.read(key +
".type", tensor);
127 throw std::runtime_error(
"activation mismatch");
141 torch::nn::functional::BatchNormFuncOptions
options = {})
146 const torch::Tensor &weight,
const torch::Tensor &bias,
147 double eps,
double momentum,
bool training =
false)
153 .training(training)),
160 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
178 inline const torch::nn::functional::BatchNormFuncOptions &
options()
const {
183 inline torch::nn::functional::BatchNormFuncOptions &
options() {
193#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
196 <<
", training=" <<
options_.training();
201 <<
"\n weight = " <<
options_.weight()
210 inline torch::serialize::OutputArchive &
211 write(torch::serialize::OutputArchive &archive,
212 const std::string &key =
"batch_norm")
const override {
213 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
215 archive.write(key +
".running_mean", this->
running_mean());
216 archive.write(key +
".running_var", this->
running_var());
217 archive.write(key +
".weight", this->
options_.weight());
218 archive.write(key +
".bias", this->
options_.bias());
219 archive.write(key +
".eps", torch::full({1}, (double)this->
options_.eps()));
220 archive.write(key +
".momentum", torch::full({1}, (double)this->
options_
222#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
226 archive.write(key +
".training",
227 torch::full({1}, (bool)this->
options_.training()));
234 inline torch::serialize::InputArchive &
235 read(torch::serialize::InputArchive &archive,
236 const std::string &key =
"batch_norm")
override {
237 torch::Tensor tensor;
239 archive.read(key +
".type", tensor);
241 throw std::runtime_error(
"activation mismatch");
243 archive.read(key +
".running_mean", this->
running_mean());
244 archive.read(key +
".running_var", this->
running_var());
245 archive.read(key +
".weight", this->
options_.weight());
246 archive.read(key +
".bias", this->
options_.bias());
247 archive.read(key +
".eps", tensor);
248 this->
options_.eps(tensor.item<
double>());
249 archive.read(key +
".momentum", tensor);
250 this->
options_.momentum(tensor.item<
double>());
251 archive.read(key +
".training", tensor);
252 this->
options_.training(tensor.item<
bool>());
258 torch::nn::functional::BatchNormFuncOptions
options_;
270 explicit CELU(torch::nn::functional::CELUFuncOptions
options = {})
273 explicit CELU(
double alpha,
bool inplace =
false)
274 :
options_(
torch::nn::functional::CELUFuncOptions().alpha(alpha).inplace(
280 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
281 return torch::nn::functional::celu(input,
options_);
285 inline const torch::nn::functional::CELUFuncOptions &
options()
const {
295 <<
", inplace=" <<
options_.inplace() <<
"\n)";
300 inline torch::serialize::OutputArchive &
301 write(torch::serialize::OutputArchive &archive,
302 const std::string &key =
"celu")
const override {
303 archive.write(key +
".type",
305 archive.write(key +
".alpha",
306 torch::full({1}, (double)this->
options_.alpha()));
307 archive.write(key +
".inplace",
308 torch::full({1}, (bool)this->
options_.inplace()));
315 inline torch::serialize::InputArchive &
316 read(torch::serialize::InputArchive &archive,
317 const std::string &key =
"celu")
override {
318 torch::Tensor tensor;
320 archive.read(key +
".type", tensor);
322 throw std::runtime_error(
"activation mismatch");
324 archive.read(key +
".alpha", tensor);
325 this->
options_.alpha(tensor.item<
double>());
326 archive.read(key +
".inplace", tensor);
327 this->
options_.inplace(tensor.item<
bool>());
347 explicit ELU(torch::nn::functional::ELUFuncOptions
options = {})
350 explicit ELU(
double alpha,
bool inplace =
false)
351 :
options_(
torch::nn::functional::ELUFuncOptions().alpha(alpha).inplace(
357 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
358 return torch::nn::functional::elu(input,
options_);
362 inline const torch::nn::functional::ELUFuncOptions &
options()
const {
373 <<
", inplace=" <<
options_.inplace() <<
"\n)";
378 inline torch::serialize::OutputArchive &
379 write(torch::serialize::OutputArchive &archive,
380 const std::string &key =
"elu")
const override {
381 archive.write(key +
".type",
383 archive.write(key +
".alpha",
384 torch::full({1}, (double)this->
options_.alpha()));
385 archive.write(key +
".inplace",
386 torch::full({1}, (bool)this->
options_.inplace()));
393 inline torch::serialize::InputArchive &
394 read(torch::serialize::InputArchive &archive,
395 const std::string &key =
"elu")
override {
396 torch::Tensor tensor;
398 archive.read(key +
".type", tensor);
400 throw std::runtime_error(
"activation mismatch");
402 archive.read(key +
".alpha", tensor);
403 this->
options_.alpha(tensor.item<
double>());
404 archive.read(key +
".inplace", tensor);
405 this->
options_.inplace(tensor.item<
bool>());
429 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
430 return torch::gelu(input);
440 inline torch::serialize::OutputArchive &
441 write(torch::serialize::OutputArchive &archive,
442 const std::string &key =
"gelu")
const override {
443 archive.write(key +
".type",
451 inline torch::serialize::InputArchive &
452 read(torch::serialize::InputArchive &archive,
453 const std::string &key =
"gelu")
override {
454 torch::Tensor tensor;
456 archive.read(key +
".type", tensor);
458 throw std::runtime_error(
"activation mismatch");
475 explicit GLU(torch::nn::functional::GLUFuncOptions
options = {})
479 :
options_(
torch::nn::functional::GLUFuncOptions().dim(dim)) {}
484 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
485 return torch::nn::functional::glu(input,
options_);
489 inline const torch::nn::functional::GLUFuncOptions &
options()
const {
504 inline torch::serialize::OutputArchive &
505 write(torch::serialize::OutputArchive &archive,
506 const std::string &key =
"glu")
const override {
507 archive.write(key +
".type",
509 archive.write(key +
".dim",
510 torch::full({1},
static_cast<int>(this->
options_.dim())));
517 inline torch::serialize::InputArchive &
518 read(torch::serialize::InputArchive &archive,
519 const std::string &key =
"glu")
override {
520 torch::Tensor tensor;
522 archive.read(key +
".type", tensor);
524 throw std::runtime_error(
"activation mismatch");
526 archive.read(key +
".dim", tensor);
527 this->
options_.dim(tensor.item<
int>());
541 :
options_(
torch::nn::functional::GroupNormFuncOptions(num_groups)) {}
546 explicit GroupNorm(int64_t num_groups,
const torch::Tensor &weight,
547 const torch::Tensor &bias,
double eps)
548 :
options_(
torch::nn::functional::GroupNormFuncOptions(num_groups)
556 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
557 return torch::nn::functional::group_norm(input,
options_);
561 inline const torch::nn::functional::GroupNormFuncOptions &
options()
const {
566 inline torch::nn::functional::GroupNormFuncOptions &
options() {
576 os <<
"\n weight = " <<
options_.weight()
585 inline torch::serialize::OutputArchive &
586 write(torch::serialize::OutputArchive &archive,
587 const std::string &key =
"group_norm")
const override {
588 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
590 archive.write(key +
".weight", this->
options_.weight());
591 archive.write(key +
".bias", this->
options_.bias());
592 archive.write(key +
".eps", torch::full({1}, (double)this->
options_.eps()));
599 inline torch::serialize::InputArchive &
600 read(torch::serialize::InputArchive &archive,
601 const std::string &key =
"group_norm")
override {
602 torch::Tensor tensor;
604 archive.read(key +
".type", tensor);
606 throw std::runtime_error(
"activation mismatch");
608 archive.read(key +
".weight", this->
options_.weight());
609 archive.read(key +
".bias", this->
options_.bias());
610 archive.read(key +
".eps", tensor);
611 this->
options_.eps(tensor.item<
double>());
617 torch::nn::functional::GroupNormFuncOptions
options_;
624 torch::nn::functional::GumbelSoftmaxFuncOptions
options = {})
636 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
637 return torch::nn::functional::gumbel_softmax(input,
options_);
641 inline const torch::nn::functional::GumbelSoftmaxFuncOptions &
647 inline torch::nn::functional::GumbelSoftmaxFuncOptions &
options() {
660 inline torch::serialize::OutputArchive &
661 write(torch::serialize::OutputArchive &archive,
662 const std::string &key =
"gumbel_softmax")
const override {
666 archive.write(key +
".tau", torch::full({1}, (double)this->
options_.tau()));
667 archive.write(key +
".dim", torch::full({1}, (int)this->
options_.dim()));
668 archive.write(key +
".hard", torch::full({1}, (bool)this->
options_.hard()));
675 inline torch::serialize::InputArchive &
676 read(torch::serialize::InputArchive &archive,
677 const std::string &key =
"gumbel_softmax")
override {
678 torch::Tensor tensor;
680 archive.read(key +
".type", tensor);
681 if (tensor.item<int64_t>() !=
683 throw std::runtime_error(
"activation mismatch");
685 archive.read(key +
".tau", tensor);
686 this->
options_.tau(tensor.item<
double>());
687 archive.read(key +
".dim", tensor);
688 this->
options_.dim(tensor.item<
int>());
689 archive.read(key +
".hard", tensor);
690 this->
options_.hard(tensor.item<
bool>());
696 torch::nn::functional::GumbelSoftmaxFuncOptions
options_;
707 torch::nn::functional::HardshrinkFuncOptions().lambda(lambda)) {}
712 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
713 return torch::nn::functional::hardshrink(input,
options_);
717 inline const torch::nn::functional::HardshrinkFuncOptions &
options()
const {
722 inline torch::nn::functional::HardshrinkFuncOptions &
options() {
730 <<
"(\n lambda=" <<
options_.lambda() <<
"\n)";
735 inline torch::serialize::OutputArchive &
736 write(torch::serialize::OutputArchive &archive,
737 const std::string &key =
"hardshrink")
const override {
738 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
740 archive.write(key +
".lambda",
741 torch::full({1}, (double)this->
options_.lambda()));
748 inline torch::serialize::InputArchive &
749 read(torch::serialize::InputArchive &archive,
750 const std::string &key =
"hardshrink")
override {
751 torch::Tensor tensor;
753 archive.read(key +
".type", tensor);
755 throw std::runtime_error(
"activation mismatch");
757 archive.read(key +
".lambda", tensor);
758 this->
options_.lambda(tensor.item<
double>());
764 torch::nn::functional::HardshrinkFuncOptions
options_;
784 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
785 return torch::hardsigmoid(input);
796 inline torch::serialize::OutputArchive &
797 write(torch::serialize::OutputArchive &archive,
798 const std::string &key =
"hardsigmoid")
const override {
808 inline torch::serialize::InputArchive &
809 read(torch::serialize::InputArchive &archive,
810 const std::string &key =
"hardsigmoid")
override {
811 torch::Tensor tensor;
813 archive.read(key +
".type", tensor);
815 throw std::runtime_error(
"activation mismatch");
838 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
839 return torch::hardswish(input);
850 inline torch::serialize::OutputArchive &
851 write(torch::serialize::OutputArchive &archive,
852 const std::string &key =
"hardswish")
const override {
853 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
861 inline torch::serialize::InputArchive &
862 read(torch::serialize::InputArchive &archive,
863 const std::string &key =
"hardswish")
override {
864 torch::Tensor tensor;
866 archive.read(key +
".type", tensor);
868 throw std::runtime_error(
"activation mismatch");
887 const torch::nn::functional::HardtanhFuncOptions &
options = {})
890 explicit Hardtanh(
double min_val,
double max_val,
bool inplace =
false)
894 .inplace(inplace)) {}
899 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
900 return torch::nn::functional::hardtanh(input,
options_);
904 inline const torch::nn::functional::HardtanhFuncOptions &
options()
const {
909 inline torch::nn::functional::HardtanhFuncOptions &
options() {
917 <<
"(\n min_val=" <<
options_.min_val()
918 <<
", max_val=" <<
options_.max_val()
919 <<
", inplace=" <<
options_.inplace() <<
"\n)";
924 inline torch::serialize::OutputArchive &
925 write(torch::serialize::OutputArchive &archive,
926 const std::string &key =
"hardtanh")
const override {
927 archive.write(key +
".type",
929 archive.write(key +
".min_val",
930 torch::full({1}, (double)this->
options_.min_val()));
931 archive.write(key +
".max_val",
932 torch::full({1}, (double)this->
options_.max_val()));
933 archive.write(key +
".inplace",
934 torch::full({1}, (bool)this->
options_.inplace()));
941 inline torch::serialize::InputArchive &
942 read(torch::serialize::InputArchive &archive,
943 const std::string &key =
"hardtanh")
override {
944 torch::Tensor tensor;
946 archive.read(key +
".type", tensor);
948 throw std::runtime_error(
"activation mismatch");
950 archive.read(key +
".min_val", tensor);
951 this->
options_.min_val(tensor.item<
double>());
952 archive.read(key +
".max_val", tensor);
953 this->
options_.max_val(tensor.item<
double>());
954 archive.read(key +
".inplace", tensor);
955 this->
options_.inplace(tensor.item<
bool>());
961 torch::nn::functional::HardtanhFuncOptions
options_;
971 torch::nn::functional::InstanceNormFuncOptions
options = {})
975 const torch::Tensor &running_var,
976 const torch::Tensor &weight,
const torch::Tensor &bias,
977 double eps,
double momentum,
978 bool use_input_stats =
true)
980 .running_mean(running_mean)
981 .running_var(running_var)
986 .use_input_stats(use_input_stats)) {}
991 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
992 return torch::nn::functional::instance_norm(input,
options_);
996 inline const torch::nn::functional::InstanceNormFuncOptions &
options()
const {
1001 inline torch::nn::functional::InstanceNormFuncOptions &
options() {
1009 <<
", momentum=" <<
options_.momentum()
1010 <<
", use_input_stats=" <<
options_.use_input_stats();
1013 os <<
"\n running_mean = " <<
options_.running_mean()
1014 <<
"\n running_var = " <<
options_.running_var()
1015 <<
"\n weight = " <<
options_.weight()
1016 <<
"\n bias = " <<
options_.bias();
1024 inline torch::serialize::OutputArchive &
1025 write(torch::serialize::OutputArchive &archive,
1026 const std::string &key =
"instance_norm")
const override {
1030 archive.write(key +
".running_mean", this->
options_.running_mean());
1031 archive.write(key +
".var", this->
options_.running_var());
1032 archive.write(key +
".weight", this->
options_.weight());
1033 archive.write(key +
".bias", this->
options_.bias());
1034 archive.write(key +
".eps", torch::full({1}, (double)this->
options_.eps()));
1035 archive.write(key +
".momentum",
1036 torch::full({1}, (double)this->
options_.momentum()));
1037 archive.write(key +
".use_input_stats",
1038 torch::full({1}, (bool)this->
options_.use_input_stats()));
1045 inline torch::serialize::InputArchive &
1046 read(torch::serialize::InputArchive &archive,
1047 const std::string &key =
"instance_norm")
override {
1048 torch::Tensor tensor;
1050 archive.read(key +
".type", tensor);
1051 if (tensor.item<int64_t>() !=
1053 throw std::runtime_error(
"activation mismatch");
1055 archive.read(key +
".running_mean", this->
options_.running_mean());
1056 archive.read(key +
".running_var", this->
options_.running_var());
1057 archive.read(key +
".weight", this->
options_.weight());
1058 archive.read(key +
".bias", this->
options_.bias());
1059 archive.read(key +
".eps", tensor);
1060 this->
options_.eps(tensor.item<
double>());
1061 archive.read(key +
".momentum", tensor);
1062 this->
options_.momentum(tensor.item<
double>());
1063 archive.read(key +
".use_input_stats", tensor);
1064 this->
options_.use_input_stats(tensor.item<
bool>());
1070 torch::nn::functional::InstanceNormFuncOptions
options_;
1080 std::move(normalized_shape))) {}
1086 const torch::Tensor &weight,
const torch::Tensor &bias,
1089 std::move(normalized_shape))
1097 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1098 return torch::nn::functional::layer_norm(input,
options_);
1102 inline const torch::nn::functional::LayerNormFuncOptions &
options()
const {
1107 inline torch::nn::functional::LayerNormFuncOptions &
options() {
1116 os <<
"\n normalized_shape = " <<
options_.normalized_shape()
1117 <<
"\n weight = " <<
options_.weight()
1118 <<
"\n bias = " <<
options_.bias();
1126 inline torch::serialize::OutputArchive &
1127 write(torch::serialize::OutputArchive &archive,
1128 const std::string &key =
"layer_norm")
const override {
1129 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
1131 archive.write(key +
".weight", this->
options_.weight());
1132 archive.write(key +
".bias", this->
options_.bias());
1133 archive.write(key +
".eps", torch::full({1}, (double)this->
options_.eps()));
1140 inline torch::serialize::InputArchive &
1141 read(torch::serialize::InputArchive &archive,
1142 const std::string &key =
"layer_norm")
override {
1143 torch::Tensor tensor;
1145 archive.read(key +
".type", tensor);
1147 throw std::runtime_error(
"activation mismatch");
1149 archive.read(key +
".weight", this->
options_.weight());
1150 archive.read(key +
".bias", this->
options_.bias());
1151 archive.read(key +
".eps", tensor);
1152 this->
options_.eps(tensor.item<
double>());
1175 explicit LeakyReLU(
double negative_slope,
bool inplace =
false)
1177 .negative_slope(negative_slope)
1178 .inplace(inplace)) {}
1183 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1184 return torch::nn::functional::leaky_relu(input,
options_);
1188 inline const torch::nn::functional::LeakyReLUFuncOptions &
options()
const {
1193 inline torch::nn::functional::LeakyReLUFuncOptions &
options() {
1200 <<
"(\n negative_slope=" <<
options_.negative_slope()
1201 <<
", inplace=" <<
options_.inplace() <<
"\n)";
1206 inline torch::serialize::OutputArchive &
1207 write(torch::serialize::OutputArchive &archive,
1208 const std::string &key =
"leaky_relu")
const override {
1209 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
1212 archive.write(key +
".negative_slope",
1213 torch::full({1}, (double)this->
options_.negative_slope()));
1214 archive.write(key +
".inplace",
1215 torch::full({1}, (bool)this->
options_.inplace()));
1222 inline torch::serialize::InputArchive &
1223 read(torch::serialize::InputArchive &archive,
1224 const std::string &key =
"leaky_relu")
override {
1225 torch::Tensor tensor;
1227 archive.read(key +
".type", tensor);
1229 throw std::runtime_error(
"activation mismatch");
1231 archive.read(key +
".negative_slope", tensor);
1232 this->
options_.negative_slope(tensor.item<
double>());
1233 archive.read(key +
".inplace", tensor);
1234 this->
options_.inplace(tensor.item<
bool>());
1247 :
options_(
torch::nn::functional::LocalResponseNormFuncOptions(size)) {}
1250 const torch::nn::functional::LocalResponseNormFuncOptions &
options)
1254 :
options_(
torch::nn::functional::LocalResponseNormFuncOptions(size)
1262 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1263 return torch::nn::functional::local_response_norm(input,
options_);
1267 inline const torch::nn::functional::LocalResponseNormFuncOptions &
1273 inline torch::nn::functional::LocalResponseNormFuncOptions &
options() {
1281 <<
", k=" <<
options_.k() <<
"\n)";
1286 inline torch::serialize::OutputArchive &
1287 write(torch::serialize::OutputArchive &archive,
1288 const std::string &key =
"local_response_norm")
const override {
1289 archive.write(key +
".type",
1290 torch::full({1},
static_cast<int64_t
>(
1293 archive.write(key +
".size",
1294 torch::full({1}, (int64_t)this->
options_.size()));
1295 archive.write(key +
".alpha",
1296 torch::full({1}, (double)this->
options_.alpha()));
1297 archive.write(key +
".beta",
1298 torch::full({1}, (double)this->
options_.beta()));
1299 archive.write(key +
".k", torch::full({1}, (double)this->
options_.k()));
1306 inline torch::serialize::InputArchive &
1307 read(torch::serialize::InputArchive &archive,
1308 const std::string &key =
"local_response_norm")
override {
1309 torch::Tensor tensor;
1311 archive.read(key +
".type", tensor);
1312 if (tensor.item<int64_t>() !=
1314 throw std::runtime_error(
"activation mismatch");
1316 archive.read(key +
".size", tensor);
1317 this->
options_.size(tensor.item<int64_t>());
1318 archive.read(key +
".alpha", tensor);
1319 this->
options_.alpha(tensor.item<
double>());
1320 archive.read(key +
".beta", tensor);
1321 this->
options_.beta(tensor.item<
double>());
1322 archive.read(key +
".k", tensor);
1323 this->
options_.k(tensor.item<
double>());
1329 torch::nn::functional::LocalResponseNormFuncOptions
options_;
1344 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1345 return torch::log_sigmoid(input);
1355 inline torch::serialize::OutputArchive &
1356 write(torch::serialize::OutputArchive &archive,
1357 const std::string &key =
"logsigmoid")
const override {
1358 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
1366 inline torch::serialize::InputArchive &
1367 read(torch::serialize::InputArchive &archive,
1368 const std::string &key =
"logsigmoid")
override {
1369 torch::Tensor tensor;
1371 archive.read(key +
".type", tensor);
1373 throw std::runtime_error(
"activation mismatch");
1390 :
options_(
torch::nn::functional::LogSoftmaxFuncOptions(dim)) {}
1393 const torch::nn::functional::LogSoftmaxFuncOptions &
options)
1399 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1400 return torch::nn::functional::log_softmax(input,
options_);
1404 inline const torch::nn::functional::LogSoftmaxFuncOptions &
options()
const {
1409 inline torch::nn::functional::LogSoftmaxFuncOptions &
options() {
1421 inline torch::serialize::OutputArchive &
1422 write(torch::serialize::OutputArchive &archive,
1423 const std::string &key =
"logsoftmax")
const override {
1424 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
1432 inline torch::serialize::InputArchive &
1433 read(torch::serialize::InputArchive &archive,
1434 const std::string &key =
"logsoftmax")
override {
1435 torch::Tensor tensor;
1437 archive.read(key +
".type", tensor);
1439 throw std::runtime_error(
"activation mismatch");
1460 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1461 return torch::mish(input);
1471 inline torch::serialize::OutputArchive &
1472 write(torch::serialize::OutputArchive &archive,
1473 const std::string &key =
"mish")
const override {
1474 archive.write(key +
".type",
1482 inline torch::serialize::InputArchive &
1483 read(torch::serialize::InputArchive &archive,
1484 const std::string &key =
"mish")
override {
1485 torch::Tensor tensor;
1487 archive.read(key +
".type", tensor);
1489 throw std::runtime_error(
"activation mismatch");
1503 torch::nn::functional::NormalizeFuncOptions().p(p).eps(eps).dim(
1509 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1510 return torch::nn::functional::normalize(input,
options_);
1514 inline const torch::nn::functional::NormalizeFuncOptions &
options()
const {
1519 inline torch::nn::functional::NormalizeFuncOptions &
options() {
1531 inline torch::serialize::OutputArchive &
1532 write(torch::serialize::OutputArchive &archive,
1533 const std::string &key =
"normalize")
const override {
1534 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
1536 archive.write(key +
".p", torch::full({1}, (double)this->
options_.p()));
1537 archive.write(key +
".eps", torch::full({1}, (double)this->
options_.eps()));
1538 archive.write(key +
".dim",
1539 torch::full({1}, (int64_t)this->
options_.dim()));
1546 inline torch::serialize::InputArchive &
1547 read(torch::serialize::InputArchive &archive,
1548 const std::string &key =
"normalize")
override {
1549 torch::Tensor tensor;
1551 archive.read(key +
".type", tensor);
1553 throw std::runtime_error(
"activation mismatch");
1555 archive.read(key +
".p", tensor);
1556 this->
options_.p(tensor.item<
double>());
1557 archive.read(key +
".eps", tensor);
1558 this->
options_.eps(tensor.item<
double>());
1559 archive.read(key +
".dim", tensor);
1560 this->
options_.dim(tensor.item<int64_t>());
1583 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1584 return torch::nn::functional::prelu(input,
weight());
1592 os <<
"(\n weight = " <<
weight() <<
"\n)";
1597 inline torch::serialize::OutputArchive &
1598 write(torch::serialize::OutputArchive &archive,
1599 const std::string &key =
"prelu")
const override {
1600 archive.write(key +
".type",
1602 archive.write(key +
".weight", this->
weight());
1609 inline torch::serialize::InputArchive &
1610 read(torch::serialize::InputArchive &archive,
1611 const std::string &key =
"prelu")
override {
1612 torch::Tensor tensor;
1614 archive.read(key +
".type", tensor);
1616 throw std::runtime_error(
"activation mismatch");
1618 archive.read(key +
".weight", this->
weight());
1638 :
options_(
torch::nn::functional::ReLUFuncOptions().inplace(inplace)) {}
1643 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1644 return torch::nn::functional::relu(input,
options_);
1648 inline const torch::nn::functional::ReLUFuncOptions &
options()
const {
1658 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1663 inline torch::serialize::OutputArchive &
1664 write(torch::serialize::OutputArchive &archive,
1665 const std::string &key =
"relu")
const override {
1666 archive.write(key +
".type",
1668 archive.write(key +
".inplace",
1669 torch::full({1}, (bool)this->
options_.inplace()));
1676 inline torch::serialize::InputArchive &
1677 read(torch::serialize::InputArchive &archive,
1678 const std::string &key =
"relu")
override {
1679 torch::Tensor tensor;
1681 archive.read(key +
".type", tensor);
1683 throw std::runtime_error(
"activation mismatch");
1685 archive.read(key +
".inplace", tensor);
1686 this->
options_.inplace(tensor.item<
bool>());
1706 :
options_(
torch::nn::functional::ReLU6FuncOptions().inplace(inplace)) {}
1711 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1712 return torch::nn::functional::relu6(input,
options_);
1716 inline const torch::nn::functional::ReLU6FuncOptions &
options()
const {
1726 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1731 inline torch::serialize::OutputArchive &
1732 write(torch::serialize::OutputArchive &archive,
1733 const std::string &key =
"relu6")
const override {
1734 archive.write(key +
".type",
1736 archive.write(key +
".inplace",
1737 torch::full({1}, (bool)this->
options_.inplace()));
1744 inline torch::serialize::InputArchive &
1745 read(torch::serialize::InputArchive &archive,
1746 const std::string &key =
"relu6")
override {
1747 torch::Tensor tensor;
1749 archive.read(key +
".type", tensor);
1751 throw std::runtime_error(
"activation mismatch");
1753 archive.read(key +
".inplace", tensor);
1754 this->
options_.inplace(tensor.item<
bool>());
1774 explicit RReLU(
const torch::nn::functional::RReLUFuncOptions &
options = {})
1777 explicit RReLU(
double lower,
double upper,
bool inplace =
false)
1781 .inplace(inplace)) {}
1786 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1787 return torch::nn::functional::rrelu(input,
options_);
1791 inline const torch::nn::functional::RReLUFuncOptions &
options()
const {
1807 inline torch::serialize::OutputArchive &
1808 write(torch::serialize::OutputArchive &archive,
1809 const std::string &key =
"rrelu")
const override {
1810 archive.write(key +
".type",
1812 archive.write(key +
".lower",
1813 torch::full({1}, (double)this->
options_.lower()));
1814 archive.write(key +
".upper",
1815 torch::full({1}, (double)this->
options_.upper()));
1816 archive.write(key +
".inplace",
1817 torch::full({1}, (bool)this->
options_.inplace()));
1824 inline torch::serialize::InputArchive &
1825 read(torch::serialize::InputArchive &archive,
1826 const std::string &key =
"rrelu")
override {
1827 torch::Tensor tensor;
1829 archive.read(key +
".type", tensor);
1831 throw std::runtime_error(
"activation mismatch");
1833 archive.read(key +
".lower", tensor);
1834 this->
options_.lower(tensor.item<
double>());
1835 archive.read(key +
".upper", tensor);
1836 this->
options_.upper(tensor.item<
double>());
1837 archive.read(key +
".inplace", tensor);
1838 this->
options_.inplace(tensor.item<
bool>());
1861 :
options_(
torch::nn::functional::SELUFuncOptions().inplace(inplace)) {}
1866 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1867 return torch::nn::functional::selu(input,
options_);
1871 inline const torch::nn::functional::SELUFuncOptions &
options()
const {
1881 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1886 inline torch::serialize::OutputArchive &
1887 write(torch::serialize::OutputArchive &archive,
1888 const std::string &key =
"selu")
const override {
1889 archive.write(key +
".type",
1891 archive.write(key +
".inplace",
1892 torch::full({1}, (bool)this->
options_.inplace()));
1899 inline torch::serialize::InputArchive &
1900 read(torch::serialize::InputArchive &archive,
1901 const std::string &key =
"selu")
override {
1902 torch::Tensor tensor;
1904 archive.read(key +
".type", tensor);
1906 throw std::runtime_error(
"activation mismatch");
1908 archive.read(key +
".inplace", tensor);
1909 this->
options_.inplace(tensor.item<
bool>());
1926 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1927 return torch::sigmoid(input);
1937 inline torch::serialize::OutputArchive &
1938 write(torch::serialize::OutputArchive &archive,
1939 const std::string &key =
"sigmoid")
const override {
1940 archive.write(key +
".type",
1948 inline torch::serialize::InputArchive &
1949 read(torch::serialize::InputArchive &archive,
1950 const std::string &key =
"sigmoid")
override {
1951 torch::Tensor tensor;
1953 archive.read(key +
".type", tensor);
1955 throw std::runtime_error(
"activation mismatch");
1969 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
1970 return torch::silu(input);
1980 inline torch::serialize::OutputArchive &
1981 write(torch::serialize::OutputArchive &archive,
1982 const std::string &key =
"silu")
const override {
1983 archive.write(key +
".type",
1991 inline torch::serialize::InputArchive &
1992 read(torch::serialize::InputArchive &archive,
1993 const std::string &key =
"silu")
override {
1994 torch::Tensor tensor;
1996 archive.read(key +
".type", tensor);
1998 throw std::runtime_error(
"activation mismatch");
2014 :
options_(
torch::nn::functional::SoftmaxFuncOptions(dim)) {}
2022 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2023 return torch::nn::functional::softmax(input,
options_);
2027 inline const torch::nn::functional::SoftmaxFuncOptions &
options()
const {
2032 inline torch::nn::functional::SoftmaxFuncOptions &
options() {
2044 inline torch::serialize::OutputArchive &
2045 write(torch::serialize::OutputArchive &archive,
2046 const std::string &key =
"softmax")
const override {
2047 archive.write(key +
".type",
2049 archive.write(key +
".dim",
2050 torch::full({1}, (int64_t)this->
options_.dim()));
2057 inline torch::serialize::InputArchive &
2058 read(torch::serialize::InputArchive &archive,
2059 const std::string &key =
"softmax")
override {
2060 torch::Tensor tensor;
2062 archive.read(key +
".type", tensor);
2064 throw std::runtime_error(
"activation mismatch");
2066 archive.read(key +
".dim", tensor);
2067 this->
options_.dim(tensor.item<int64_t>());
2084 :
options_(
torch::nn::functional::SoftminFuncOptions(dim)) {}
2092 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2093 return torch::nn::functional::softmin(input,
options_);
2097 inline const torch::nn::functional::SoftminFuncOptions &
options()
const {
2102 inline torch::nn::functional::SoftminFuncOptions &
options() {
2114 inline torch::serialize::OutputArchive &
2115 write(torch::serialize::OutputArchive &archive,
2116 const std::string &key =
"softmin")
const override {
2117 archive.write(key +
".type",
2119 archive.write(key +
".dim",
2120 torch::full({1}, (int64_t)this->
options_.dim()));
2127 inline torch::serialize::InputArchive &
2128 read(torch::serialize::InputArchive &archive,
2129 const std::string &key =
"softmin")
override {
2130 torch::Tensor tensor;
2132 archive.read(key +
".type", tensor);
2134 throw std::runtime_error(
"activation mismatch");
2136 archive.read(key +
".dim", tensor);
2137 this->
options_.dim(tensor.item<int64_t>());
2158 torch::nn::functional::SoftplusFuncOptions().beta(beta).
threshold(
2164 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2165 return torch::nn::functional::softplus(input,
options_);
2169 inline const torch::nn::functional::SoftplusFuncOptions &
options()
const {
2174 inline torch::nn::functional::SoftplusFuncOptions &
options() {
2181 <<
", theshold=" <<
options_.threshold() <<
"\n)";
2186 inline torch::serialize::OutputArchive &
2187 write(torch::serialize::OutputArchive &archive,
2188 const std::string &key =
"softplus")
const override {
2189 archive.write(key +
".type",
2191 archive.write(key +
".beta",
2192 torch::full({1}, (double)this->
options_.beta()));
2193 archive.write(key +
".threshold",
2194 torch::full({1}, (double)this->
options_.threshold()));
2201 inline torch::serialize::InputArchive &
2202 read(torch::serialize::InputArchive &archive,
2203 const std::string &key =
"softplus")
override {
2204 torch::Tensor tensor;
2206 archive.read(key +
".type", tensor);
2208 throw std::runtime_error(
"activation mismatch");
2210 archive.read(key +
".beta", tensor);
2211 this->
options_.beta(tensor.item<
double>());
2212 archive.read(key +
".threshold", tensor);
2213 this->
options_.threshold(tensor.item<
double>());
2239 torch::nn::functional::SoftshrinkFuncOptions().lambda(lambda)) {}
2244 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2245 return torch::nn::functional::softshrink(input,
options_);
2249 inline const torch::nn::functional::SoftshrinkFuncOptions &
options()
const {
2254 inline torch::nn::functional::SoftshrinkFuncOptions &
options() {
2261 <<
"(\n lambda=" <<
options_.lambda() <<
"\n)";
2266 inline torch::serialize::OutputArchive &
2267 write(torch::serialize::OutputArchive &archive,
2268 const std::string &key =
"softshrink")
const override {
2269 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
2271 archive.write(key +
".lambda",
2272 torch::full({1}, (double)this->
options_.lambda()));
2279 inline torch::serialize::InputArchive &
2280 read(torch::serialize::InputArchive &archive,
2281 const std::string &key =
"softshrink")
override {
2282 torch::Tensor tensor;
2284 archive.read(key +
".type", tensor);
2286 throw std::runtime_error(
"activation mismatch");
2288 archive.read(key +
".lambda", tensor);
2289 this->
options_.lambda(tensor.item<
double>());
2306 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2307 return torch::nn::functional::softsign(input);
2317 inline torch::serialize::OutputArchive &
2318 write(torch::serialize::OutputArchive &archive,
2319 const std::string &key =
"softsign")
const override {
2320 archive.write(key +
".type",
2328 inline torch::serialize::InputArchive &
2329 read(torch::serialize::InputArchive &archive,
2330 const std::string &key =
"softsign")
override {
2331 torch::Tensor tensor;
2333 archive.read(key +
".type", tensor);
2335 throw std::runtime_error(
"activation mismatch");
2349 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2350 return torch::tanh(input);
2360 inline torch::serialize::OutputArchive &
2361 write(torch::serialize::OutputArchive &archive,
2362 const std::string &key =
"tanh")
const override {
2363 archive.write(key +
".type",
2371 inline torch::serialize::InputArchive &
2372 read(torch::serialize::InputArchive &archive,
2373 const std::string &key =
"tanh")
override {
2374 torch::Tensor tensor;
2376 archive.read(key +
".type", tensor);
2378 throw std::runtime_error(
"activation mismatch");
2392 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2393 return torch::nn::functional::tanhshrink(input);
2403 inline torch::serialize::OutputArchive &
2404 write(torch::serialize::OutputArchive &archive,
2405 const std::string &key =
"tanhshrink")
const override {
2406 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
2414 inline torch::serialize::InputArchive &
2415 read(torch::serialize::InputArchive &archive,
2416 const std::string &key =
"tanhshrink")
override {
2417 torch::Tensor tensor;
2419 archive.read(key +
".type", tensor);
2421 throw std::runtime_error(
"activation mismatch");
2443 .inplace(inplace)) {}
2448 inline torch::Tensor
apply(
const torch::Tensor &input)
const override {
2449 return torch::nn::functional::threshold(input,
options_);
2453 inline const torch::nn::functional::ThresholdFuncOptions &
options()
const {
2458 inline torch::nn::functional::ThresholdFuncOptions &
options() {
2465 <<
"(\n threshold=" <<
options_.threshold()
2472 inline torch::serialize::OutputArchive &
2473 write(torch::serialize::OutputArchive &archive,
2474 const std::string &key =
"threshold")
const override {
2475 archive.write(key +
".type", torch::full({1},
static_cast<int64_t
>(
2477 archive.write(key +
".threshold",
2478 torch::full({1}, this->
options_.threshold()));
2479 archive.write(key +
".value", torch::full({1}, this->
options_.value()));
2480 archive.write(key +
".inplace", torch::full({1}, this->
options_.inplace()));
2487 inline torch::serialize::InputArchive &
2488 read(torch::serialize::InputArchive &archive,
2489 const std::string &key =
"threshold")
override {
2490 torch::Tensor tensor;
2492 archive.read(key +
".type", tensor);
2494 throw std::runtime_error(
"activation mismatch");
2496 archive.read(key +
".threshold", tensor);
2497 this->
options_.threshold(tensor.item<
double>());
2498 archive.read(key +
".value", tensor);
2499 this->
options_.value(tensor.item<
double>());
2500 archive.read(key +
".inplace", tensor);
2501 this->
options_.inplace(tensor.item<
bool>());
Abstract activation function structure.
Definition activation.hpp:65
~ActivationFunction() override=default
void pretty_print(std::ostream &os) const noexcept override=0
Returns a string representation of the activation function.
virtual torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key) const =0
Writes the activation function into a torch::serialize::OutputArchive object.
virtual torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key)=0
Reads the activation function from a torch::serialize::InputArchive object.
virtual torch::Tensor apply(const torch::Tensor &) const =0
Applies the activation function to the given input.
Batch Normalization as described in the paper.
Definition activation.hpp:138
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:188
torch::Tensor & running_mean()
Returns non-constant reference to running mean.
Definition activation.hpp:169
torch::Tensor running_var_
Definition activation.hpp:259
const torch::Tensor & running_mean() const
Returns constant reference to running mean.
Definition activation.hpp:166
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="batch_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:211
const torch::Tensor & running_var() const
Returns constant reference to running variance.
Definition activation.hpp:172
torch::Tensor running_mean_
Definition activation.hpp:259
const torch::nn::functional::BatchNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:178
~BatchNorm() override=default
BatchNorm(torch::Tensor running_mean, torch::Tensor running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool training=false)
Definition activation.hpp:145
torch::nn::functional::BatchNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:183
torch::Tensor & running_var()
Returns non-constant reference to running var.
Definition activation.hpp:175
torch::nn::functional::BatchNormFuncOptions options_
Definition activation.hpp:258
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="batch_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:235
BatchNorm(torch::Tensor running_mean, torch::Tensor running_var, torch::nn::functional::BatchNormFuncOptions options={})
Definition activation.hpp:140
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:160
Continuously Differentiable Exponential Linear Units activation function.
Definition activation.hpp:268
torch::nn::functional::CELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:290
CELU(torch::nn::functional::CELUFuncOptions options={})
Definition activation.hpp:270
torch::nn::functional::CELUFuncOptions options_
Definition activation.hpp:333
const torch::nn::functional::CELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:285
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="celu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:316
CELU(double alpha, bool inplace=false)
Definition activation.hpp:273
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:293
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:280
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="celu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:301
Exponential Linear Units activation function.
Definition activation.hpp:345
const torch::nn::functional::ELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:362
torch::nn::functional::ELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:367
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="elu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:379
ELU(double alpha, bool inplace=false)
Definition activation.hpp:350
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:371
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:357
torch::nn::functional::ELUFuncOptions options_
Definition activation.hpp:411
ELU(torch::nn::functional::ELUFuncOptions options={})
Definition activation.hpp:347
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="elu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:394
Gaussian Error Linear Units activation function.
Definition activation.hpp:422
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:441
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:434
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:429
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:452
Grated Linear Units activation function.
Definition activation.hpp:473
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="glu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:518
const torch::nn::functional::GLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:489
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:484
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="glu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:505
torch::nn::functional::GLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:494
torch::nn::functional::GLUFuncOptions options_
Definition activation.hpp:533
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:497
GLU(torch::nn::functional::GLUFuncOptions options={})
Definition activation.hpp:475
GLU(int64_t dim)
Definition activation.hpp:478
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition activation.hpp:538
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:556
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="group_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:600
torch::nn::functional::GroupNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:566
const torch::nn::functional::GroupNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:561
~GroupNorm() override=default
GroupNorm(int64_t num_groups, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition activation.hpp:546
GroupNorm(int64_t num_groups)
Definition activation.hpp:540
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:572
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="group_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:586
torch::nn::functional::GroupNormFuncOptions options_
Definition activation.hpp:617
GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
Definition activation.hpp:543
Gumbel-Softmax distribution activation function.
Definition activation.hpp:621
const torch::nn::functional::GumbelSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:642
GumbelSoftmax(torch::nn::functional::GumbelSoftmaxFuncOptions options={})
Definition activation.hpp:623
~GumbelSoftmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:636
torch::nn::functional::GumbelSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:647
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:653
GumbelSoftmax(double tau, int dim, bool hard)
Definition activation.hpp:627
torch::nn::functional::GumbelSoftmaxFuncOptions options_
Definition activation.hpp:696
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gumbel_softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:661
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gumbel_softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:676
Hard shrinkish activation function.
Definition activation.hpp:700
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:749
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:728
~Hardshrink() override=default
torch::nn::functional::HardshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:722
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:736
Hardshrink(double lambda)
Definition activation.hpp:705
torch::nn::functional::HardshrinkFuncOptions options_
Definition activation.hpp:764
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:712
Hardshrink(torch::nn::functional::HardshrinkFuncOptions options={})
Definition activation.hpp:702
const torch::nn::functional::HardshrinkFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:717
Hardsigmoid activation function.
Definition activation.hpp:777
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:784
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:809
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:797
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:790
~Hardsigmoid() override=default
Hardswish activation function.
Definition activation.hpp:831
~Hardswish() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardswish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:851
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:838
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardswish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:862
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:844
Hardtanh activation function.
Definition activation.hpp:884
const torch::nn::functional::HardtanhFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:904
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardtanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:942
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:899
Hardtanh(const torch::nn::functional::HardtanhFuncOptions &options={})
Definition activation.hpp:886
torch::nn::functional::HardtanhFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:909
~Hardtanh() override=default
Hardtanh(double min_val, double max_val, bool inplace=false)
Definition activation.hpp:890
torch::nn::functional::HardtanhFuncOptions options_
Definition activation.hpp:961
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardtanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:925
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:915
Instance Normalization as described in the paper.
Definition activation.hpp:968
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1007
torch::nn::functional::InstanceNormFuncOptions options_
Definition activation.hpp:1070
torch::nn::functional::InstanceNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1001
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:991
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="instance_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1025
const torch::nn::functional::InstanceNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:996
InstanceNorm(torch::nn::functional::InstanceNormFuncOptions options={})
Definition activation.hpp:970
InstanceNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool use_input_stats=true)
Definition activation.hpp:974
~InstanceNorm() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="instance_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1046
Layer Normalization as described in the paper.
Definition activation.hpp:1076
const torch::nn::functional::LayerNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1102
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="layer_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1127
LayerNorm(std::vector< int64_t > normalized_shape, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition activation.hpp:1085
LayerNorm(std::vector< int64_t > normalized_shape)
Definition activation.hpp:1078
~LayerNorm() override=default
torch::nn::functional::LayerNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1107
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1112
LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
Definition activation.hpp:1082
torch::nn::functional::LayerNormFuncOptions options_
Definition activation.hpp:1158
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1097
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="layer_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1141
Leaky ReLU activation function.
Definition activation.hpp:1170
~LeakyReLU() override=default
const torch::nn::functional::LeakyReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1188
LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options={})
Definition activation.hpp:1172
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="leaky_relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1207
torch::nn::functional::LeakyReLUFuncOptions options_
Definition activation.hpp:1240
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1183
torch::nn::functional::LeakyReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1193
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="leaky_relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1223
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1198
LeakyReLU(double negative_slope, bool inplace=false)
Definition activation.hpp:1175
Local response Normalization.
Definition activation.hpp:1244
torch::nn::functional::LocalResponseNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1273
LocalResponseNorm(const torch::nn::functional::LocalResponseNormFuncOptions &options)
Definition activation.hpp:1249
torch::nn::functional::LocalResponseNormFuncOptions options_
Definition activation.hpp:1329
LocalResponseNorm(int64_t size, double alpha, double beta, double k)
Definition activation.hpp:1253
LocalResponseNorm(int64_t size)
Definition activation.hpp:1246
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="local_response_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1287
~LocalResponseNorm() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1262
const torch::nn::functional::LocalResponseNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1268
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1278
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="local_response_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1307
LogSigmoid activation function.
Definition activation.hpp:1337
~LogSigmoid() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1367
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1349
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1356
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1344
LogSoftmax activation function.
Definition activation.hpp:1387
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1399
torch::nn::functional::LogSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1409
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsoftmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1422
LogSoftmax(int64_t dim)
Definition activation.hpp:1389
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1414
LogSoftmax(const torch::nn::functional::LogSoftmaxFuncOptions &options)
Definition activation.hpp:1392
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsoftmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1433
torch::nn::functional::LogSoftmaxFuncOptions options_
Definition activation.hpp:1445
~LogSoftmax() override=default
const torch::nn::functional::LogSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1404
Mish activation function.
Definition activation.hpp:1453
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1460
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1465
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="mish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1483
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="mish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1472
No-op activation function.
Definition activation.hpp:95
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="none") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:110
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:103
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:98
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="none") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:121
Lp Normalization.
Definition activation.hpp:1496
~Normalize() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="normalize") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1532
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="normalize") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1547
torch::nn::functional::NormalizeFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1519
Normalize(double p, double eps, int64_t dim)
Definition activation.hpp:1501
Normalize(torch::nn::functional::NormalizeFuncOptions options={})
Definition activation.hpp:1498
torch::nn::functional::NormalizeFuncOptions options_
Definition activation.hpp:1566
const torch::nn::functional::NormalizeFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1514
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1509
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1524
PReLU activation function.
Definition activation.hpp:1570
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1583
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1588
PReLU(torch::Tensor weight)
Definition activation.hpp:1572
torch::Tensor weight_
Definition activation.hpp:1624
~PReLU() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="prelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1598
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="prelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1610
const torch::Tensor & weight() const
Returns constant reference to weights.
Definition activation.hpp:1577
torch::Tensor & weight()
Returns non-constant reference to weights.
Definition activation.hpp:1580
Randomized ReLU activation function.
Definition activation.hpp:1772
RReLU(double lower, double upper, bool inplace=false)
Definition activation.hpp:1777
RReLU(const torch::nn::functional::RReLUFuncOptions &options={})
Definition activation.hpp:1774
const torch::nn::functional::RReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1791
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1786
torch::nn::functional::RReLUFuncOptions options_
Definition activation.hpp:1844
~RReLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="rrelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1825
torch::nn::functional::RReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1796
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1799
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="rrelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1808
ReLU6 activation function.
Definition activation.hpp:1700
~ReLU6() override=default
ReLU6(torch::nn::functional::ReLU6FuncOptions options={})
Definition activation.hpp:1702
torch::nn::functional::ReLU6FuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1721
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1711
const torch::nn::functional::ReLU6FuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1716
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu6") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1732
ReLU6(bool inplace)
Definition activation.hpp:1705
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1724
torch::nn::functional::ReLU6FuncOptions options_
Definition activation.hpp:1760
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu6") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1745
ReLU activation function.
Definition activation.hpp:1632
torch::nn::functional::ReLUFuncOptions options_
Definition activation.hpp:1692
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1643
ReLU(bool inplace)
Definition activation.hpp:1637
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1664
torch::nn::functional::ReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1653
ReLU(torch::nn::functional::ReLUFuncOptions options={})
Definition activation.hpp:1634
const torch::nn::functional::ReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1648
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1656
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1677
SELU activation function.
Definition activation.hpp:1855
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1879
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="selu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1900
torch::nn::functional::SELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1876
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1866
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="selu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1887
const torch::nn::functional::SELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1871
SELU(bool inplace)
Definition activation.hpp:1860
SELU(torch::nn::functional::SELUFuncOptions options={})
Definition activation.hpp:1857
torch::nn::functional::SELUFuncOptions options_
Definition activation.hpp:1915
Sigmoid Linear Unit activation function.
Definition activation.hpp:1966
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="silu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1981
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1974
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1969
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="silu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1992
Sigmoid activation function.
Definition activation.hpp:1923
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1931
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1926
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="sigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1949
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="sigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1938
Softmax activation function.
Definition activation.hpp:2011
torch::nn::functional::SoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2032
Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
Definition activation.hpp:2016
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2045
const torch::nn::functional::SoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2027
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2037
~Softmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2022
Softmax(int64_t dim)
Definition activation.hpp:2013
torch::nn::functional::SoftmaxFuncOptions options_
Definition activation.hpp:2073
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2058
Softmin activation function.
Definition activation.hpp:2081
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmin") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2115
Softmin(int64_t dim)
Definition activation.hpp:2083
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2092
torch::nn::functional::SoftminFuncOptions options_
Definition activation.hpp:2143
const torch::nn::functional::SoftminFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2097
~Softmin() override=default
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2107
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmin") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2128
torch::nn::functional::SoftminFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2102
Softmin(const torch::nn::functional::SoftminFuncOptions &options)
Definition activation.hpp:2086
Softplus activation function.
Definition activation.hpp:2151
torch::nn::functional::SoftplusFuncOptions options_
Definition activation.hpp:2219
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2164
Softplus(torch::nn::functional::SoftplusFuncOptions options={})
Definition activation.hpp:2153
const torch::nn::functional::SoftplusFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2169
Softplus(double beta, double threshold)
Definition activation.hpp:2156
torch::nn::functional::SoftplusFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2174
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softplus") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2202
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softplus") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2187
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2179
~Softplus() override=default
Softshrink activation function.
Definition activation.hpp:2232
const torch::nn::functional::SoftshrinkFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2249
Softshrink(double lambda)
Definition activation.hpp:2237
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2267
Softshrink(torch::nn::functional::SoftshrinkFuncOptions options={})
Definition activation.hpp:2234
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2244
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2280
torch::nn::functional::SoftshrinkFuncOptions options_
Definition activation.hpp:2295
torch::nn::functional::SoftshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2254
~Softshrink() override=default
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2259
Softsign activation function.
Definition activation.hpp:2303
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2306
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2311
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softsign") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2329
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softsign") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2318
Tanh activation function.
Definition activation.hpp:2346
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2354
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2361
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2349
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2372
Tanhshrink activation function.
Definition activation.hpp:2389
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2392
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanhshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2415
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanhshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2404
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2397
Threshold activation function.
Definition activation.hpp:2436
const torch::nn::functional::ThresholdFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2453
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2463
Threshold(const torch::nn::functional::ThresholdFuncOptions &options)
Definition activation.hpp:2438
~Threshold() override=default
Threshold(double threshold, double value, bool inplace=false)
Definition activation.hpp:2441
torch::nn::functional::ThresholdFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2458
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2448
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="threshold") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2488
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="threshold") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2473
torch::nn::functional::ThresholdFuncOptions options_
Definition activation.hpp:2507
Full qualified name descriptor.
Definition fqn.hpp:22
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:28
Full qualified name utility functions.
bool is_verbose(std::ostream &os)
Definition core.hpp:831
std::ostream & operator<<(std::ostream &os, const MemoryDebugger< id > &obj)
Print (as string) a memory debugger object.
Definition memory.hpp:125
struct iganet::@0 Log
Logger.
@ none
Definition boundary.hpp:38
activation
Enumerator for nonlinear activation functions.
Definition activation.hpp:26
short int short_t
Definition core.hpp:74
Definition optimizer.hpp:61