67 virtual torch::Tensor
apply(
const torch::Tensor &)
const = 0;
74 virtual torch::serialize::OutputArchive &
76 const std::string &
key)
const = 0;
80 virtual torch::serialize::InputArchive &
95 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
106 inline torch::serialize::OutputArchive &
108 const std::string &
key =
"none")
const override {
116 inline torch::serialize::InputArchive &
118 const std::string &
key =
"none")
override {
123 throw std::runtime_error(
"activation mismatch");
138 torch::nn::functional::BatchNormFuncOptions
options = {})
144 const torch::Tensor &weight,
const torch::Tensor &
bias,
157 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
175 inline const torch::nn::functional::BatchNormFuncOptions &
options()
const {
180 inline torch::nn::functional::BatchNormFuncOptions &
options() {
188 <<
", momentum=" <<
options_.momentum()
189#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
192 <<
", training=" <<
options_.training();
197 <<
"\n weight = " <<
options_.weight()
206 inline torch::serialize::OutputArchive &
208 const std::string &
key =
"batch_norm")
const override {
213 archive.write(
key +
".weight", this->
options_.weight());
214 archive.write(
key +
".bias", this->
options_.bias());
216 archive.write(
key +
".momentum",
218#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
222 archive.write(
key +
".training",
230 inline torch::serialize::InputArchive &
232 const std::string &
key =
"batch_norm")
override {
237 throw std::runtime_error(
"activation mismatch");
241 archive.read(
key +
".weight", this->
options_.weight());
242 archive.read(
key +
".bias", this->
options_.bias());
244 this->
options_.eps(tensor.item<
double>());
246 this->
options_.momentum(tensor.item<
double>());
248 this->
options_.training(tensor.item<
bool>());
254 torch::nn::functional::BatchNormFuncOptions
options_;
266 explicit CELU(torch::nn::functional::CELUFuncOptions
options = {})
276 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
281 inline const torch::nn::functional::CELUFuncOptions &
options()
const {
292 <<
", inplace=" <<
options_.inplace() <<
"\n)";
297 inline torch::serialize::OutputArchive &
299 const std::string &
key =
"celu")
const override {
303 archive.write(
key +
".inplace",
311 inline torch::serialize::InputArchive &
313 const std::string &
key =
"celu")
override {
318 throw std::runtime_error(
"activation mismatch");
321 this->
options_.alpha(tensor.item<
double>());
323 this->
options_.inplace(tensor.item<
bool>());
343 explicit ELU(torch::nn::functional::ELUFuncOptions
options = {})
353 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
358 inline const torch::nn::functional::ELUFuncOptions &
options()
const {
369 <<
", inplace=" <<
options_.inplace() <<
"\n)";
374 inline torch::serialize::OutputArchive &
376 const std::string &
key =
"elu")
const override {
380 archive.write(
key +
".inplace",
388 inline torch::serialize::InputArchive &
390 const std::string &
key =
"elu")
override {
395 throw std::runtime_error(
"activation mismatch");
398 this->
options_.alpha(tensor.item<
double>());
400 this->
options_.inplace(tensor.item<
bool>());
424 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
425 return torch::gelu(
input);
436 inline torch::serialize::OutputArchive &
438 const std::string &
key =
"gelu")
const override {
446 inline torch::serialize::InputArchive &
448 const std::string &
key =
"gelu")
override {
453 throw std::runtime_error(
"activation mismatch");
470 explicit GLU(torch::nn::functional::GLUFuncOptions
options = {})
479 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
484 inline const torch::nn::functional::GLUFuncOptions &
options()
const {
500 inline torch::serialize::OutputArchive &
502 const std::string &
key =
"glu")
const override {
511 inline torch::serialize::InputArchive &
513 const std::string &
key =
"glu")
override {
518 throw std::runtime_error(
"activation mismatch");
521 this->
options_.dim(tensor.item<
int>());
541 const torch::Tensor &
bias,
double eps)
550 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
555 inline const torch::nn::functional::GroupNormFuncOptions &
options()
const {
560 inline torch::nn::functional::GroupNormFuncOptions &
options() {
579 inline torch::serialize::OutputArchive &
581 const std::string &
key =
"group_norm")
const override {
585 archive.write(
key +
".bias", this->
options_.bias());
593 inline torch::serialize::InputArchive &
595 const std::string &
key =
"group_norm")
override {
600 throw std::runtime_error(
"activation mismatch");
603 archive.read(
key +
".bias", this->
options_.bias());
605 this->
options_.eps(tensor.item<
double>());
611 torch::nn::functional::GroupNormFuncOptions
options_;
618 torch::nn::functional::GumbelSoftmaxFuncOptions
options = {})
630 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
631 return torch::nn::functional::gumbel_softmax(
input,
options_);
635 inline const torch::nn::functional::GumbelSoftmaxFuncOptions &
641 inline torch::nn::functional::GumbelSoftmaxFuncOptions &
options() {
654 inline torch::serialize::OutputArchive &
656 const std::string &
key =
"gumbel_softmax")
const override {
660 archive.write(
key +
".dim", torch::full({1}, (
int)this->
options_.dim()));
661 archive.write(
key +
".hard", torch::full({1}, (
bool)this->
options_.hard()));
668 inline torch::serialize::InputArchive &
670 const std::string &
key =
"gumbel_softmax")
override {
675 throw std::runtime_error(
"activation mismatch");
678 this->
options_.tau(tensor.item<
double>());
680 this->
options_.dim(tensor.item<
int>());
682 this->
options_.hard(tensor.item<
bool>());
688 torch::nn::functional::GumbelSoftmaxFuncOptions
options_;
704 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
709 inline const torch::nn::functional::HardshrinkFuncOptions &
options()
const {
714 inline torch::nn::functional::HardshrinkFuncOptions &
options() {
722 <<
"(\n lambda=" <<
options_.lambda() <<
"\n)";
727 inline torch::serialize::OutputArchive &
729 const std::string &
key =
"hardshrink")
const override {
740 inline torch::serialize::InputArchive &
742 const std::string &
key =
"hardshrink")
override {
747 throw std::runtime_error(
"activation mismatch");
750 this->
options_.lambda(tensor.item<
double>());
756 torch::nn::functional::HardshrinkFuncOptions
options_;
776 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
777 return torch::hardsigmoid(
input);
788 inline torch::serialize::OutputArchive &
790 const std::string &
key =
"hardsigmoid")
const override {
799 inline torch::serialize::InputArchive &
801 const std::string &
key =
"hardsigmoid")
override {
806 throw std::runtime_error(
"activation mismatch");
829 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
830 return torch::hardswish(
input);
841 inline torch::serialize::OutputArchive &
843 const std::string &
key =
"hardswish")
const override {
852 inline torch::serialize::InputArchive &
854 const std::string &
key =
"hardswish")
override {
859 throw std::runtime_error(
"activation mismatch");
889 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
894 inline const torch::nn::functional::HardtanhFuncOptions &
options()
const {
899 inline torch::nn::functional::HardtanhFuncOptions &
options() {
907 <<
"(\n min_val=" <<
options_.min_val()
908 <<
", max_val=" <<
options_.max_val()
909 <<
", inplace=" <<
options_.inplace() <<
"\n)";
914 inline torch::serialize::OutputArchive &
916 const std::string &
key =
"hardtanh")
const override {
921 archive.write(
key +
".max_val",
923 archive.write(
key +
".inplace",
931 inline torch::serialize::InputArchive &
933 const std::string &
key =
"hardtanh")
override {
938 throw std::runtime_error(
"activation mismatch");
941 this->
options_.min_val(tensor.item<
double>());
943 this->
options_.max_val(tensor.item<
double>());
945 this->
options_.inplace(tensor.item<
bool>());
951 torch::nn::functional::HardtanhFuncOptions
options_;
961 torch::nn::functional::InstanceNormFuncOptions
options = {})
965 const torch::Tensor &running_var,
966 const torch::Tensor &weight,
const torch::Tensor &
bias,
970 .running_mean(running_mean)
971 .running_var(running_var)
981 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
982 return torch::nn::functional::instance_norm(
input,
options_);
986 inline const torch::nn::functional::InstanceNormFuncOptions &
options()
const {
991 inline torch::nn::functional::InstanceNormFuncOptions &
options() {
999 <<
", momentum=" <<
options_.momentum()
1000 <<
", use_input_stats=" <<
options_.use_input_stats();
1003 os <<
"\n running_mean = " <<
options_.running_mean()
1004 <<
"\n running_var = " <<
options_.running_var()
1005 <<
"\n weight = " <<
options_.weight()
1006 <<
"\n bias = " <<
options_.bias();
1014 inline torch::serialize::OutputArchive &
1016 const std::string &
key =
"instance_norm")
const override {
1020 archive.write(
key +
".var", this->
options_.running_var());
1021 archive.write(
key +
".weight", this->
options_.weight());
1022 archive.write(
key +
".bias", this->
options_.bias());
1024 archive.write(
key +
".momentum",
1026 archive.write(
key +
".use_input_stats",
1027 torch::full({1}, (
bool)this->
options_.use_input_stats()));
1034 inline torch::serialize::InputArchive &
1036 const std::string &
key =
"instance_norm")
override {
1041 throw std::runtime_error(
"activation mismatch");
1044 archive.read(
key +
".running_var", this->
options_.running_var());
1045 archive.read(
key +
".weight", this->
options_.weight());
1046 archive.read(
key +
".bias", this->
options_.bias());
1048 this->
options_.eps(tensor.item<
double>());
1050 this->
options_.momentum(tensor.item<
double>());
1052 this->
options_.use_input_stats(tensor.item<
bool>());
1058 torch::nn::functional::InstanceNormFuncOptions
options_;
1074 const torch::Tensor &weight,
const torch::Tensor &
bias,
1085 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1090 inline const torch::nn::functional::LayerNormFuncOptions &
options()
const {
1095 inline torch::nn::functional::LayerNormFuncOptions &
options() {
1105 os <<
"\n normalized_shape = " <<
options_.normalized_shape()
1106 <<
"\n weight = " <<
options_.weight()
1107 <<
"\n bias = " <<
options_.bias();
1115 inline torch::serialize::OutputArchive &
1117 const std::string &
key =
"layer_norm")
const override {
1121 archive.write(
key +
".bias", this->
options_.bias());
1129 inline torch::serialize::InputArchive &
1131 const std::string &
key =
"layer_norm")
override {
1136 throw std::runtime_error(
"activation mismatch");
1139 archive.read(
key +
".bias", this->
options_.bias());
1141 this->
options_.eps(tensor.item<
double>());
1172 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1177 inline const torch::nn::functional::LeakyReLUFuncOptions &
options()
const {
1182 inline torch::nn::functional::LeakyReLUFuncOptions &
options() {
1190 <<
"(\n negative_slope=" <<
options_.negative_slope()
1191 <<
", inplace=" <<
options_.inplace() <<
"\n)";
1196 inline torch::serialize::OutputArchive &
1198 const std::string &
key =
"leaky_relu")
const override {
1204 archive.write(
key +
".inplace",
1212 inline torch::serialize::InputArchive &
1214 const std::string &
key =
"leaky_relu")
override {
1219 throw std::runtime_error(
"activation mismatch");
1222 this->
options_.negative_slope(tensor.item<
double>());
1224 this->
options_.inplace(tensor.item<
bool>());
1240 const torch::nn::functional::LocalResponseNormFuncOptions &
options)
1252 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1253 return torch::nn::functional::local_response_norm(
input,
options_);
1257 inline const torch::nn::functional::LocalResponseNormFuncOptions &
1263 inline torch::nn::functional::LocalResponseNormFuncOptions &
options() {
1272 <<
", k=" <<
options_.k() <<
"\n)";
1277 inline torch::serialize::OutputArchive &
1279 const std::string &
key =
"local_response_norm")
const override {
1285 archive.write(
key +
".alpha",
1287 archive.write(
key +
".beta",
1296 inline torch::serialize::InputArchive &
1298 const std::string &
key =
"local_response_norm")
override {
1303 throw std::runtime_error(
"activation mismatch");
1308 this->
options_.alpha(tensor.item<
double>());
1310 this->
options_.beta(tensor.item<
double>());
1312 this->
options_.k(tensor.item<
double>());
1318 torch::nn::functional::LocalResponseNormFuncOptions
options_;
1333 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1334 return torch::log_sigmoid(
input);
1345 inline torch::serialize::OutputArchive &
1347 const std::string &
key =
"logsigmoid")
const override {
1356 inline torch::serialize::InputArchive &
1358 const std::string &
key =
"logsigmoid")
override {
1363 throw std::runtime_error(
"activation mismatch");
1383 const torch::nn::functional::LogSoftmaxFuncOptions &
options)
1389 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1394 inline const torch::nn::functional::LogSoftmaxFuncOptions &
options()
const {
1399 inline torch::nn::functional::LogSoftmaxFuncOptions &
options() {
1412 inline torch::serialize::OutputArchive &
1414 const std::string &
key =
"logsoftmax")
const override {
1423 inline torch::serialize::InputArchive &
1425 const std::string &
key =
"logsoftmax")
override {
1430 throw std::runtime_error(
"activation mismatch");
1451 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1452 return torch::mish(
input);
1463 inline torch::serialize::OutputArchive &
1465 const std::string &
key =
"mish")
const override {
1473 inline torch::serialize::InputArchive &
1475 const std::string &
key =
"mish")
override {
1480 throw std::runtime_error(
"activation mismatch");
1500 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1505 inline const torch::nn::functional::NormalizeFuncOptions &
options()
const {
1510 inline torch::nn::functional::NormalizeFuncOptions &
options() {
1523 inline torch::serialize::OutputArchive &
1525 const std::string &
key =
"normalize")
const override {
1530 archive.write(
key +
".dim",
1538 inline torch::serialize::InputArchive &
1540 const std::string &
key =
"normalize")
override {
1545 throw std::runtime_error(
"activation mismatch");
1548 this->
options_.p(tensor.item<
double>());
1550 this->
options_.eps(tensor.item<
double>());
1575 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1576 return torch::nn::functional::prelu(
input,
weight());
1585 os <<
"(\n weight = " <<
weight() <<
"\n)";
1590 inline torch::serialize::OutputArchive &
1592 const std::string &
key =
"prelu")
const override {
1601 inline torch::serialize::InputArchive &
1603 const std::string &
key =
"prelu")
override {
1608 throw std::runtime_error(
"activation mismatch");
1635 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1640 inline const torch::nn::functional::ReLUFuncOptions &
options()
const {
1651 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1656 inline torch::serialize::OutputArchive &
1658 const std::string &
key =
"relu")
const override {
1668 inline torch::serialize::InputArchive &
1670 const std::string &
key =
"relu")
override {
1675 throw std::runtime_error(
"activation mismatch");
1678 this->
options_.inplace(tensor.item<
bool>());
1703 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1708 inline const torch::nn::functional::ReLU6FuncOptions &
options()
const {
1719 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1724 inline torch::serialize::OutputArchive &
1726 const std::string &
key =
"relu6")
const override {
1736 inline torch::serialize::InputArchive &
1738 const std::string &
key =
"relu6")
override {
1743 throw std::runtime_error(
"activation mismatch");
1746 this->
options_.inplace(tensor.item<
bool>());
1778 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1783 inline const torch::nn::functional::RReLUFuncOptions &
options()
const {
1800 inline torch::serialize::OutputArchive &
1802 const std::string &
key =
"rrelu")
const override {
1806 archive.write(
key +
".upper",
1808 archive.write(
key +
".inplace",
1816 inline torch::serialize::InputArchive &
1818 const std::string &
key =
"rrelu")
override {
1823 throw std::runtime_error(
"activation mismatch");
1826 this->
options_.lower(tensor.item<
double>());
1828 this->
options_.upper(tensor.item<
double>());
1830 this->
options_.inplace(tensor.item<
bool>());
1858 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1863 inline const torch::nn::functional::SELUFuncOptions &
options()
const {
1874 <<
"(\n inplace=" <<
options_.inplace() <<
"\n)";
1879 inline torch::serialize::OutputArchive &
1881 const std::string &
key =
"selu")
const override {
1891 inline torch::serialize::InputArchive &
1893 const std::string &
key =
"selu")
override {
1898 throw std::runtime_error(
"activation mismatch");
1901 this->
options_.inplace(tensor.item<
bool>());
1918 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1919 return torch::sigmoid(
input);
1930 inline torch::serialize::OutputArchive &
1932 const std::string &
key =
"sigmoid")
const override {
1941 inline torch::serialize::InputArchive &
1943 const std::string &
key =
"sigmoid")
override {
1948 throw std::runtime_error(
"activation mismatch");
1962 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
1963 return torch::silu(
input);
1974 inline torch::serialize::OutputArchive &
1976 const std::string &
key =
"silu")
const override {
1984 inline torch::serialize::InputArchive &
1986 const std::string &
key =
"silu")
override {
1991 throw std::runtime_error(
"activation mismatch");
2015 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2020 inline const torch::nn::functional::SoftmaxFuncOptions &
options()
const {
2025 inline torch::nn::functional::SoftmaxFuncOptions &
options() {
2038 inline torch::serialize::OutputArchive &
2040 const std::string &
key =
"softmax")
const override {
2051 inline torch::serialize::InputArchive &
2053 const std::string &
key =
"softmax")
override {
2058 throw std::runtime_error(
"activation mismatch");
2086 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2091 inline const torch::nn::functional::SoftminFuncOptions &
options()
const {
2096 inline torch::nn::functional::SoftminFuncOptions &
options() {
2109 inline torch::serialize::OutputArchive &
2111 const std::string &
key =
"softmin")
const override {
2122 inline torch::serialize::InputArchive &
2124 const std::string &
key =
"softmin")
override {
2129 throw std::runtime_error(
"activation mismatch");
2159 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2164 inline const torch::nn::functional::SoftplusFuncOptions &
options()
const {
2169 inline torch::nn::functional::SoftplusFuncOptions &
options() {
2177 <<
", theshold=" <<
options_.threshold() <<
"\n)";
2182 inline torch::serialize::OutputArchive &
2184 const std::string &
key =
"softplus")
const override {
2189 archive.write(
key +
".threshold",
2197 inline torch::serialize::InputArchive &
2199 const std::string &
key =
"softplus")
override {
2204 throw std::runtime_error(
"activation mismatch");
2207 this->
options_.beta(tensor.item<
double>());
2209 this->
options_.threshold(tensor.item<
double>());
2240 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2245 inline const torch::nn::functional::SoftshrinkFuncOptions &
options()
const {
2250 inline torch::nn::functional::SoftshrinkFuncOptions &
options() {
2258 <<
"(\n lambda=" <<
options_.lambda() <<
"\n)";
2263 inline torch::serialize::OutputArchive &
2265 const std::string &
key =
"softshrink")
const override {
2276 inline torch::serialize::InputArchive &
2278 const std::string &
key =
"softshrink")
override {
2283 throw std::runtime_error(
"activation mismatch");
2286 this->
options_.lambda(tensor.item<
double>());
2303 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2304 return torch::nn::functional::softsign(
input);
2315 inline torch::serialize::OutputArchive &
2317 const std::string &
key =
"softsign")
const override {
2326 inline torch::serialize::InputArchive &
2328 const std::string &
key =
"softsign")
override {
2333 throw std::runtime_error(
"activation mismatch");
2347 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2348 return torch::tanh(
input);
2359 inline torch::serialize::OutputArchive &
2361 const std::string &
key =
"tanh")
const override {
2369 inline torch::serialize::InputArchive &
2371 const std::string &
key =
"tanh")
override {
2376 throw std::runtime_error(
"activation mismatch");
2390 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2391 return torch::nn::functional::tanhshrink(
input);
2402 inline torch::serialize::OutputArchive &
2404 const std::string &
key =
"tanhshrink")
const override {
2413 inline torch::serialize::InputArchive &
2415 const std::string &
key =
"tanhshrink")
override {
2420 throw std::runtime_error(
"activation mismatch");
2447 inline torch::Tensor
apply(
const torch::Tensor &
input)
const override {
2452 inline const torch::nn::functional::ThresholdFuncOptions &
options()
const {
2457 inline torch::nn::functional::ThresholdFuncOptions &
options() {
2465 <<
"(\n threshold=" <<
options_.threshold()
2472 inline torch::serialize::OutputArchive &
2474 const std::string &
key =
"threshold")
const override {
2479 archive.write(
key +
".value",
2481 archive.write(
key +
".inplace",
2489 inline torch::serialize::InputArchive &
2491 const std::string &
key =
"threshold")
override {
2496 throw std::runtime_error(
"activation mismatch");
2499 this->
options_.threshold(tensor.item<
double>());
2501 this->
options_.value(tensor.item<
double>());
2503 this->
options_.inplace(tensor.item<
bool>());
Abstract activation function structure.
Definition layer.hpp:62
virtual torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key) const =0
Writes the activation function into a torch::serialize::OutputArchive object.
virtual torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key)=0
Reads the activation function from a torch::serialize::InputArchive object.
virtual void pretty_print(std::ostream &os) const noexcept=0
Returns a string representation of the activation function.
virtual torch::Tensor apply(const torch::Tensor &) const =0
Applies the activation function to the given input.
virtual ~ActivationFunction()=default
Batch Normalization as described in the paper.
Definition layer.hpp:134
BatchNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, torch::nn::functional::BatchNormFuncOptions options={})
Definition layer.hpp:136
torch::Tensor & running_mean()
Returns non-constant reference to running mean.
Definition layer.hpp:166
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:186
torch::Tensor running_var_
Definition layer.hpp:255
BatchNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool training=false)
Definition layer.hpp:142
const torch::Tensor & running_mean() const
Returns constant reference to running mean.
Definition layer.hpp:163
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="batch_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:207
const torch::Tensor & running_var() const
Returns constant reference to running variance.
Definition layer.hpp:169
torch::Tensor running_mean_
Definition layer.hpp:255
const torch::nn::functional::BatchNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:175
~BatchNorm() override=default
torch::nn::functional::BatchNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:180
torch::Tensor & running_var()
Returns non-constant reference to running var.
Definition layer.hpp:172
torch::nn::functional::BatchNormFuncOptions options_
Definition layer.hpp:254
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="batch_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:231
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:157
Continuously Differentiable Exponential Linear Units activation function.
Definition layer.hpp:264
torch::nn::functional::CELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:286
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:290
CELU(torch::nn::functional::CELUFuncOptions options={})
Definition layer.hpp:266
torch::nn::functional::CELUFuncOptions options_
Definition layer.hpp:329
const torch::nn::functional::CELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:281
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="celu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:312
CELU(double alpha, bool inplace=false)
Definition layer.hpp:269
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:276
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="celu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:298
Exponential Linear Units activation function.
Definition layer.hpp:341
const torch::nn::functional::ELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:358
torch::nn::functional::ELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:363
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="elu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:375
ELU(double alpha, bool inplace=false)
Definition layer.hpp:346
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:353
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:367
torch::nn::functional::ELUFuncOptions options_
Definition layer.hpp:406
ELU(torch::nn::functional::ELUFuncOptions options={})
Definition layer.hpp:343
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="elu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:389
Gaussian Error Linear Units activation function.
Definition layer.hpp:417
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:437
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:430
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:424
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:447
Grated Linear Units activation function.
Definition layer.hpp:468
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="glu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:512
const torch::nn::functional::GLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:484
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:479
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="glu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:501
torch::nn::functional::GLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:489
torch::nn::functional::GLUFuncOptions options_
Definition layer.hpp:527
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:493
GLU(torch::nn::functional::GLUFuncOptions options={})
Definition layer.hpp:470
GLU(int64_t dim)
Definition layer.hpp:473
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition layer.hpp:532
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:550
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:566
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="group_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:594
torch::nn::functional::GroupNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:560
const torch::nn::functional::GroupNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:555
~GroupNorm() override=default
GroupNorm(int64_t num_groups, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition layer.hpp:540
GroupNorm(int64_t num_groups)
Definition layer.hpp:534
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="group_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:580
torch::nn::functional::GroupNormFuncOptions options_
Definition layer.hpp:611
GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
Definition layer.hpp:537
Gumbel-Softmax distribution activation function.
Definition layer.hpp:615
const torch::nn::functional::GumbelSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:636
GumbelSoftmax(torch::nn::functional::GumbelSoftmaxFuncOptions options={})
Definition layer.hpp:617
~GumbelSoftmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:630
torch::nn::functional::GumbelSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:641
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:647
GumbelSoftmax(double tau, int dim, bool hard)
Definition layer.hpp:621
torch::nn::functional::GumbelSoftmaxFuncOptions options_
Definition layer.hpp:688
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gumbel_softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:655
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gumbel_softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:669
Hard shrinkish activation function.
Definition layer.hpp:692
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:741
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:720
~Hardshrink() override=default
torch::nn::functional::HardshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:714
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:728
Hardshrink(double lambda)
Definition layer.hpp:697
torch::nn::functional::HardshrinkFuncOptions options_
Definition layer.hpp:756
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:704
Hardshrink(torch::nn::functional::HardshrinkFuncOptions options={})
Definition layer.hpp:694
const torch::nn::functional::HardshrinkFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:709
Hardsigmoid activation function.
Definition layer.hpp:769
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:776
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:800
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:789
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:782
~Hardsigmoid() override=default
Hardswish activation function.
Definition layer.hpp:822
~Hardswish() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardswish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:842
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:829
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardswish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:853
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:835
Hardtanh activation function.
Definition layer.hpp:875
const torch::nn::functional::HardtanhFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:894
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardtanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:932
Hardtanh(torch::nn::functional::HardtanhFuncOptions options={})
Definition layer.hpp:877
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:905
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:889
torch::nn::functional::HardtanhFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:899
~Hardtanh() override=default
Hardtanh(double min_val, double max_val, bool inplace=false)
Definition layer.hpp:880
torch::nn::functional::HardtanhFuncOptions options_
Definition layer.hpp:951
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardtanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:915
Instance Normalization as described in the paper.
Definition layer.hpp:958
torch::nn::functional::InstanceNormFuncOptions options_
Definition layer.hpp:1058
torch::nn::functional::InstanceNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:991
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:981
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="instance_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1015
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:997
const torch::nn::functional::InstanceNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:986
InstanceNorm(torch::nn::functional::InstanceNormFuncOptions options={})
Definition layer.hpp:960
InstanceNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool use_input_stats=true)
Definition layer.hpp:964
~InstanceNorm() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="instance_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1035
Layer Normalization as described in the paper.
Definition layer.hpp:1064
const torch::nn::functional::LayerNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1090
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="layer_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1116
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1101
LayerNorm(std::vector< int64_t > normalized_shape, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition layer.hpp:1073
LayerNorm(std::vector< int64_t > normalized_shape)
Definition layer.hpp:1066
~LayerNorm() override=default
torch::nn::functional::LayerNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1095
LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
Definition layer.hpp:1070
torch::nn::functional::LayerNormFuncOptions options_
Definition layer.hpp:1147
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1085
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="layer_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1130
Leaky ReLU activation function.
Definition layer.hpp:1159
~LeakyReLU() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1188
const torch::nn::functional::LeakyReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1177
LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options={})
Definition layer.hpp:1161
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="leaky_relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1197
torch::nn::functional::LeakyReLUFuncOptions options_
Definition layer.hpp:1230
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1172
torch::nn::functional::LeakyReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1182
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="leaky_relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1213
LeakyReLU(double negative_slope, bool inplace=false)
Definition layer.hpp:1164
Local response Normalization.
Definition layer.hpp:1234
torch::nn::functional::LocalResponseNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1263
LocalResponseNorm(const torch::nn::functional::LocalResponseNormFuncOptions &options)
Definition layer.hpp:1239
torch::nn::functional::LocalResponseNormFuncOptions options_
Definition layer.hpp:1318
LocalResponseNorm(int64_t size, double alpha, double beta, double k)
Definition layer.hpp:1243
LocalResponseNorm(int64_t size)
Definition layer.hpp:1236
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="local_response_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1278
~LocalResponseNorm() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1252
const torch::nn::functional::LocalResponseNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1258
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="local_response_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1297
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1269
LogSigmoid activation function.
Definition layer.hpp:1326
~LogSigmoid() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1357
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1346
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1339
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1333
LogSoftmax activation function.
Definition layer.hpp:1377
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1389
torch::nn::functional::LogSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1399
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsoftmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1413
LogSoftmax(int64_t dim)
Definition layer.hpp:1379
LogSoftmax(const torch::nn::functional::LogSoftmaxFuncOptions &options)
Definition layer.hpp:1382
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1405
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsoftmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1424
torch::nn::functional::LogSoftmaxFuncOptions options_
Definition layer.hpp:1436
~LogSoftmax() override=default
const torch::nn::functional::LogSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1394
Mish activation function.
Definition layer.hpp:1444
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1451
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="mish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1474
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="mish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1464
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1457
No-op activation function.
Definition layer.hpp:92
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="none") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:107
virtual void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:100
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:95
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="none") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:117
Lp Normalization.
Definition layer.hpp:1487
~Normalize() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1516
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="normalize") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1524
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="normalize") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1539
torch::nn::functional::NormalizeFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1510
Normalize(double p, double eps, int64_t dim)
Definition layer.hpp:1492
Normalize(torch::nn::functional::NormalizeFuncOptions options={})
Definition layer.hpp:1489
torch::nn::functional::NormalizeFuncOptions options_
Definition layer.hpp:1558
const torch::nn::functional::NormalizeFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1505
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1500
PReLU activation function.
Definition layer.hpp:1562
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1575
torch::Tensor weight_
Definition layer.hpp:1616
~PReLU() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="prelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1591
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="prelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1602
const torch::Tensor & weight() const
Returns constant reference to weights.
Definition layer.hpp:1569
PReLU(const torch::Tensor &weight)
Definition layer.hpp:1564
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1581
torch::Tensor & weight()
Returns non-constant reference to weights.
Definition layer.hpp:1572
Randomized ReLU activation function.
Definition layer.hpp:1764
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1792
RReLU(double lower, double upper, bool inplace=false)
Definition layer.hpp:1769
const torch::nn::functional::RReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1783
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1778
torch::nn::functional::RReLUFuncOptions options_
Definition layer.hpp:1836
~RReLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="rrelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1817
torch::nn::functional::RReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1788
RReLU(torch::nn::functional::RReLUFuncOptions options={})
Definition layer.hpp:1766
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="rrelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1801
ReLU6 activation function.
Definition layer.hpp:1692
~ReLU6() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1717
ReLU6(torch::nn::functional::ReLU6FuncOptions options={})
Definition layer.hpp:1694
torch::nn::functional::ReLU6FuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1713
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1703
const torch::nn::functional::ReLU6FuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1708
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu6") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1725
ReLU6(bool inplace)
Definition layer.hpp:1697
torch::nn::functional::ReLU6FuncOptions options_
Definition layer.hpp:1752
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu6") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1737
ReLU activation function.
Definition layer.hpp:1624
torch::nn::functional::ReLUFuncOptions options_
Definition layer.hpp:1684
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1635
ReLU(bool inplace)
Definition layer.hpp:1629
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1657
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1649
torch::nn::functional::ReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1645
ReLU(torch::nn::functional::ReLUFuncOptions options={})
Definition layer.hpp:1626
const torch::nn::functional::ReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1640
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1669
SELU activation function.
Definition layer.hpp:1847
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="selu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1892
torch::nn::functional::SELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1868
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1858
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1872
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="selu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1880
const torch::nn::functional::SELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1863
SELU(bool inplace)
Definition layer.hpp:1852
SELU(torch::nn::functional::SELUFuncOptions options={})
Definition layer.hpp:1849
torch::nn::functional::SELUFuncOptions options_
Definition layer.hpp:1907
Sigmoid Linear Unit activation function.
Definition layer.hpp:1959
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="silu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1975
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1962
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="silu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1985
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1968
Sigmoid activation function.
Definition layer.hpp:1915
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1924
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1918
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="sigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1942
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="sigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1931
Softmax activation function.
Definition layer.hpp:2004
torch::nn::functional::SoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2025
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2031
Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
Definition layer.hpp:2009
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2039
const torch::nn::functional::SoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2020
~Softmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2015
Softmax(int64_t dim)
Definition layer.hpp:2006
torch::nn::functional::SoftmaxFuncOptions options_
Definition layer.hpp:2067
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2052
Softmin activation function.
Definition layer.hpp:2075
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmin") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2110
Softmin(int64_t dim)
Definition layer.hpp:2077
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2086
torch::nn::functional::SoftminFuncOptions options_
Definition layer.hpp:2138
const torch::nn::functional::SoftminFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2091
~Softmin() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2102
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmin") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2123
torch::nn::functional::SoftminFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2096
Softmin(const torch::nn::functional::SoftminFuncOptions &options)
Definition layer.hpp:2080
Softplus activation function.
Definition layer.hpp:2146
torch::nn::functional::SoftplusFuncOptions options_
Definition layer.hpp:2215
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2159
Softplus(torch::nn::functional::SoftplusFuncOptions options={})
Definition layer.hpp:2148
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2175
const torch::nn::functional::SoftplusFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2164
Softplus(double beta, double threshold)
Definition layer.hpp:2151
torch::nn::functional::SoftplusFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2169
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softplus") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2198
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softplus") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2183
~Softplus() override=default
Softshrink activation function.
Definition layer.hpp:2228
const torch::nn::functional::SoftshrinkFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2245
Softshrink(double lambda)
Definition layer.hpp:2233
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2264
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2256
Softshrink(torch::nn::functional::SoftshrinkFuncOptions options={})
Definition layer.hpp:2230
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2240
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2277
torch::nn::functional::SoftshrinkFuncOptions options_
Definition layer.hpp:2292
torch::nn::functional::SoftshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2250
~Softshrink() override=default
Softsign activation function.
Definition layer.hpp:2300
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2303
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softsign") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2327
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2309
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softsign") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2316
Tanh activation function.
Definition layer.hpp:2344
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2360
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2347
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2370
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2353
Tanhshrink activation function.
Definition layer.hpp:2387
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2390
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanhshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2414
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2396
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanhshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2403
Threshold activation function.
Definition layer.hpp:2435
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2463
const torch::nn::functional::ThresholdFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2452
~Threshold() override=default
Threshold(double threshold, double value, bool inplace=false)
Definition layer.hpp:2440
torch::nn::functional::ThresholdFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2457
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2447
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="threshold") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2490
Threshold(torch::nn::functional::ThresholdFuncOptions options)
Definition layer.hpp:2437
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="threshold") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2473
torch::nn::functional::ThresholdFuncOptions options_
Definition layer.hpp:2509
Full qualified name descriptor.
Definition fqn.hpp:26
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:31
Full qualified name utility functions.
Definition boundary.hpp:22
bool is_verbose(std::ostream &os)
Definition core.hpp:831
constexpr bool is_SplineType_v
Alias to the value of is_SplineType.
Definition bspline.hpp:3243
std::ostream & operator<<(std::ostream &os, const Boundary< Spline > &obj)
Print (as string) a Boundary object.
Definition boundary.hpp:1978
struct iganet::@0 Log
Logger.
@ none
Definition boundary.hpp:38
activation
Enumerator for nonlinear activation functions.
Definition layer.hpp:23
short int short_t
Definition core.hpp:74