IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
activation.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <iostream>
18#include <utility>
19
20#include <core/core.hpp>
21#include <utils/fqn.hpp>
22
23namespace iganet {
24
26enum class activation : short_t {
27 none = 0,
28 batch_norm = 1,
29 celu = 2,
30 elu = 3,
31 gelu = 4,
32 glu = 5,
33 group_norm = 6,
35 hardshrink = 9,
36 hardsigmoid = 8,
37 hardswish = 10,
38 hardtanh = 11,
39 instance_norm = 12,
40 layer_norm = 13,
41 leaky_relu = 14,
43 logsigmoid = 16,
44 logsoftmax = 17,
45 mish = 18,
46 normalize = 19,
47 prelu = 20,
48 relu = 21,
49 relu6 = 22,
50 rrelu = 23,
51 selu = 24,
52 sigmoid = 25,
53 silu = 26,
54 softmax = 27,
55 softmin = 28,
56 softplus = 29,
57 softshrink = 30,
58 softsign = 31,
59 tanh = 32,
60 tanhshrink = 33,
61 threshold = 34
62};
63
66public:
67 ~ActivationFunction() override = default;
68
70 virtual torch::Tensor apply(const torch::Tensor &) const = 0;
71
73 void pretty_print(std::ostream &os) const noexcept override = 0;
74
77 virtual torch::serialize::OutputArchive &
78 write(torch::serialize::OutputArchive &archive,
79 const std::string &key) const = 0;
80
83 virtual torch::serialize::InputArchive &
84 read(torch::serialize::InputArchive &archive, const std::string &key) = 0;
85};
86
88inline std::ostream &operator<<(std::ostream &os,
89 const ActivationFunction &obj) {
90 obj.pretty_print(os);
91 return os;
92}
93
95class None : public ActivationFunction {
96public:
98 inline torch::Tensor apply(const torch::Tensor &input) const override {
99 return input;
100 }
101
103 inline void pretty_print(std::ostream &os) const noexcept override {
105 }
106
109 inline torch::serialize::OutputArchive &
110 write(torch::serialize::OutputArchive &archive,
111 const std::string &key = "none") const override {
112 archive.write(key + ".type",
113 torch::full({1}, static_cast<int64_t>(activation::none)));
114
115 return archive;
116 }
117
120 inline torch::serialize::InputArchive &
121 read(torch::serialize::InputArchive &archive,
122 const std::string &key = "none") override {
123 torch::Tensor tensor;
124
125 archive.read(key + ".type", tensor);
126 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::none))
127 throw std::runtime_error("activation mismatch");
128
129 return archive;
130 }
131};
132
139public:
140 explicit BatchNorm(torch::Tensor running_mean, torch::Tensor running_var,
141 torch::nn::functional::BatchNormFuncOptions options = {})
143 running_var_(std::move(running_var)) {}
144
145 explicit BatchNorm(torch::Tensor running_mean, torch::Tensor running_var,
146 const torch::Tensor &weight, const torch::Tensor &bias,
147 double eps, double momentum, bool training = false)
148 : options_(torch::nn::functional::BatchNormFuncOptions()
149 .weight(weight)
150 .bias(bias)
151 .eps(eps)
152 .momentum(momentum)
153 .training(training)),
155 running_var_(std::move(running_var)) {}
156
157 ~BatchNorm() override = default;
158
160 inline torch::Tensor apply(const torch::Tensor &input) const override {
161 return torch::nn::functional::batch_norm(input, running_mean_, running_var_,
162 options_);
163 }
164
166 inline const torch::Tensor &running_mean() const { return running_mean_; }
167
169 inline torch::Tensor &running_mean() { return running_mean_; }
170
172 inline const torch::Tensor &running_var() const { return running_var_; }
173
175 inline torch::Tensor &running_var() { return running_var_; }
176
178 inline const torch::nn::functional::BatchNormFuncOptions &options() const {
179 return options_;
180 }
181
183 inline torch::nn::functional::BatchNormFuncOptions &options() {
184 return options_;
185 }
186
188 inline void pretty_print(std::ostream &os) const noexcept override {
189 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
190 << ", momentum="
191 << options_
192 .momentum()
193#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
194 .value()
195#endif
196 << ", training=" << options_.training();
197
198 if (is_verbose(os)) {
199 os << "\n running_mean = " << running_mean()
200 << "\n running_var = " << running_var()
201 << "\n weight = " << options_.weight()
202 << "\n bias = " << options_.bias();
203 }
204
205 os << "\n)";
206 }
207
210 inline torch::serialize::OutputArchive &
211 write(torch::serialize::OutputArchive &archive,
212 const std::string &key = "batch_norm") const override {
213 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
215 archive.write(key + ".running_mean", this->running_mean());
216 archive.write(key + ".running_var", this->running_var());
217 archive.write(key + ".weight", this->options_.weight());
218 archive.write(key + ".bias", this->options_.bias());
219 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
220 archive.write(key + ".momentum", torch::full({1}, (double)this->options_
221 .momentum()
222#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
223 .value()
224#endif
225 ));
226 archive.write(key + ".training",
227 torch::full({1}, (bool)this->options_.training()));
228
229 return archive;
230 }
231
234 inline torch::serialize::InputArchive &
235 read(torch::serialize::InputArchive &archive,
236 const std::string &key = "batch_norm") override {
237 torch::Tensor tensor;
238
239 archive.read(key + ".type", tensor);
240 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::batch_norm))
241 throw std::runtime_error("activation mismatch");
242
243 archive.read(key + ".running_mean", this->running_mean());
244 archive.read(key + ".running_var", this->running_var());
245 archive.read(key + ".weight", this->options_.weight());
246 archive.read(key + ".bias", this->options_.bias());
247 archive.read(key + ".eps", tensor);
248 this->options_.eps(tensor.item<double>());
249 archive.read(key + ".momentum", tensor);
250 this->options_.momentum(tensor.item<double>());
251 archive.read(key + ".training", tensor);
252 this->options_.training(tensor.item<bool>());
253
254 return archive;
255 }
256
257private:
258 torch::nn::functional::BatchNormFuncOptions options_;
260};
261
268class CELU : public ActivationFunction {
269public:
270 explicit CELU(torch::nn::functional::CELUFuncOptions options = {})
271 : options_(options) {}
272
273 explicit CELU(double alpha, bool inplace = false)
274 : options_(torch::nn::functional::CELUFuncOptions().alpha(alpha).inplace(
275 inplace)) {}
276
277 ~CELU() override = default;
278
280 inline torch::Tensor apply(const torch::Tensor &input) const override {
281 return torch::nn::functional::celu(input, options_);
282 }
283
285 inline const torch::nn::functional::CELUFuncOptions &options() const {
286 return options_;
287 }
288
290 inline torch::nn::functional::CELUFuncOptions &options() { return options_; }
291
293 inline void pretty_print(std::ostream &os) const noexcept override {
294 os << utils::FullQualifiedName::name() << "(\n alpha=" << options_.alpha()
295 << ", inplace=" << options_.inplace() << "\n)";
296 }
297
300 inline torch::serialize::OutputArchive &
301 write(torch::serialize::OutputArchive &archive,
302 const std::string &key = "celu") const override {
303 archive.write(key + ".type",
304 torch::full({1}, static_cast<int64_t>(activation::celu)));
305 archive.write(key + ".alpha",
306 torch::full({1}, (double)this->options_.alpha()));
307 archive.write(key + ".inplace",
308 torch::full({1}, (bool)this->options_.inplace()));
309
310 return archive;
311 }
312
315 inline torch::serialize::InputArchive &
316 read(torch::serialize::InputArchive &archive,
317 const std::string &key = "celu") override {
318 torch::Tensor tensor;
319
320 archive.read(key + ".type", tensor);
321 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::celu))
322 throw std::runtime_error("activation mismatch");
323
324 archive.read(key + ".alpha", tensor);
325 this->options_.alpha(tensor.item<double>());
326 archive.read(key + ".inplace", tensor);
327 this->options_.inplace(tensor.item<bool>());
328
329 return archive;
330 }
331
332private:
333 torch::nn::functional::CELUFuncOptions options_;
334};
335
345class ELU : public ActivationFunction {
346public:
347 explicit ELU(torch::nn::functional::ELUFuncOptions options = {})
348 : options_(options) {}
349
350 explicit ELU(double alpha, bool inplace = false)
351 : options_(torch::nn::functional::ELUFuncOptions().alpha(alpha).inplace(
352 inplace)) {}
353
354 ~ELU() override = default;
355
357 inline torch::Tensor apply(const torch::Tensor &input) const override {
358 return torch::nn::functional::elu(input, options_);
359 }
360
362 inline const torch::nn::functional::ELUFuncOptions &options() const {
363 return options_;
364 }
365
367 inline torch::nn::functional::ELUFuncOptions &options() { return options_; }
368
370 inline void
371 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
372 os << utils::FullQualifiedName::name() << "(\n alpha=" << options_.alpha()
373 << ", inplace=" << options_.inplace() << "\n)";
374 }
375
378 inline torch::serialize::OutputArchive &
379 write(torch::serialize::OutputArchive &archive,
380 const std::string &key = "elu") const override {
381 archive.write(key + ".type",
382 torch::full({1}, static_cast<int64_t>(activation::elu)));
383 archive.write(key + ".alpha",
384 torch::full({1}, (double)this->options_.alpha()));
385 archive.write(key + ".inplace",
386 torch::full({1}, (bool)this->options_.inplace()));
387
388 return archive;
389 }
390
393 inline torch::serialize::InputArchive &
394 read(torch::serialize::InputArchive &archive,
395 const std::string &key = "elu") override {
396 torch::Tensor tensor;
397
398 archive.read(key + ".type", tensor);
399 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::elu))
400 throw std::runtime_error("activation mismatch");
401
402 archive.read(key + ".alpha", tensor);
403 this->options_.alpha(tensor.item<double>());
404 archive.read(key + ".inplace", tensor);
405 this->options_.inplace(tensor.item<bool>());
406
407 return archive;
408 }
409
410private:
411 torch::nn::functional::ELUFuncOptions options_;
412};
413
422class GELU : public ActivationFunction {
423public:
424 explicit GELU() = default;
425
426 ~GELU() override = default;
427
429 inline torch::Tensor apply(const torch::Tensor &input) const override {
430 return torch::gelu(input);
431 }
432
434 inline void pretty_print(std::ostream &os) const noexcept override {
436 }
437
440 inline torch::serialize::OutputArchive &
441 write(torch::serialize::OutputArchive &archive,
442 const std::string &key = "gelu") const override {
443 archive.write(key + ".type",
444 torch::full({1}, static_cast<int64_t>(activation::gelu)));
445
446 return archive;
447 }
448
451 inline torch::serialize::InputArchive &
452 read(torch::serialize::InputArchive &archive,
453 const std::string &key = "gelu") override {
454 torch::Tensor tensor;
455
456 archive.read(key + ".type", tensor);
457 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::gelu))
458 throw std::runtime_error("activation mismatch");
459
460 return archive;
461 }
462};
463
473class GLU : public ActivationFunction {
474public:
475 explicit GLU(torch::nn::functional::GLUFuncOptions options = {})
476 : options_(options) {}
477
478 explicit GLU(int64_t dim)
479 : options_(torch::nn::functional::GLUFuncOptions().dim(dim)) {}
480
481 ~GLU() override = default;
482
484 inline torch::Tensor apply(const torch::Tensor &input) const override {
485 return torch::nn::functional::glu(input, options_);
486 }
487
489 inline const torch::nn::functional::GLUFuncOptions &options() const {
490 return options_;
491 }
492
494 inline torch::nn::functional::GLUFuncOptions &options() { return options_; }
495
497 inline void pretty_print(std::ostream &os) const noexcept override {
498 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
499 << "\n)";
500 }
501
504 inline torch::serialize::OutputArchive &
505 write(torch::serialize::OutputArchive &archive,
506 const std::string &key = "glu") const override {
507 archive.write(key + ".type",
508 torch::full({1}, static_cast<int64_t>(activation::glu)));
509 archive.write(key + ".dim",
510 torch::full({1}, static_cast<int>(this->options_.dim())));
511
512 return archive;
513 }
514
517 inline torch::serialize::InputArchive &
518 read(torch::serialize::InputArchive &archive,
519 const std::string &key = "glu") override {
520 torch::Tensor tensor;
521
522 archive.read(key + ".type", tensor);
523 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::glu))
524 throw std::runtime_error("activation mismatch");
525
526 archive.read(key + ".dim", tensor);
527 this->options_.dim(tensor.item<int>());
528
529 return archive;
530 }
531
532private:
533 torch::nn::functional::GLUFuncOptions options_;
534};
535
539public:
540 explicit GroupNorm(int64_t num_groups)
541 : options_(torch::nn::functional::GroupNormFuncOptions(num_groups)) {}
542
543 explicit GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
544 : options_(std::move(options)) {}
545
546 explicit GroupNorm(int64_t num_groups, const torch::Tensor &weight,
547 const torch::Tensor &bias, double eps)
548 : options_(torch::nn::functional::GroupNormFuncOptions(num_groups)
549 .weight(weight)
550 .bias(bias)
551 .eps(eps)) {}
552
553 ~GroupNorm() override = default;
554
556 inline torch::Tensor apply(const torch::Tensor &input) const override {
557 return torch::nn::functional::group_norm(input, options_);
558 }
559
561 inline const torch::nn::functional::GroupNormFuncOptions &options() const {
562 return options_;
563 }
564
566 inline torch::nn::functional::GroupNormFuncOptions &options() {
567 return options_;
568 }
569
571 inline void
572 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
573 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps();
574
575 if (is_verbose(os)) {
576 os << "\n weight = " << options_.weight()
577 << "\n bias = " << options_.bias();
578 }
579
580 os << "\n)";
581 }
582
585 inline torch::serialize::OutputArchive &
586 write(torch::serialize::OutputArchive &archive,
587 const std::string &key = "group_norm") const override {
588 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
590 archive.write(key + ".weight", this->options_.weight());
591 archive.write(key + ".bias", this->options_.bias());
592 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
593
594 return archive;
595 }
596
599 inline torch::serialize::InputArchive &
600 read(torch::serialize::InputArchive &archive,
601 const std::string &key = "group_norm") override {
602 torch::Tensor tensor;
603
604 archive.read(key + ".type", tensor);
605 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::group_norm))
606 throw std::runtime_error("activation mismatch");
607
608 archive.read(key + ".weight", this->options_.weight());
609 archive.read(key + ".bias", this->options_.bias());
610 archive.read(key + ".eps", tensor);
611 this->options_.eps(tensor.item<double>());
612
613 return archive;
614 }
615
616private:
617 torch::nn::functional::GroupNormFuncOptions options_;
618};
619
622public:
624 torch::nn::functional::GumbelSoftmaxFuncOptions options = {})
625 : options_(options) {}
626
627 explicit GumbelSoftmax(double tau, int dim, bool hard)
628 : options_(torch::nn::functional::GumbelSoftmaxFuncOptions()
629 .tau(tau)
630 .dim(dim)
631 .hard(hard)) {}
632
633 ~GumbelSoftmax() override = default;
634
636 inline torch::Tensor apply(const torch::Tensor &input) const override {
637 return torch::nn::functional::gumbel_softmax(input, options_);
638 }
639
641 inline const torch::nn::functional::GumbelSoftmaxFuncOptions &
642 options() const {
643 return options_;
644 }
645
647 inline torch::nn::functional::GumbelSoftmaxFuncOptions &options() {
648 return options_;
649 }
650
652 inline void
653 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
654 os << utils::FullQualifiedName::name() << "(\n tau=" << options_.tau()
655 << ", dim=" << options_.dim() << ", hard=" << options_.hard() << "\n)";
656 }
657
660 inline torch::serialize::OutputArchive &
661 write(torch::serialize::OutputArchive &archive,
662 const std::string &key = "gumbel_softmax") const override {
663 archive.write(
664 key + ".type",
665 torch::full({1}, static_cast<int64_t>(activation::gumbel_softmax)));
666 archive.write(key + ".tau", torch::full({1}, (double)this->options_.tau()));
667 archive.write(key + ".dim", torch::full({1}, (int)this->options_.dim()));
668 archive.write(key + ".hard", torch::full({1}, (bool)this->options_.hard()));
669
670 return archive;
671 }
672
675 inline torch::serialize::InputArchive &
676 read(torch::serialize::InputArchive &archive,
677 const std::string &key = "gumbel_softmax") override {
678 torch::Tensor tensor;
679
680 archive.read(key + ".type", tensor);
681 if (tensor.item<int64_t>() !=
682 static_cast<int64_t>(activation::gumbel_softmax))
683 throw std::runtime_error("activation mismatch");
684
685 archive.read(key + ".tau", tensor);
686 this->options_.tau(tensor.item<double>());
687 archive.read(key + ".dim", tensor);
688 this->options_.dim(tensor.item<int>());
689 archive.read(key + ".hard", tensor);
690 this->options_.hard(tensor.item<bool>());
691
692 return archive;
693 }
694
695private:
696 torch::nn::functional::GumbelSoftmaxFuncOptions options_;
697};
698
701public:
702 explicit Hardshrink(torch::nn::functional::HardshrinkFuncOptions options = {})
703 : options_(options) {}
704
705 explicit Hardshrink(double lambda)
706 : options_(
707 torch::nn::functional::HardshrinkFuncOptions().lambda(lambda)) {}
708
709 ~Hardshrink() override = default;
710
712 inline torch::Tensor apply(const torch::Tensor &input) const override {
713 return torch::nn::functional::hardshrink(input, options_);
714 }
715
717 inline const torch::nn::functional::HardshrinkFuncOptions &options() const {
718 return options_;
719 }
720
722 inline torch::nn::functional::HardshrinkFuncOptions &options() {
723 return options_;
724 }
725
727 inline void
728 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
730 << "(\n lambda=" << options_.lambda() << "\n)";
731 }
732
735 inline torch::serialize::OutputArchive &
736 write(torch::serialize::OutputArchive &archive,
737 const std::string &key = "hardshrink") const override {
738 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
740 archive.write(key + ".lambda",
741 torch::full({1}, (double)this->options_.lambda()));
742
743 return archive;
744 }
745
748 inline torch::serialize::InputArchive &
749 read(torch::serialize::InputArchive &archive,
750 const std::string &key = "hardshrink") override {
751 torch::Tensor tensor;
752
753 archive.read(key + ".type", tensor);
754 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::hardshrink))
755 throw std::runtime_error("activation mismatch");
756
757 archive.read(key + ".lambda", tensor);
758 this->options_.lambda(tensor.item<double>());
759
760 return archive;
761 }
762
763private:
764 torch::nn::functional::HardshrinkFuncOptions options_;
765};
766
778public:
779 explicit Hardsigmoid() = default;
780
781 ~Hardsigmoid() override = default;
782
784 inline torch::Tensor apply(const torch::Tensor &input) const override {
785 return torch::hardsigmoid(input);
786 }
787
789 inline void
790 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
792 }
793
796 inline torch::serialize::OutputArchive &
797 write(torch::serialize::OutputArchive &archive,
798 const std::string &key = "hardsigmoid") const override {
799 archive.write(
800 key + ".type",
801 torch::full({1}, static_cast<int64_t>(activation::hardsigmoid)));
802
803 return archive;
804 }
805
808 inline torch::serialize::InputArchive &
809 read(torch::serialize::InputArchive &archive,
810 const std::string &key = "hardsigmoid") override {
811 torch::Tensor tensor;
812
813 archive.read(key + ".type", tensor);
814 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::hardsigmoid))
815 throw std::runtime_error("activation mismatch");
816
817 return archive;
818 }
819};
820
832public:
833 explicit Hardswish() = default;
834
835 ~Hardswish() override = default;
836
838 inline torch::Tensor apply(const torch::Tensor &input) const override {
839 return torch::hardswish(input);
840 }
841
843 inline void
844 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
846 }
847
850 inline torch::serialize::OutputArchive &
851 write(torch::serialize::OutputArchive &archive,
852 const std::string &key = "hardswish") const override {
853 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
855
856 return archive;
857 }
858
861 inline torch::serialize::InputArchive &
862 read(torch::serialize::InputArchive &archive,
863 const std::string &key = "hardswish") override {
864 torch::Tensor tensor;
865
866 archive.read(key + ".type", tensor);
867 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::hardswish))
868 throw std::runtime_error("activation mismatch");
869
870 return archive;
871 }
872};
873
885public:
886 explicit Hardtanh(
887 const torch::nn::functional::HardtanhFuncOptions &options = {})
888 : options_(options) {}
889
890 explicit Hardtanh(double min_val, double max_val, bool inplace = false)
891 : options_(torch::nn::functional::HardtanhFuncOptions()
892 .min_val(min_val)
893 .max_val(max_val)
894 .inplace(inplace)) {}
895
896 ~Hardtanh() override = default;
897
899 inline torch::Tensor apply(const torch::Tensor &input) const override {
900 return torch::nn::functional::hardtanh(input, options_);
901 }
902
904 inline const torch::nn::functional::HardtanhFuncOptions &options() const {
905 return options_;
906 }
907
909 inline torch::nn::functional::HardtanhFuncOptions &options() {
910 return options_;
911 }
912
914 inline void
915 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
917 << "(\n min_val=" << options_.min_val()
918 << ", max_val=" << options_.max_val()
919 << ", inplace=" << options_.inplace() << "\n)";
920 }
921
924 inline torch::serialize::OutputArchive &
925 write(torch::serialize::OutputArchive &archive,
926 const std::string &key = "hardtanh") const override {
927 archive.write(key + ".type",
928 torch::full({1}, static_cast<int64_t>(activation::hardtanh)));
929 archive.write(key + ".min_val",
930 torch::full({1}, (double)this->options_.min_val()));
931 archive.write(key + ".max_val",
932 torch::full({1}, (double)this->options_.max_val()));
933 archive.write(key + ".inplace",
934 torch::full({1}, (bool)this->options_.inplace()));
935
936 return archive;
937 }
938
941 inline torch::serialize::InputArchive &
942 read(torch::serialize::InputArchive &archive,
943 const std::string &key = "hardtanh") override {
944 torch::Tensor tensor;
945
946 archive.read(key + ".type", tensor);
947 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::hardtanh))
948 throw std::runtime_error("activation mismatch");
949
950 archive.read(key + ".min_val", tensor);
951 this->options_.min_val(tensor.item<double>());
952 archive.read(key + ".max_val", tensor);
953 this->options_.max_val(tensor.item<double>());
954 archive.read(key + ".inplace", tensor);
955 this->options_.inplace(tensor.item<bool>());
956
957 return archive;
958 }
959
960private:
961 torch::nn::functional::HardtanhFuncOptions options_;
962};
963
969public:
970 explicit InstanceNorm(
971 torch::nn::functional::InstanceNormFuncOptions options = {})
972 : options_(std::move(options)) {}
973
974 explicit InstanceNorm(const torch::Tensor &running_mean,
975 const torch::Tensor &running_var,
976 const torch::Tensor &weight, const torch::Tensor &bias,
977 double eps, double momentum,
978 bool use_input_stats = true)
979 : options_(torch::nn::functional::InstanceNormFuncOptions()
980 .running_mean(running_mean)
981 .running_var(running_var)
982 .weight(weight)
983 .bias(bias)
984 .eps(eps)
985 .momentum(momentum)
986 .use_input_stats(use_input_stats)) {}
987
988 ~InstanceNorm() override = default;
989
991 inline torch::Tensor apply(const torch::Tensor &input) const override {
992 return torch::nn::functional::instance_norm(input, options_);
993 }
994
996 inline const torch::nn::functional::InstanceNormFuncOptions &options() const {
997 return options_;
998 }
999
1001 inline torch::nn::functional::InstanceNormFuncOptions &options() {
1002 return options_;
1003 }
1004
1006 inline void
1007 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1008 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
1009 << ", momentum=" << options_.momentum()
1010 << ", use_input_stats=" << options_.use_input_stats();
1011
1012 if (is_verbose(os)) {
1013 os << "\n running_mean = " << options_.running_mean()
1014 << "\n running_var = " << options_.running_var()
1015 << "\n weight = " << options_.weight()
1016 << "\n bias = " << options_.bias();
1017 }
1018
1019 os << "\n)";
1020 }
1021
1024 inline torch::serialize::OutputArchive &
1025 write(torch::serialize::OutputArchive &archive,
1026 const std::string &key = "instance_norm") const override {
1027 archive.write(
1028 key + ".type",
1029 torch::full({1}, static_cast<int64_t>(activation::instance_norm)));
1030 archive.write(key + ".running_mean", this->options_.running_mean());
1031 archive.write(key + ".var", this->options_.running_var());
1032 archive.write(key + ".weight", this->options_.weight());
1033 archive.write(key + ".bias", this->options_.bias());
1034 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1035 archive.write(key + ".momentum",
1036 torch::full({1}, (double)this->options_.momentum()));
1037 archive.write(key + ".use_input_stats",
1038 torch::full({1}, (bool)this->options_.use_input_stats()));
1039
1040 return archive;
1041 }
1042
1045 inline torch::serialize::InputArchive &
1046 read(torch::serialize::InputArchive &archive,
1047 const std::string &key = "instance_norm") override {
1048 torch::Tensor tensor;
1049
1050 archive.read(key + ".type", tensor);
1051 if (tensor.item<int64_t>() !=
1052 static_cast<int64_t>(activation::instance_norm))
1053 throw std::runtime_error("activation mismatch");
1054
1055 archive.read(key + ".running_mean", this->options_.running_mean());
1056 archive.read(key + ".running_var", this->options_.running_var());
1057 archive.read(key + ".weight", this->options_.weight());
1058 archive.read(key + ".bias", this->options_.bias());
1059 archive.read(key + ".eps", tensor);
1060 this->options_.eps(tensor.item<double>());
1061 archive.read(key + ".momentum", tensor);
1062 this->options_.momentum(tensor.item<double>());
1063 archive.read(key + ".use_input_stats", tensor);
1064 this->options_.use_input_stats(tensor.item<bool>());
1065
1066 return archive;
1067 }
1068
1069private:
1070 torch::nn::functional::InstanceNormFuncOptions options_;
1071};
1072
1077public:
1078 explicit LayerNorm(std::vector<int64_t> normalized_shape)
1079 : options_(torch::nn::functional::LayerNormFuncOptions(
1080 std::move(normalized_shape))) {}
1081
1082 explicit LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
1083 : options_(std::move(options)) {}
1084
1085 explicit LayerNorm(std::vector<int64_t> normalized_shape,
1086 const torch::Tensor &weight, const torch::Tensor &bias,
1087 double eps)
1088 : options_(torch::nn::functional::LayerNormFuncOptions(
1089 std::move(normalized_shape))
1090 .weight(weight)
1091 .bias(bias)
1092 .eps(eps)) {}
1093
1094 ~LayerNorm() override = default;
1095
1097 inline torch::Tensor apply(const torch::Tensor &input) const override {
1098 return torch::nn::functional::layer_norm(input, options_);
1099 }
1100
1102 inline const torch::nn::functional::LayerNormFuncOptions &options() const {
1103 return options_;
1104 }
1105
1107 inline torch::nn::functional::LayerNormFuncOptions &options() {
1108 return options_;
1109 }
1110
1112 inline void pretty_print(std::ostream &os) const noexcept override {
1113 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps();
1114
1115 if (is_verbose(os)) {
1116 os << "\n normalized_shape = " << options_.normalized_shape()
1117 << "\n weight = " << options_.weight()
1118 << "\n bias = " << options_.bias();
1119 }
1120
1121 os << "\n)";
1122 }
1123
1126 inline torch::serialize::OutputArchive &
1127 write(torch::serialize::OutputArchive &archive,
1128 const std::string &key = "layer_norm") const override {
1129 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
1131 archive.write(key + ".weight", this->options_.weight());
1132 archive.write(key + ".bias", this->options_.bias());
1133 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1134
1135 return archive;
1136 }
1137
1140 inline torch::serialize::InputArchive &
1141 read(torch::serialize::InputArchive &archive,
1142 const std::string &key = "layer_norm") override {
1143 torch::Tensor tensor;
1144
1145 archive.read(key + ".type", tensor);
1146 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::layer_norm))
1147 throw std::runtime_error("activation mismatch");
1148
1149 archive.read(key + ".weight", this->options_.weight());
1150 archive.read(key + ".bias", this->options_.bias());
1151 archive.read(key + ".eps", tensor);
1152 this->options_.eps(tensor.item<double>());
1153
1154 return archive;
1155 }
1156
1157private:
1158 torch::nn::functional::LayerNormFuncOptions options_;
1159};
1160
1171public:
1172 explicit LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options = {})
1173 : options_(options) {}
1174
1175 explicit LeakyReLU(double negative_slope, bool inplace = false)
1176 : options_(torch::nn::functional::LeakyReLUFuncOptions()
1177 .negative_slope(negative_slope)
1178 .inplace(inplace)) {}
1179
1180 ~LeakyReLU() override = default;
1181
1183 inline torch::Tensor apply(const torch::Tensor &input) const override {
1184 return torch::nn::functional::leaky_relu(input, options_);
1185 }
1186
1188 inline const torch::nn::functional::LeakyReLUFuncOptions &options() const {
1189 return options_;
1190 }
1191
1193 inline torch::nn::functional::LeakyReLUFuncOptions &options() {
1194 return options_;
1195 }
1196
1198 inline void pretty_print(std::ostream &os) const noexcept override {
1200 << "(\n negative_slope=" << options_.negative_slope()
1201 << ", inplace=" << options_.inplace() << "\n)";
1202 }
1203
1206 inline torch::serialize::OutputArchive &
1207 write(torch::serialize::OutputArchive &archive,
1208 const std::string &key = "leaky_relu") const override {
1209 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
1211
1212 archive.write(key + ".negative_slope",
1213 torch::full({1}, (double)this->options_.negative_slope()));
1214 archive.write(key + ".inplace",
1215 torch::full({1}, (bool)this->options_.inplace()));
1216
1217 return archive;
1218 }
1219
1222 inline torch::serialize::InputArchive &
1223 read(torch::serialize::InputArchive &archive,
1224 const std::string &key = "leaky_relu") override {
1225 torch::Tensor tensor;
1226
1227 archive.read(key + ".type", tensor);
1228 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::leaky_relu))
1229 throw std::runtime_error("activation mismatch");
1230
1231 archive.read(key + ".negative_slope", tensor);
1232 this->options_.negative_slope(tensor.item<double>());
1233 archive.read(key + ".inplace", tensor);
1234 this->options_.inplace(tensor.item<bool>());
1235
1236 return archive;
1237 }
1238
1239private:
1240 torch::nn::functional::LeakyReLUFuncOptions options_;
1241};
1242
1245public:
1246 explicit LocalResponseNorm(int64_t size)
1247 : options_(torch::nn::functional::LocalResponseNormFuncOptions(size)) {}
1248
1250 const torch::nn::functional::LocalResponseNormFuncOptions &options)
1251 : options_(options) {}
1252
1253 explicit LocalResponseNorm(int64_t size, double alpha, double beta, double k)
1254 : options_(torch::nn::functional::LocalResponseNormFuncOptions(size)
1255 .alpha(alpha)
1256 .beta(beta)
1257 .k(k)) {}
1258
1259 ~LocalResponseNorm() override = default;
1260
1262 inline torch::Tensor apply(const torch::Tensor &input) const override {
1263 return torch::nn::functional::local_response_norm(input, options_);
1264 }
1265
1267 inline const torch::nn::functional::LocalResponseNormFuncOptions &
1268 options() const {
1269 return options_;
1270 }
1271
1273 inline torch::nn::functional::LocalResponseNormFuncOptions &options() {
1274 return options_;
1275 }
1276
1278 inline void pretty_print(std::ostream &os) const noexcept override {
1279 os << utils::FullQualifiedName::name() << "(\n size=" << options_.size()
1280 << ", alpha=" << options_.alpha() << ", beta=" << options_.beta()
1281 << ", k=" << options_.k() << "\n)";
1282 }
1283
1286 inline torch::serialize::OutputArchive &
1287 write(torch::serialize::OutputArchive &archive,
1288 const std::string &key = "local_response_norm") const override {
1289 archive.write(key + ".type",
1290 torch::full({1}, static_cast<int64_t>(
1292
1293 archive.write(key + ".size",
1294 torch::full({1}, (int64_t)this->options_.size()));
1295 archive.write(key + ".alpha",
1296 torch::full({1}, (double)this->options_.alpha()));
1297 archive.write(key + ".beta",
1298 torch::full({1}, (double)this->options_.beta()));
1299 archive.write(key + ".k", torch::full({1}, (double)this->options_.k()));
1300
1301 return archive;
1302 }
1303
1306 inline torch::serialize::InputArchive &
1307 read(torch::serialize::InputArchive &archive,
1308 const std::string &key = "local_response_norm") override {
1309 torch::Tensor tensor;
1310
1311 archive.read(key + ".type", tensor);
1312 if (tensor.item<int64_t>() !=
1313 static_cast<int64_t>(activation::local_response_norm))
1314 throw std::runtime_error("activation mismatch");
1315
1316 archive.read(key + ".size", tensor);
1317 this->options_.size(tensor.item<int64_t>());
1318 archive.read(key + ".alpha", tensor);
1319 this->options_.alpha(tensor.item<double>());
1320 archive.read(key + ".beta", tensor);
1321 this->options_.beta(tensor.item<double>());
1322 archive.read(key + ".k", tensor);
1323 this->options_.k(tensor.item<double>());
1324
1325 return archive;
1326 }
1327
1328private:
1329 torch::nn::functional::LocalResponseNormFuncOptions options_;
1330};
1331
1338public:
1339 explicit LogSigmoid() = default;
1340
1341 ~LogSigmoid() override = default;
1342
1344 inline torch::Tensor apply(const torch::Tensor &input) const override {
1345 return torch::log_sigmoid(input);
1346 }
1347
1349 inline void pretty_print(std::ostream &os) const noexcept override {
1351 }
1352
1355 inline torch::serialize::OutputArchive &
1356 write(torch::serialize::OutputArchive &archive,
1357 const std::string &key = "logsigmoid") const override {
1358 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
1360
1361 return archive;
1362 }
1363
1366 inline torch::serialize::InputArchive &
1367 read(torch::serialize::InputArchive &archive,
1368 const std::string &key = "logsigmoid") override {
1369 torch::Tensor tensor;
1370
1371 archive.read(key + ".type", tensor);
1372 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::logsigmoid))
1373 throw std::runtime_error("activation mismatch");
1374
1375 return archive;
1376 }
1377};
1378
1388public:
1389 explicit LogSoftmax(int64_t dim)
1390 : options_(torch::nn::functional::LogSoftmaxFuncOptions(dim)) {}
1391
1392 explicit LogSoftmax(
1393 const torch::nn::functional::LogSoftmaxFuncOptions &options)
1394 : options_(options) {}
1395
1396 ~LogSoftmax() override = default;
1397
1399 inline torch::Tensor apply(const torch::Tensor &input) const override {
1400 return torch::nn::functional::log_softmax(input, options_);
1401 }
1402
1404 inline const torch::nn::functional::LogSoftmaxFuncOptions &options() const {
1405 return options_;
1406 }
1407
1409 inline torch::nn::functional::LogSoftmaxFuncOptions &options() {
1410 return options_;
1411 }
1412
1414 inline void pretty_print(std::ostream &os) const noexcept override {
1415 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
1416 << "\n)";
1417 }
1418
1421 inline torch::serialize::OutputArchive &
1422 write(torch::serialize::OutputArchive &archive,
1423 const std::string &key = "logsoftmax") const override {
1424 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
1426
1427 return archive;
1428 }
1429
1432 inline torch::serialize::InputArchive &
1433 read(torch::serialize::InputArchive &archive,
1434 const std::string &key = "logsoftmax") override {
1435 torch::Tensor tensor;
1436
1437 archive.read(key + ".type", tensor);
1438 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::logsoftmax))
1439 throw std::runtime_error("activation mismatch");
1440
1441 return archive;
1442 }
1443
1444private:
1445 torch::nn::functional::LogSoftmaxFuncOptions options_;
1446};
1447
1453class Mish : public ActivationFunction {
1454public:
1455 explicit Mish() = default;
1456
1457 ~Mish() override = default;
1458
1460 inline torch::Tensor apply(const torch::Tensor &input) const override {
1461 return torch::mish(input);
1462 }
1463
1465 inline void pretty_print(std::ostream &os) const noexcept override {
1467 }
1468
1471 inline torch::serialize::OutputArchive &
1472 write(torch::serialize::OutputArchive &archive,
1473 const std::string &key = "mish") const override {
1474 archive.write(key + ".type",
1475 torch::full({1}, static_cast<int64_t>(activation::mish)));
1476
1477 return archive;
1478 }
1479
1482 inline torch::serialize::InputArchive &
1483 read(torch::serialize::InputArchive &archive,
1484 const std::string &key = "mish") override {
1485 torch::Tensor tensor;
1486
1487 archive.read(key + ".type", tensor);
1488 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::mish))
1489 throw std::runtime_error("activation mismatch");
1490
1491 return archive;
1492 }
1493};
1494
1497public:
1498 explicit Normalize(torch::nn::functional::NormalizeFuncOptions options = {})
1499 : options_(std::move(options)) {}
1500
1501 explicit Normalize(double p, double eps, int64_t dim)
1502 : options_(
1503 torch::nn::functional::NormalizeFuncOptions().p(p).eps(eps).dim(
1504 dim)) {}
1505
1506 ~Normalize() override = default;
1507
1509 inline torch::Tensor apply(const torch::Tensor &input) const override {
1510 return torch::nn::functional::normalize(input, options_);
1511 }
1512
1514 inline const torch::nn::functional::NormalizeFuncOptions &options() const {
1515 return options_;
1516 }
1517
1519 inline torch::nn::functional::NormalizeFuncOptions &options() {
1520 return options_;
1521 }
1522
1524 inline void pretty_print(std::ostream &os) const noexcept override {
1525 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
1526 << "(\n p=" << options_.p() << "(\n dim=" << options_.dim() << "\n)";
1527 }
1528
1531 inline torch::serialize::OutputArchive &
1532 write(torch::serialize::OutputArchive &archive,
1533 const std::string &key = "normalize") const override {
1534 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
1536 archive.write(key + ".p", torch::full({1}, (double)this->options_.p()));
1537 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1538 archive.write(key + ".dim",
1539 torch::full({1}, (int64_t)this->options_.dim()));
1540
1541 return archive;
1542 }
1543
1546 inline torch::serialize::InputArchive &
1547 read(torch::serialize::InputArchive &archive,
1548 const std::string &key = "normalize") override {
1549 torch::Tensor tensor;
1550
1551 archive.read(key + ".type", tensor);
1552 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::normalize))
1553 throw std::runtime_error("activation mismatch");
1554
1555 archive.read(key + ".p", tensor);
1556 this->options_.p(tensor.item<double>());
1557 archive.read(key + ".eps", tensor);
1558 this->options_.eps(tensor.item<double>());
1559 archive.read(key + ".dim", tensor);
1560 this->options_.dim(tensor.item<int64_t>());
1561
1562 return archive;
1563 }
1564
1565private:
1566 torch::nn::functional::NormalizeFuncOptions options_;
1567};
1568
1571public:
1572 explicit PReLU(torch::Tensor weight) : weight_(std::move(weight)) {}
1573
1574 ~PReLU() override = default;
1575
1577 const torch::Tensor &weight() const { return weight_; }
1578
1580 torch::Tensor &weight() { return weight_; }
1581
1583 inline torch::Tensor apply(const torch::Tensor &input) const override {
1584 return torch::nn::functional::prelu(input, weight());
1585 }
1586
1588 inline void pretty_print(std::ostream &os) const noexcept override {
1590
1591 if (is_verbose(os))
1592 os << "(\n weight = " << weight() << "\n)";
1593 }
1594
1597 inline torch::serialize::OutputArchive &
1598 write(torch::serialize::OutputArchive &archive,
1599 const std::string &key = "prelu") const override {
1600 archive.write(key + ".type",
1601 torch::full({1}, static_cast<int64_t>(activation::prelu)));
1602 archive.write(key + ".weight", this->weight());
1603
1604 return archive;
1605 }
1606
1609 inline torch::serialize::InputArchive &
1610 read(torch::serialize::InputArchive &archive,
1611 const std::string &key = "prelu") override {
1612 torch::Tensor tensor;
1613
1614 archive.read(key + ".type", tensor);
1615 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::prelu))
1616 throw std::runtime_error("activation mismatch");
1617
1618 archive.read(key + ".weight", this->weight());
1619
1620 return archive;
1621 }
1622
1623private:
1624 torch::Tensor weight_;
1625};
1626
1632class ReLU : public ActivationFunction {
1633public:
1634 explicit ReLU(torch::nn::functional::ReLUFuncOptions options = {})
1635 : options_(options) {}
1636
1637 explicit ReLU(bool inplace)
1638 : options_(torch::nn::functional::ReLUFuncOptions().inplace(inplace)) {}
1639
1640 ~ReLU() override = default;
1641
1643 inline torch::Tensor apply(const torch::Tensor &input) const override {
1644 return torch::nn::functional::relu(input, options_);
1645 }
1646
1648 inline const torch::nn::functional::ReLUFuncOptions &options() const {
1649 return options_;
1650 }
1651
1653 inline torch::nn::functional::ReLUFuncOptions &options() { return options_; }
1654
1656 inline void pretty_print(std::ostream &os) const noexcept override {
1658 << "(\n inplace=" << options_.inplace() << "\n)";
1659 }
1660
1663 inline torch::serialize::OutputArchive &
1664 write(torch::serialize::OutputArchive &archive,
1665 const std::string &key = "relu") const override {
1666 archive.write(key + ".type",
1667 torch::full({1}, static_cast<int64_t>(activation::relu)));
1668 archive.write(key + ".inplace",
1669 torch::full({1}, (bool)this->options_.inplace()));
1670
1671 return archive;
1672 }
1673
1676 inline torch::serialize::InputArchive &
1677 read(torch::serialize::InputArchive &archive,
1678 const std::string &key = "relu") override {
1679 torch::Tensor tensor;
1680
1681 archive.read(key + ".type", tensor);
1682 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::relu))
1683 throw std::runtime_error("activation mismatch");
1684
1685 archive.read(key + ".inplace", tensor);
1686 this->options_.inplace(tensor.item<bool>());
1687
1688 return archive;
1689 }
1690
1691private:
1692 torch::nn::functional::ReLUFuncOptions options_;
1693};
1694
1701public:
1702 explicit ReLU6(torch::nn::functional::ReLU6FuncOptions options = {})
1703 : options_(options) {}
1704
1705 explicit ReLU6(bool inplace)
1706 : options_(torch::nn::functional::ReLU6FuncOptions().inplace(inplace)) {}
1707
1708 ~ReLU6() override = default;
1709
1711 inline torch::Tensor apply(const torch::Tensor &input) const override {
1712 return torch::nn::functional::relu6(input, options_);
1713 }
1714
1716 inline const torch::nn::functional::ReLU6FuncOptions &options() const {
1717 return options_;
1718 }
1719
1721 inline torch::nn::functional::ReLU6FuncOptions &options() { return options_; }
1722
1724 inline void pretty_print(std::ostream &os) const noexcept override {
1726 << "(\n inplace=" << options_.inplace() << "\n)";
1727 }
1728
1731 inline torch::serialize::OutputArchive &
1732 write(torch::serialize::OutputArchive &archive,
1733 const std::string &key = "relu6") const override {
1734 archive.write(key + ".type",
1735 torch::full({1}, static_cast<int64_t>(activation::relu6)));
1736 archive.write(key + ".inplace",
1737 torch::full({1}, (bool)this->options_.inplace()));
1738
1739 return archive;
1740 }
1741
1744 inline torch::serialize::InputArchive &
1745 read(torch::serialize::InputArchive &archive,
1746 const std::string &key = "relu6") override {
1747 torch::Tensor tensor;
1748
1749 archive.read(key + ".type", tensor);
1750 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::relu6))
1751 throw std::runtime_error("activation mismatch");
1752
1753 archive.read(key + ".inplace", tensor);
1754 this->options_.inplace(tensor.item<bool>());
1755
1756 return archive;
1757 }
1758
1759private:
1760 torch::nn::functional::ReLU6FuncOptions options_;
1761};
1762
1773public:
1774 explicit RReLU(const torch::nn::functional::RReLUFuncOptions &options = {})
1775 : options_(options) {}
1776
1777 explicit RReLU(double lower, double upper, bool inplace = false)
1778 : options_(torch::nn::functional::RReLUFuncOptions()
1779 .lower(lower)
1780 .upper(upper)
1781 .inplace(inplace)) {}
1782
1783 ~RReLU() override = default;
1784
1786 inline torch::Tensor apply(const torch::Tensor &input) const override {
1787 return torch::nn::functional::rrelu(input, options_);
1788 }
1789
1791 inline const torch::nn::functional::RReLUFuncOptions &options() const {
1792 return options_;
1793 }
1794
1796 inline torch::nn::functional::RReLUFuncOptions &options() { return options_; }
1797
1799 inline void pretty_print(std::ostream &os) const noexcept override {
1800 os << utils::FullQualifiedName::name() << "(\n lower=" << options_.lower()
1801 << ", upper=" << options_.upper() << ", inplace=" << options_.inplace()
1802 << "\n)";
1803 }
1804
1807 inline torch::serialize::OutputArchive &
1808 write(torch::serialize::OutputArchive &archive,
1809 const std::string &key = "rrelu") const override {
1810 archive.write(key + ".type",
1811 torch::full({1}, static_cast<int64_t>(activation::rrelu)));
1812 archive.write(key + ".lower",
1813 torch::full({1}, (double)this->options_.lower()));
1814 archive.write(key + ".upper",
1815 torch::full({1}, (double)this->options_.upper()));
1816 archive.write(key + ".inplace",
1817 torch::full({1}, (bool)this->options_.inplace()));
1818
1819 return archive;
1820 }
1821
1824 inline torch::serialize::InputArchive &
1825 read(torch::serialize::InputArchive &archive,
1826 const std::string &key = "rrelu") override {
1827 torch::Tensor tensor;
1828
1829 archive.read(key + ".type", tensor);
1830 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::rrelu))
1831 throw std::runtime_error("activation mismatch");
1832
1833 archive.read(key + ".lower", tensor);
1834 this->options_.lower(tensor.item<double>());
1835 archive.read(key + ".upper", tensor);
1836 this->options_.upper(tensor.item<double>());
1837 archive.read(key + ".inplace", tensor);
1838 this->options_.inplace(tensor.item<bool>());
1839
1840 return archive;
1841 }
1842
1843private:
1844 torch::nn::functional::RReLUFuncOptions options_;
1845};
1846
1855class SELU : public ActivationFunction {
1856public:
1857 explicit SELU(torch::nn::functional::SELUFuncOptions options = {})
1858 : options_(options) {}
1859
1860 explicit SELU(bool inplace)
1861 : options_(torch::nn::functional::SELUFuncOptions().inplace(inplace)) {}
1862
1863 ~SELU() override = default;
1864
1866 inline torch::Tensor apply(const torch::Tensor &input) const override {
1867 return torch::nn::functional::selu(input, options_);
1868 }
1869
1871 inline const torch::nn::functional::SELUFuncOptions &options() const {
1872 return options_;
1873 }
1874
1876 inline torch::nn::functional::SELUFuncOptions &options() { return options_; }
1877
1879 inline void pretty_print(std::ostream &os) const noexcept override {
1881 << "(\n inplace=" << options_.inplace() << "\n)";
1882 }
1883
1886 inline torch::serialize::OutputArchive &
1887 write(torch::serialize::OutputArchive &archive,
1888 const std::string &key = "selu") const override {
1889 archive.write(key + ".type",
1890 torch::full({1}, static_cast<int64_t>(activation::selu)));
1891 archive.write(key + ".inplace",
1892 torch::full({1}, (bool)this->options_.inplace()));
1893
1894 return archive;
1895 }
1896
1899 inline torch::serialize::InputArchive &
1900 read(torch::serialize::InputArchive &archive,
1901 const std::string &key = "selu") override {
1902 torch::Tensor tensor;
1903
1904 archive.read(key + ".type", tensor);
1905 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::selu))
1906 throw std::runtime_error("activation mismatch");
1907
1908 archive.read(key + ".inplace", tensor);
1909 this->options_.inplace(tensor.item<bool>());
1910
1911 return archive;
1912 }
1913
1914private:
1915 torch::nn::functional::SELUFuncOptions options_;
1916};
1917
1924public:
1926 inline torch::Tensor apply(const torch::Tensor &input) const override {
1927 return torch::sigmoid(input);
1928 }
1929
1931 inline void pretty_print(std::ostream &os) const noexcept override {
1933 }
1934
1937 inline torch::serialize::OutputArchive &
1938 write(torch::serialize::OutputArchive &archive,
1939 const std::string &key = "sigmoid") const override {
1940 archive.write(key + ".type",
1941 torch::full({1}, static_cast<int64_t>(activation::sigmoid)));
1942
1943 return archive;
1944 }
1945
1948 inline torch::serialize::InputArchive &
1949 read(torch::serialize::InputArchive &archive,
1950 const std::string &key = "sigmoid") override {
1951 torch::Tensor tensor;
1952
1953 archive.read(key + ".type", tensor);
1954 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::sigmoid))
1955 throw std::runtime_error("activation mismatch");
1956
1957 return archive;
1958 }
1959};
1960
1966class SiLU : public ActivationFunction {
1967public:
1969 inline torch::Tensor apply(const torch::Tensor &input) const override {
1970 return torch::silu(input);
1971 }
1972
1974 inline void pretty_print(std::ostream &os) const noexcept override {
1976 }
1977
1980 inline torch::serialize::OutputArchive &
1981 write(torch::serialize::OutputArchive &archive,
1982 const std::string &key = "silu") const override {
1983 archive.write(key + ".type",
1984 torch::full({1}, static_cast<int64_t>(activation::silu)));
1985
1986 return archive;
1987 }
1988
1991 inline torch::serialize::InputArchive &
1992 read(torch::serialize::InputArchive &archive,
1993 const std::string &key = "silu") override {
1994 torch::Tensor tensor;
1995
1996 archive.read(key + ".type", tensor);
1997 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::silu))
1998 throw std::runtime_error("activation mismatch");
1999
2000 return archive;
2001 }
2002};
2003
2012public:
2013 explicit Softmax(int64_t dim)
2014 : options_(torch::nn::functional::SoftmaxFuncOptions(dim)) {}
2015
2016 explicit Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
2017 : options_(options) {}
2018
2019 ~Softmax() override = default;
2020
2022 inline torch::Tensor apply(const torch::Tensor &input) const override {
2023 return torch::nn::functional::softmax(input, options_);
2024 }
2025
2027 inline const torch::nn::functional::SoftmaxFuncOptions &options() const {
2028 return options_;
2029 }
2030
2032 inline torch::nn::functional::SoftmaxFuncOptions &options() {
2033 return options_;
2034 }
2035
2037 inline void pretty_print(std::ostream &os) const noexcept override {
2038 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
2039 << "\n)";
2040 }
2041
2044 inline torch::serialize::OutputArchive &
2045 write(torch::serialize::OutputArchive &archive,
2046 const std::string &key = "softmax") const override {
2047 archive.write(key + ".type",
2048 torch::full({1}, static_cast<int64_t>(activation::softmax)));
2049 archive.write(key + ".dim",
2050 torch::full({1}, (int64_t)this->options_.dim()));
2051
2052 return archive;
2053 }
2054
2057 inline torch::serialize::InputArchive &
2058 read(torch::serialize::InputArchive &archive,
2059 const std::string &key = "softmax") override {
2060 torch::Tensor tensor;
2061
2062 archive.read(key + ".type", tensor);
2063 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::softmax))
2064 throw std::runtime_error("activation mismatch");
2065
2066 archive.read(key + ".dim", tensor);
2067 this->options_.dim(tensor.item<int64_t>());
2068
2069 return archive;
2070 }
2071
2072private:
2073 torch::nn::functional::SoftmaxFuncOptions options_;
2074};
2075
2082public:
2083 explicit Softmin(int64_t dim)
2084 : options_(torch::nn::functional::SoftminFuncOptions(dim)) {}
2085
2086 explicit Softmin(const torch::nn::functional::SoftminFuncOptions &options)
2087 : options_(options) {}
2088
2089 ~Softmin() override = default;
2090
2092 inline torch::Tensor apply(const torch::Tensor &input) const override {
2093 return torch::nn::functional::softmin(input, options_);
2094 }
2095
2097 inline const torch::nn::functional::SoftminFuncOptions &options() const {
2098 return options_;
2099 }
2100
2102 inline torch::nn::functional::SoftminFuncOptions &options() {
2103 return options_;
2104 }
2105
2107 inline void pretty_print(std::ostream &os) const noexcept override {
2108 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
2109 << "\n)";
2110 }
2111
2114 inline torch::serialize::OutputArchive &
2115 write(torch::serialize::OutputArchive &archive,
2116 const std::string &key = "softmin") const override {
2117 archive.write(key + ".type",
2118 torch::full({1}, static_cast<int64_t>(activation::softmin)));
2119 archive.write(key + ".dim",
2120 torch::full({1}, (int64_t)this->options_.dim()));
2121
2122 return archive;
2123 }
2124
2127 inline torch::serialize::InputArchive &
2128 read(torch::serialize::InputArchive &archive,
2129 const std::string &key = "softmin") override {
2130 torch::Tensor tensor;
2131
2132 archive.read(key + ".type", tensor);
2133 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::softmin))
2134 throw std::runtime_error("activation mismatch");
2135
2136 archive.read(key + ".dim", tensor);
2137 this->options_.dim(tensor.item<int64_t>());
2138
2139 return archive;
2140 }
2141
2142private:
2143 torch::nn::functional::SoftminFuncOptions options_;
2144};
2145
2152public:
2153 explicit Softplus(torch::nn::functional::SoftplusFuncOptions options = {})
2154 : options_(options) {}
2155
2156 explicit Softplus(double beta, double threshold)
2157 : options_(
2158 torch::nn::functional::SoftplusFuncOptions().beta(beta).threshold(
2159 threshold)) {}
2160
2161 ~Softplus() override = default;
2162
2164 inline torch::Tensor apply(const torch::Tensor &input) const override {
2165 return torch::nn::functional::softplus(input, options_);
2166 }
2167
2169 inline const torch::nn::functional::SoftplusFuncOptions &options() const {
2170 return options_;
2171 }
2172
2174 inline torch::nn::functional::SoftplusFuncOptions &options() {
2175 return options_;
2176 }
2177
2179 inline void pretty_print(std::ostream &os) const noexcept override {
2180 os << utils::FullQualifiedName::name() << "(\n beta=" << options_.beta()
2181 << ", theshold=" << options_.threshold() << "\n)";
2182 }
2183
2186 inline torch::serialize::OutputArchive &
2187 write(torch::serialize::OutputArchive &archive,
2188 const std::string &key = "softplus") const override {
2189 archive.write(key + ".type",
2190 torch::full({1}, static_cast<int64_t>(activation::softplus)));
2191 archive.write(key + ".beta",
2192 torch::full({1}, (double)this->options_.beta()));
2193 archive.write(key + ".threshold",
2194 torch::full({1}, (double)this->options_.threshold()));
2195
2196 return archive;
2197 }
2198
2201 inline torch::serialize::InputArchive &
2202 read(torch::serialize::InputArchive &archive,
2203 const std::string &key = "softplus") override {
2204 torch::Tensor tensor;
2205
2206 archive.read(key + ".type", tensor);
2207 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::softplus))
2208 throw std::runtime_error("activation mismatch");
2209
2210 archive.read(key + ".beta", tensor);
2211 this->options_.beta(tensor.item<double>());
2212 archive.read(key + ".threshold", tensor);
2213 this->options_.threshold(tensor.item<double>());
2214
2215 return archive;
2216 }
2217
2218private:
2219 torch::nn::functional::SoftplusFuncOptions options_;
2220};
2221
2233public:
2234 explicit Softshrink(torch::nn::functional::SoftshrinkFuncOptions options = {})
2235 : options_(options) {}
2236
2237 explicit Softshrink(double lambda)
2238 : options_(
2239 torch::nn::functional::SoftshrinkFuncOptions().lambda(lambda)) {}
2240
2241 ~Softshrink() override = default;
2242
2244 inline torch::Tensor apply(const torch::Tensor &input) const override {
2245 return torch::nn::functional::softshrink(input, options_);
2246 }
2247
2249 inline const torch::nn::functional::SoftshrinkFuncOptions &options() const {
2250 return options_;
2251 }
2252
2254 inline torch::nn::functional::SoftshrinkFuncOptions &options() {
2255 return options_;
2256 }
2257
2259 inline void pretty_print(std::ostream &os) const noexcept override {
2261 << "(\n lambda=" << options_.lambda() << "\n)";
2262 }
2263
2266 inline torch::serialize::OutputArchive &
2267 write(torch::serialize::OutputArchive &archive,
2268 const std::string &key = "softshrink") const override {
2269 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
2271 archive.write(key + ".lambda",
2272 torch::full({1}, (double)this->options_.lambda()));
2273
2274 return archive;
2275 }
2276
2279 inline torch::serialize::InputArchive &
2280 read(torch::serialize::InputArchive &archive,
2281 const std::string &key = "softshrink") override {
2282 torch::Tensor tensor;
2283
2284 archive.read(key + ".type", tensor);
2285 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::softshrink))
2286 throw std::runtime_error("activation mismatch");
2287
2288 archive.read(key + ".lambda", tensor);
2289 this->options_.lambda(tensor.item<double>());
2290
2291 return archive;
2292 }
2293
2294private:
2295 torch::nn::functional::SoftshrinkFuncOptions options_;
2296};
2297
2304public:
2306 inline torch::Tensor apply(const torch::Tensor &input) const override {
2307 return torch::nn::functional::softsign(input);
2308 }
2309
2311 inline void pretty_print(std::ostream &os) const noexcept override {
2313 }
2314
2317 inline torch::serialize::OutputArchive &
2318 write(torch::serialize::OutputArchive &archive,
2319 const std::string &key = "softsign") const override {
2320 archive.write(key + ".type",
2321 torch::full({1}, static_cast<int64_t>(activation::softsign)));
2322
2323 return archive;
2324 }
2325
2328 inline torch::serialize::InputArchive &
2329 read(torch::serialize::InputArchive &archive,
2330 const std::string &key = "softsign") override {
2331 torch::Tensor tensor;
2332
2333 archive.read(key + ".type", tensor);
2334 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::softsign))
2335 throw std::runtime_error("activation mismatch");
2336
2337 return archive;
2338 }
2339};
2340
2346class Tanh : public ActivationFunction {
2347public:
2349 inline torch::Tensor apply(const torch::Tensor &input) const override {
2350 return torch::tanh(input);
2351 }
2352
2354 inline void pretty_print(std::ostream &os) const noexcept override {
2356 }
2357
2360 inline torch::serialize::OutputArchive &
2361 write(torch::serialize::OutputArchive &archive,
2362 const std::string &key = "tanh") const override {
2363 archive.write(key + ".type",
2364 torch::full({1}, static_cast<int64_t>(activation::tanh)));
2365
2366 return archive;
2367 }
2368
2371 inline torch::serialize::InputArchive &
2372 read(torch::serialize::InputArchive &archive,
2373 const std::string &key = "tanh") override {
2374 torch::Tensor tensor;
2375
2376 archive.read(key + ".type", tensor);
2377 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::tanh))
2378 throw std::runtime_error("activation mismatch");
2379
2380 return archive;
2381 }
2382};
2383
2390public:
2392 inline torch::Tensor apply(const torch::Tensor &input) const override {
2393 return torch::nn::functional::tanhshrink(input);
2394 }
2395
2397 inline void pretty_print(std::ostream &os) const noexcept override {
2399 }
2400
2403 inline torch::serialize::OutputArchive &
2404 write(torch::serialize::OutputArchive &archive,
2405 const std::string &key = "tanhshrink") const override {
2406 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
2408
2409 return archive;
2410 }
2411
2414 inline torch::serialize::InputArchive &
2415 read(torch::serialize::InputArchive &archive,
2416 const std::string &key = "tanhshrink") override {
2417 torch::Tensor tensor;
2418
2419 archive.read(key + ".type", tensor);
2420 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::tanhshrink))
2421 throw std::runtime_error("activation mismatch");
2422
2423 return archive;
2424 }
2425};
2426
2437public:
2438 explicit Threshold(const torch::nn::functional::ThresholdFuncOptions &options)
2439 : options_(options) {}
2440
2441 explicit Threshold(double threshold, double value, bool inplace = false)
2442 : options_(torch::nn::functional::ThresholdFuncOptions(threshold, value)
2443 .inplace(inplace)) {}
2444
2445 ~Threshold() override = default;
2446
2448 inline torch::Tensor apply(const torch::Tensor &input) const override {
2449 return torch::nn::functional::threshold(input, options_);
2450 }
2451
2453 inline const torch::nn::functional::ThresholdFuncOptions &options() const {
2454 return options_;
2455 }
2456
2458 inline torch::nn::functional::ThresholdFuncOptions &options() {
2459 return options_;
2460 }
2461
2463 inline void pretty_print(std::ostream &os) const noexcept override {
2465 << "(\n threshold=" << options_.threshold()
2466 << ", value=" << options_.value() << ", inplace=" << options_.inplace()
2467 << "\n)";
2468 }
2469
2472 inline torch::serialize::OutputArchive &
2473 write(torch::serialize::OutputArchive &archive,
2474 const std::string &key = "threshold") const override {
2475 archive.write(key + ".type", torch::full({1}, static_cast<int64_t>(
2477 archive.write(key + ".threshold",
2478 torch::full({1}, this->options_.threshold()));
2479 archive.write(key + ".value", torch::full({1}, this->options_.value()));
2480 archive.write(key + ".inplace", torch::full({1}, this->options_.inplace()));
2481
2482 return archive;
2483 }
2484
2487 inline torch::serialize::InputArchive &
2488 read(torch::serialize::InputArchive &archive,
2489 const std::string &key = "threshold") override {
2490 torch::Tensor tensor;
2491
2492 archive.read(key + ".type", tensor);
2493 if (tensor.item<int64_t>() != static_cast<int64_t>(activation::threshold))
2494 throw std::runtime_error("activation mismatch");
2495
2496 archive.read(key + ".threshold", tensor);
2497 this->options_.threshold(tensor.item<double>());
2498 archive.read(key + ".value", tensor);
2499 this->options_.value(tensor.item<double>());
2500 archive.read(key + ".inplace", tensor);
2501 this->options_.inplace(tensor.item<bool>());
2502
2503 return archive;
2504 }
2505
2506private:
2507 torch::nn::functional::ThresholdFuncOptions options_;
2508};
2509
2510} // namespace iganet
Abstract activation function structure.
Definition activation.hpp:65
~ActivationFunction() override=default
void pretty_print(std::ostream &os) const noexcept override=0
Returns a string representation of the activation function.
virtual torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key) const =0
Writes the activation function into a torch::serialize::OutputArchive object.
virtual torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key)=0
Reads the activation function from a torch::serialize::InputArchive object.
virtual torch::Tensor apply(const torch::Tensor &) const =0
Applies the activation function to the given input.
Batch Normalization as described in the paper.
Definition activation.hpp:138
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:188
torch::Tensor & running_mean()
Returns non-constant reference to running mean.
Definition activation.hpp:169
torch::Tensor running_var_
Definition activation.hpp:259
const torch::Tensor & running_mean() const
Returns constant reference to running mean.
Definition activation.hpp:166
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="batch_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:211
const torch::Tensor & running_var() const
Returns constant reference to running variance.
Definition activation.hpp:172
torch::Tensor running_mean_
Definition activation.hpp:259
const torch::nn::functional::BatchNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:178
~BatchNorm() override=default
BatchNorm(torch::Tensor running_mean, torch::Tensor running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool training=false)
Definition activation.hpp:145
torch::nn::functional::BatchNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:183
torch::Tensor & running_var()
Returns non-constant reference to running var.
Definition activation.hpp:175
torch::nn::functional::BatchNormFuncOptions options_
Definition activation.hpp:258
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="batch_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:235
BatchNorm(torch::Tensor running_mean, torch::Tensor running_var, torch::nn::functional::BatchNormFuncOptions options={})
Definition activation.hpp:140
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:160
Continuously Differentiable Exponential Linear Units activation function.
Definition activation.hpp:268
torch::nn::functional::CELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:290
CELU(torch::nn::functional::CELUFuncOptions options={})
Definition activation.hpp:270
torch::nn::functional::CELUFuncOptions options_
Definition activation.hpp:333
const torch::nn::functional::CELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:285
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="celu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:316
CELU(double alpha, bool inplace=false)
Definition activation.hpp:273
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:293
~CELU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:280
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="celu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:301
Exponential Linear Units activation function.
Definition activation.hpp:345
const torch::nn::functional::ELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:362
~ELU() override=default
torch::nn::functional::ELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:367
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="elu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:379
ELU(double alpha, bool inplace=false)
Definition activation.hpp:350
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:371
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:357
torch::nn::functional::ELUFuncOptions options_
Definition activation.hpp:411
ELU(torch::nn::functional::ELUFuncOptions options={})
Definition activation.hpp:347
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="elu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:394
Gaussian Error Linear Units activation function.
Definition activation.hpp:422
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:441
GELU()=default
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:434
~GELU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:429
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:452
Grated Linear Units activation function.
Definition activation.hpp:473
~GLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="glu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:518
const torch::nn::functional::GLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:489
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:484
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="glu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:505
torch::nn::functional::GLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:494
torch::nn::functional::GLUFuncOptions options_
Definition activation.hpp:533
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:497
GLU(torch::nn::functional::GLUFuncOptions options={})
Definition activation.hpp:475
GLU(int64_t dim)
Definition activation.hpp:478
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition activation.hpp:538
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:556
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="group_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:600
torch::nn::functional::GroupNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:566
const torch::nn::functional::GroupNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:561
~GroupNorm() override=default
GroupNorm(int64_t num_groups, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition activation.hpp:546
GroupNorm(int64_t num_groups)
Definition activation.hpp:540
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:572
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="group_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:586
torch::nn::functional::GroupNormFuncOptions options_
Definition activation.hpp:617
GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
Definition activation.hpp:543
Gumbel-Softmax distribution activation function.
Definition activation.hpp:621
const torch::nn::functional::GumbelSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:642
GumbelSoftmax(torch::nn::functional::GumbelSoftmaxFuncOptions options={})
Definition activation.hpp:623
~GumbelSoftmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:636
torch::nn::functional::GumbelSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:647
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:653
GumbelSoftmax(double tau, int dim, bool hard)
Definition activation.hpp:627
torch::nn::functional::GumbelSoftmaxFuncOptions options_
Definition activation.hpp:696
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gumbel_softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:661
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gumbel_softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:676
Hard shrinkish activation function.
Definition activation.hpp:700
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:749
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:728
~Hardshrink() override=default
torch::nn::functional::HardshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:722
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:736
Hardshrink(double lambda)
Definition activation.hpp:705
torch::nn::functional::HardshrinkFuncOptions options_
Definition activation.hpp:764
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:712
Hardshrink(torch::nn::functional::HardshrinkFuncOptions options={})
Definition activation.hpp:702
const torch::nn::functional::HardshrinkFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:717
Hardsigmoid activation function.
Definition activation.hpp:777
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:784
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:809
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:797
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:790
~Hardsigmoid() override=default
Hardswish activation function.
Definition activation.hpp:831
~Hardswish() override=default
Hardswish()=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardswish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:851
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:838
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardswish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:862
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:844
Hardtanh activation function.
Definition activation.hpp:884
const torch::nn::functional::HardtanhFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:904
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardtanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:942
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:899
Hardtanh(const torch::nn::functional::HardtanhFuncOptions &options={})
Definition activation.hpp:886
torch::nn::functional::HardtanhFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:909
~Hardtanh() override=default
Hardtanh(double min_val, double max_val, bool inplace=false)
Definition activation.hpp:890
torch::nn::functional::HardtanhFuncOptions options_
Definition activation.hpp:961
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardtanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:925
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:915
Instance Normalization as described in the paper.
Definition activation.hpp:968
void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1007
torch::nn::functional::InstanceNormFuncOptions options_
Definition activation.hpp:1070
torch::nn::functional::InstanceNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1001
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:991
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="instance_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1025
const torch::nn::functional::InstanceNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:996
InstanceNorm(torch::nn::functional::InstanceNormFuncOptions options={})
Definition activation.hpp:970
InstanceNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool use_input_stats=true)
Definition activation.hpp:974
~InstanceNorm() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="instance_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1046
Layer Normalization as described in the paper.
Definition activation.hpp:1076
const torch::nn::functional::LayerNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1102
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="layer_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1127
LayerNorm(std::vector< int64_t > normalized_shape, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition activation.hpp:1085
LayerNorm(std::vector< int64_t > normalized_shape)
Definition activation.hpp:1078
~LayerNorm() override=default
torch::nn::functional::LayerNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1107
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1112
LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
Definition activation.hpp:1082
torch::nn::functional::LayerNormFuncOptions options_
Definition activation.hpp:1158
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1097
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="layer_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1141
Leaky ReLU activation function.
Definition activation.hpp:1170
~LeakyReLU() override=default
const torch::nn::functional::LeakyReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1188
LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options={})
Definition activation.hpp:1172
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="leaky_relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1207
torch::nn::functional::LeakyReLUFuncOptions options_
Definition activation.hpp:1240
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1183
torch::nn::functional::LeakyReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1193
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="leaky_relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1223
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1198
LeakyReLU(double negative_slope, bool inplace=false)
Definition activation.hpp:1175
Local response Normalization.
Definition activation.hpp:1244
torch::nn::functional::LocalResponseNormFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1273
LocalResponseNorm(const torch::nn::functional::LocalResponseNormFuncOptions &options)
Definition activation.hpp:1249
torch::nn::functional::LocalResponseNormFuncOptions options_
Definition activation.hpp:1329
LocalResponseNorm(int64_t size, double alpha, double beta, double k)
Definition activation.hpp:1253
LocalResponseNorm(int64_t size)
Definition activation.hpp:1246
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="local_response_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1287
~LocalResponseNorm() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1262
const torch::nn::functional::LocalResponseNormFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1268
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1278
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="local_response_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1307
LogSigmoid activation function.
Definition activation.hpp:1337
~LogSigmoid() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1367
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1349
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1356
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1344
LogSigmoid()=default
LogSoftmax activation function.
Definition activation.hpp:1387
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1399
torch::nn::functional::LogSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1409
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsoftmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1422
LogSoftmax(int64_t dim)
Definition activation.hpp:1389
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1414
LogSoftmax(const torch::nn::functional::LogSoftmaxFuncOptions &options)
Definition activation.hpp:1392
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsoftmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1433
torch::nn::functional::LogSoftmaxFuncOptions options_
Definition activation.hpp:1445
~LogSoftmax() override=default
const torch::nn::functional::LogSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1404
Mish activation function.
Definition activation.hpp:1453
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1460
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1465
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="mish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1483
~Mish() override=default
Mish()=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="mish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1472
No-op activation function.
Definition activation.hpp:95
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="none") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:110
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:103
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:98
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="none") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:121
Lp Normalization.
Definition activation.hpp:1496
~Normalize() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="normalize") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1532
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="normalize") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1547
torch::nn::functional::NormalizeFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1519
Normalize(double p, double eps, int64_t dim)
Definition activation.hpp:1501
Normalize(torch::nn::functional::NormalizeFuncOptions options={})
Definition activation.hpp:1498
torch::nn::functional::NormalizeFuncOptions options_
Definition activation.hpp:1566
const torch::nn::functional::NormalizeFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1514
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1509
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1524
PReLU activation function.
Definition activation.hpp:1570
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1583
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1588
PReLU(torch::Tensor weight)
Definition activation.hpp:1572
torch::Tensor weight_
Definition activation.hpp:1624
~PReLU() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="prelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1598
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="prelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1610
const torch::Tensor & weight() const
Returns constant reference to weights.
Definition activation.hpp:1577
torch::Tensor & weight()
Returns non-constant reference to weights.
Definition activation.hpp:1580
Randomized ReLU activation function.
Definition activation.hpp:1772
RReLU(double lower, double upper, bool inplace=false)
Definition activation.hpp:1777
RReLU(const torch::nn::functional::RReLUFuncOptions &options={})
Definition activation.hpp:1774
const torch::nn::functional::RReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1791
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1786
torch::nn::functional::RReLUFuncOptions options_
Definition activation.hpp:1844
~RReLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="rrelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1825
torch::nn::functional::RReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1796
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1799
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="rrelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1808
ReLU6 activation function.
Definition activation.hpp:1700
~ReLU6() override=default
ReLU6(torch::nn::functional::ReLU6FuncOptions options={})
Definition activation.hpp:1702
torch::nn::functional::ReLU6FuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1721
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1711
const torch::nn::functional::ReLU6FuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1716
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu6") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1732
ReLU6(bool inplace)
Definition activation.hpp:1705
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1724
torch::nn::functional::ReLU6FuncOptions options_
Definition activation.hpp:1760
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu6") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1745
ReLU activation function.
Definition activation.hpp:1632
torch::nn::functional::ReLUFuncOptions options_
Definition activation.hpp:1692
~ReLU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1643
ReLU(bool inplace)
Definition activation.hpp:1637
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1664
torch::nn::functional::ReLUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1653
ReLU(torch::nn::functional::ReLUFuncOptions options={})
Definition activation.hpp:1634
const torch::nn::functional::ReLUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1648
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1656
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1677
SELU activation function.
Definition activation.hpp:1855
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1879
~SELU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="selu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1900
torch::nn::functional::SELUFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:1876
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1866
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="selu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1887
const torch::nn::functional::SELUFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:1871
SELU(bool inplace)
Definition activation.hpp:1860
SELU(torch::nn::functional::SELUFuncOptions options={})
Definition activation.hpp:1857
torch::nn::functional::SELUFuncOptions options_
Definition activation.hpp:1915
Sigmoid Linear Unit activation function.
Definition activation.hpp:1966
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="silu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1981
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1974
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1969
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="silu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1992
Sigmoid activation function.
Definition activation.hpp:1923
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:1931
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:1926
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="sigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:1949
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="sigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:1938
Softmax activation function.
Definition activation.hpp:2011
torch::nn::functional::SoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2032
Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
Definition activation.hpp:2016
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2045
const torch::nn::functional::SoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2027
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2037
~Softmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2022
Softmax(int64_t dim)
Definition activation.hpp:2013
torch::nn::functional::SoftmaxFuncOptions options_
Definition activation.hpp:2073
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2058
Softmin activation function.
Definition activation.hpp:2081
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmin") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2115
Softmin(int64_t dim)
Definition activation.hpp:2083
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2092
torch::nn::functional::SoftminFuncOptions options_
Definition activation.hpp:2143
const torch::nn::functional::SoftminFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2097
~Softmin() override=default
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2107
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmin") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2128
torch::nn::functional::SoftminFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2102
Softmin(const torch::nn::functional::SoftminFuncOptions &options)
Definition activation.hpp:2086
Softplus activation function.
Definition activation.hpp:2151
torch::nn::functional::SoftplusFuncOptions options_
Definition activation.hpp:2219
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2164
Softplus(torch::nn::functional::SoftplusFuncOptions options={})
Definition activation.hpp:2153
const torch::nn::functional::SoftplusFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2169
Softplus(double beta, double threshold)
Definition activation.hpp:2156
torch::nn::functional::SoftplusFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2174
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softplus") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2202
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softplus") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2187
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2179
~Softplus() override=default
Softshrink activation function.
Definition activation.hpp:2232
const torch::nn::functional::SoftshrinkFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2249
Softshrink(double lambda)
Definition activation.hpp:2237
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2267
Softshrink(torch::nn::functional::SoftshrinkFuncOptions options={})
Definition activation.hpp:2234
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2244
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2280
torch::nn::functional::SoftshrinkFuncOptions options_
Definition activation.hpp:2295
torch::nn::functional::SoftshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2254
~Softshrink() override=default
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2259
Softsign activation function.
Definition activation.hpp:2303
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2306
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2311
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softsign") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2329
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softsign") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2318
Tanh activation function.
Definition activation.hpp:2346
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2354
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2361
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2349
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2372
Tanhshrink activation function.
Definition activation.hpp:2389
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2392
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanhshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2415
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanhshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2404
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2397
Threshold activation function.
Definition activation.hpp:2436
const torch::nn::functional::ThresholdFuncOptions & options() const
Returns constant reference to options.
Definition activation.hpp:2453
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition activation.hpp:2463
Threshold(const torch::nn::functional::ThresholdFuncOptions &options)
Definition activation.hpp:2438
~Threshold() override=default
Threshold(double threshold, double value, bool inplace=false)
Definition activation.hpp:2441
torch::nn::functional::ThresholdFuncOptions & options()
Returns non-constant reference to options.
Definition activation.hpp:2458
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition activation.hpp:2448
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="threshold") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition activation.hpp:2488
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="threshold") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition activation.hpp:2473
torch::nn::functional::ThresholdFuncOptions options_
Definition activation.hpp:2507
Full qualified name descriptor.
Definition fqn.hpp:22
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:28
Core components.
Full qualified name utility functions.
Definition core.hpp:72
bool is_verbose(std::ostream &os)
Definition core.hpp:831
std::ostream & operator<<(std::ostream &os, const MemoryDebugger< id > &obj)
Print (as string) a memory debugger object.
Definition memory.hpp:125
struct iganet::@0 Log
Logger.
@ none
Definition boundary.hpp:38
activation
Enumerator for nonlinear activation functions.
Definition activation.hpp:26
short int short_t
Definition core.hpp:74
STL namespace.
Definition optimizer.hpp:61