IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
layer.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core.hpp>
18#include <utils/fqn.hpp>
19
20namespace iganet {
21
23enum class activation : short_t {
24 none = 0,
25 batch_norm = 1,
26 celu = 2,
27 elu = 3,
28 gelu = 4,
29 glu = 5,
30 group_norm = 6,
32 hardshrink = 9,
33 hardsigmoid = 8,
34 hardswish = 10,
35 hardtanh = 11,
36 instance_norm = 12,
37 layer_norm = 13,
38 leaky_relu = 14,
40 logsigmoid = 16,
41 logsoftmax = 17,
42 mish = 18,
43 normalize = 19,
44 prelu = 20,
45 relu = 21,
46 relu6 = 22,
47 rrelu = 23,
48 selu = 24,
49 sigmoid = 25,
50 silu = 26,
51 softmax = 27,
52 softmin = 28,
53 softplus = 29,
54 softshrink = 30,
55 softsign = 31,
56 tanh = 32,
57 tanhshrink = 33,
58 threshold = 34
59};
60
63public:
64 virtual ~ActivationFunction() = default;
65
67 virtual torch::Tensor apply(const torch::Tensor &) const = 0;
68
70 virtual void pretty_print(std::ostream &os) const noexcept = 0;
71
74 virtual torch::serialize::OutputArchive &
75 write(torch::serialize::OutputArchive &archive,
76 const std::string &key) const = 0;
77
80 virtual torch::serialize::InputArchive &
81 read(torch::serialize::InputArchive &archive, const std::string &key) = 0;
82};
83
85inline std::ostream &operator<<(std::ostream &os,
86 const ActivationFunction &obj) {
87 obj.pretty_print(os);
88 return os;
89}
90
92class None : public ActivationFunction {
93public:
95 inline torch::Tensor apply(const torch::Tensor &input) const override {
96 return input;
97 }
98
100 inline virtual void pretty_print(std::ostream &os) const noexcept override {
102 }
103
106 inline torch::serialize::OutputArchive &
107 write(torch::serialize::OutputArchive &archive,
108 const std::string &key = "none") const override {
109 archive.write(key + ".type", torch::full({1}, (int64_t)activation::none));
110
111 return archive;
112 }
113
116 inline torch::serialize::InputArchive &
117 read(torch::serialize::InputArchive &archive,
118 const std::string &key = "none") override {
119 torch::Tensor tensor;
120
121 archive.read(key + ".type", tensor);
122 if (tensor.item<int64_t>() != (int64_t)activation::none)
123 throw std::runtime_error("activation mismatch");
124
125 return archive;
126 }
127};
128
135public:
136 explicit BatchNorm(const torch::Tensor &running_mean,
137 const torch::Tensor &running_var,
138 torch::nn::functional::BatchNormFuncOptions options = {})
141
142 explicit BatchNorm(const torch::Tensor &running_mean,
143 const torch::Tensor &running_var,
144 const torch::Tensor &weight, const torch::Tensor &bias,
145 double eps, double momentum, bool training = false)
148 .weight(weight)
149 .bias(bias)
150 .eps(eps)
152 .training(training)) {}
153
154 ~BatchNorm() override = default;
155
157 inline torch::Tensor apply(const torch::Tensor &input) const override {
158 return torch::nn::functional::batch_norm(input, running_mean_, running_var_,
159 options_);
160 }
161
163 inline const torch::Tensor &running_mean() const { return running_mean_; }
164
166 inline torch::Tensor &running_mean() { return running_mean_; }
167
169 inline const torch::Tensor &running_var() const { return running_var_; }
170
172 inline torch::Tensor &running_var() { return running_var_; }
173
175 inline const torch::nn::functional::BatchNormFuncOptions &options() const {
176 return options_;
177 }
178
180 inline torch::nn::functional::BatchNormFuncOptions &options() {
181 return options_;
182 }
183
185 inline virtual void
186 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
187 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
188 << ", momentum=" << options_.momentum()
189#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
190 .value()
191#endif
192 << ", training=" << options_.training();
193
194 if (is_verbose(os)) {
195 os << "\n running_mean = " << running_mean()
196 << "\n running_var = " << running_var()
197 << "\n weight = " << options_.weight()
198 << "\n bias = " << options_.bias();
199 }
200
201 os << "\n)";
202 }
203
206 inline torch::serialize::OutputArchive &
207 write(torch::serialize::OutputArchive &archive,
208 const std::string &key = "batch_norm") const override {
209 archive.write(key + ".type",
210 torch::full({1}, (int64_t)activation::batch_norm));
211 archive.write(key + ".running_mean", this->running_mean());
212 archive.write(key + ".running_var", this->running_var());
213 archive.write(key + ".weight", this->options_.weight());
214 archive.write(key + ".bias", this->options_.bias());
215 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
216 archive.write(key + ".momentum",
217 torch::full({1}, (double)this->options_.momentum()
218#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR < 7
219 .value()
220#endif
221 ));
222 archive.write(key + ".training",
223 torch::full({1}, (bool)this->options_.training()));
224
225 return archive;
226 }
227
230 inline torch::serialize::InputArchive &
231 read(torch::serialize::InputArchive &archive,
232 const std::string &key = "batch_norm") override {
233 torch::Tensor tensor;
234
235 archive.read(key + ".type", tensor);
237 throw std::runtime_error("activation mismatch");
238
239 archive.read(key + ".running_mean", this->running_mean());
240 archive.read(key + ".running_var", this->running_var());
241 archive.read(key + ".weight", this->options_.weight());
242 archive.read(key + ".bias", this->options_.bias());
243 archive.read(key + ".eps", tensor);
244 this->options_.eps(tensor.item<double>());
245 archive.read(key + ".momentum", tensor);
246 this->options_.momentum(tensor.item<double>());
247 archive.read(key + ".training", tensor);
248 this->options_.training(tensor.item<bool>());
249
250 return archive;
251 }
252
253private:
254 torch::nn::functional::BatchNormFuncOptions options_;
256};
257
264class CELU : public ActivationFunction {
265public:
266 explicit CELU(torch::nn::functional::CELUFuncOptions options = {})
267 : options_(options) {}
268
269 explicit CELU(double alpha, bool inplace = false)
271 inplace)) {}
272
273 ~CELU() override = default;
274
276 inline torch::Tensor apply(const torch::Tensor &input) const override {
277 return torch::nn::functional::celu(input, options_);
278 }
279
281 inline const torch::nn::functional::CELUFuncOptions &options() const {
282 return options_;
283 }
284
286 inline torch::nn::functional::CELUFuncOptions &options() { return options_; }
287
289 inline virtual void
290 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
291 os << utils::FullQualifiedName::name() << "(\n alpha=" << options_.alpha()
292 << ", inplace=" << options_.inplace() << "\n)";
293 }
294
297 inline torch::serialize::OutputArchive &
298 write(torch::serialize::OutputArchive &archive,
299 const std::string &key = "celu") const override {
300 archive.write(key + ".type", torch::full({1}, (int64_t)activation::celu));
301 archive.write(key + ".alpha",
302 torch::full({1}, (double)this->options_.alpha()));
303 archive.write(key + ".inplace",
304 torch::full({1}, (bool)this->options_.inplace()));
305
306 return archive;
307 }
308
311 inline torch::serialize::InputArchive &
312 read(torch::serialize::InputArchive &archive,
313 const std::string &key = "celu") override {
314 torch::Tensor tensor;
315
316 archive.read(key + ".type", tensor);
317 if (tensor.item<int64_t>() != (int64_t)activation::celu)
318 throw std::runtime_error("activation mismatch");
319
320 archive.read(key + ".alpha", tensor);
321 this->options_.alpha(tensor.item<double>());
322 archive.read(key + ".inplace", tensor);
323 this->options_.inplace(tensor.item<bool>());
324
325 return archive;
326 }
327
328private:
329 torch::nn::functional::CELUFuncOptions options_;
330};
331
341class ELU : public ActivationFunction {
342public:
343 explicit ELU(torch::nn::functional::ELUFuncOptions options = {})
344 : options_(options) {}
345
346 explicit ELU(double alpha, bool inplace = false)
348 inplace)) {}
349
350 ~ELU() override = default;
351
353 inline torch::Tensor apply(const torch::Tensor &input) const override {
354 return torch::nn::functional::elu(input, options_);
355 }
356
358 inline const torch::nn::functional::ELUFuncOptions &options() const {
359 return options_;
360 }
361
363 inline torch::nn::functional::ELUFuncOptions &options() { return options_; }
364
366 inline virtual void
367 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
368 os << utils::FullQualifiedName::name() << "(\n alpha=" << options_.alpha()
369 << ", inplace=" << options_.inplace() << "\n)";
370 }
371
374 inline torch::serialize::OutputArchive &
375 write(torch::serialize::OutputArchive &archive,
376 const std::string &key = "elu") const override {
377 archive.write(key + ".type", torch::full({1}, (int64_t)activation::elu));
378 archive.write(key + ".alpha",
379 torch::full({1}, (double)this->options_.alpha()));
380 archive.write(key + ".inplace",
381 torch::full({1}, (bool)this->options_.inplace()));
382
383 return archive;
384 }
385
388 inline torch::serialize::InputArchive &
389 read(torch::serialize::InputArchive &archive,
390 const std::string &key = "elu") override {
391 torch::Tensor tensor;
392
393 archive.read(key + ".type", tensor);
394 if (tensor.item<int64_t>() != (int64_t)activation::elu)
395 throw std::runtime_error("activation mismatch");
396
397 archive.read(key + ".alpha", tensor);
398 this->options_.alpha(tensor.item<double>());
399 archive.read(key + ".inplace", tensor);
400 this->options_.inplace(tensor.item<bool>());
401
402 return archive;
403 }
404
405private:
406 torch::nn::functional::ELUFuncOptions options_;
407};
408
417class GELU : public ActivationFunction {
418public:
419 explicit GELU() = default;
420
421 ~GELU() override = default;
422
424 inline torch::Tensor apply(const torch::Tensor &input) const override {
425 return torch::gelu(input);
426 }
427
429 inline virtual void
430 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
432 }
433
436 inline torch::serialize::OutputArchive &
437 write(torch::serialize::OutputArchive &archive,
438 const std::string &key = "gelu") const override {
439 archive.write(key + ".type", torch::full({1}, (int64_t)activation::gelu));
440
441 return archive;
442 }
443
446 inline torch::serialize::InputArchive &
447 read(torch::serialize::InputArchive &archive,
448 const std::string &key = "gelu") override {
449 torch::Tensor tensor;
450
451 archive.read(key + ".type", tensor);
452 if (tensor.item<int64_t>() != (int64_t)activation::gelu)
453 throw std::runtime_error("activation mismatch");
454
455 return archive;
456 }
457};
458
468class GLU : public ActivationFunction {
469public:
470 explicit GLU(torch::nn::functional::GLUFuncOptions options = {})
471 : options_(options) {}
472
473 explicit GLU(int64_t dim)
474 : options_(torch::nn::functional::GLUFuncOptions().dim(dim)) {}
475
476 ~GLU() override = default;
477
479 inline torch::Tensor apply(const torch::Tensor &input) const override {
480 return torch::nn::functional::glu(input, options_);
481 }
482
484 inline const torch::nn::functional::GLUFuncOptions &options() const {
485 return options_;
486 }
487
489 inline torch::nn::functional::GLUFuncOptions &options() { return options_; }
490
492 inline virtual void
493 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
494 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
495 << "\n)";
496 }
497
500 inline torch::serialize::OutputArchive &
501 write(torch::serialize::OutputArchive &archive,
502 const std::string &key = "glu") const override {
503 archive.write(key + ".type", torch::full({1}, (int64_t)activation::glu));
504 archive.write(key + ".dim", torch::full({1}, (int)this->options_.dim()));
505
506 return archive;
507 }
508
511 inline torch::serialize::InputArchive &
512 read(torch::serialize::InputArchive &archive,
513 const std::string &key = "glu") override {
514 torch::Tensor tensor;
515
516 archive.read(key + ".type", tensor);
517 if (tensor.item<int64_t>() != (int64_t)activation::glu)
518 throw std::runtime_error("activation mismatch");
519
520 archive.read(key + ".dim", tensor);
521 this->options_.dim(tensor.item<int>());
522
523 return archive;
524 }
525
526private:
527 torch::nn::functional::GLUFuncOptions options_;
528};
529
533public:
536
537 explicit GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
538 : options_(std::move(options)) {}
539
540 explicit GroupNorm(int64_t num_groups, const torch::Tensor &weight,
541 const torch::Tensor &bias, double eps)
543 .weight(weight)
544 .bias(bias)
545 .eps(eps)) {}
546
547 ~GroupNorm() override = default;
548
550 inline torch::Tensor apply(const torch::Tensor &input) const override {
551 return torch::nn::functional::group_norm(input, options_);
552 }
553
555 inline const torch::nn::functional::GroupNormFuncOptions &options() const {
556 return options_;
557 }
558
560 inline torch::nn::functional::GroupNormFuncOptions &options() {
561 return options_;
562 }
563
565 inline virtual void
566 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
567 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps();
568
569 if (is_verbose(os)) {
570 os << "\n weight = " << options_.weight()
571 << "\n bias = " << options_.bias();
572 }
573
574 os << "\n)";
575 }
576
579 inline torch::serialize::OutputArchive &
580 write(torch::serialize::OutputArchive &archive,
581 const std::string &key = "group_norm") const override {
582 archive.write(key + ".type",
583 torch::full({1}, (int64_t)activation::group_norm));
584 archive.write(key + ".weight", this->options_.weight());
585 archive.write(key + ".bias", this->options_.bias());
586 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
587
588 return archive;
589 }
590
593 inline torch::serialize::InputArchive &
594 read(torch::serialize::InputArchive &archive,
595 const std::string &key = "group_norm") override {
596 torch::Tensor tensor;
597
598 archive.read(key + ".type", tensor);
600 throw std::runtime_error("activation mismatch");
601
602 archive.read(key + ".weight", this->options_.weight());
603 archive.read(key + ".bias", this->options_.bias());
604 archive.read(key + ".eps", tensor);
605 this->options_.eps(tensor.item<double>());
606
607 return archive;
608 }
609
610private:
611 torch::nn::functional::GroupNormFuncOptions options_;
612};
613
616public:
618 torch::nn::functional::GumbelSoftmaxFuncOptions options = {})
619 : options_(options) {}
620
621 explicit GumbelSoftmax(double tau, int dim, bool hard)
623 .tau(tau)
624 .dim(dim)
625 .hard(hard)) {}
626
627 ~GumbelSoftmax() override = default;
628
630 inline torch::Tensor apply(const torch::Tensor &input) const override {
631 return torch::nn::functional::gumbel_softmax(input, options_);
632 }
633
635 inline const torch::nn::functional::GumbelSoftmaxFuncOptions &
636 options() const {
637 return options_;
638 }
639
641 inline torch::nn::functional::GumbelSoftmaxFuncOptions &options() {
642 return options_;
643 }
644
646 inline virtual void
647 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
648 os << utils::FullQualifiedName::name() << "(\n tau=" << options_.tau()
649 << ", dim=" << options_.dim() << ", hard=" << options_.hard() << "\n)";
650 }
651
654 inline torch::serialize::OutputArchive &
655 write(torch::serialize::OutputArchive &archive,
656 const std::string &key = "gumbel_softmax") const override {
657 archive.write(key + ".type",
658 torch::full({1}, (int64_t)activation::gumbel_softmax));
659 archive.write(key + ".tau", torch::full({1}, (double)this->options_.tau()));
660 archive.write(key + ".dim", torch::full({1}, (int)this->options_.dim()));
661 archive.write(key + ".hard", torch::full({1}, (bool)this->options_.hard()));
662
663 return archive;
664 }
665
668 inline torch::serialize::InputArchive &
669 read(torch::serialize::InputArchive &archive,
670 const std::string &key = "gumbel_softmax") override {
671 torch::Tensor tensor;
672
673 archive.read(key + ".type", tensor);
675 throw std::runtime_error("activation mismatch");
676
677 archive.read(key + ".tau", tensor);
678 this->options_.tau(tensor.item<double>());
679 archive.read(key + ".dim", tensor);
680 this->options_.dim(tensor.item<int>());
681 archive.read(key + ".hard", tensor);
682 this->options_.hard(tensor.item<bool>());
683
684 return archive;
685 }
686
687private:
688 torch::nn::functional::GumbelSoftmaxFuncOptions options_;
689};
690
693public:
694 explicit Hardshrink(torch::nn::functional::HardshrinkFuncOptions options = {})
695 : options_(options) {}
696
697 explicit Hardshrink(double lambda)
698 : options_(
700
701 ~Hardshrink() override = default;
702
704 inline torch::Tensor apply(const torch::Tensor &input) const override {
705 return torch::nn::functional::hardshrink(input, options_);
706 }
707
709 inline const torch::nn::functional::HardshrinkFuncOptions &options() const {
710 return options_;
711 }
712
714 inline torch::nn::functional::HardshrinkFuncOptions &options() {
715 return options_;
716 }
717
719 inline virtual void
720 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
722 << "(\n lambda=" << options_.lambda() << "\n)";
723 }
724
727 inline torch::serialize::OutputArchive &
728 write(torch::serialize::OutputArchive &archive,
729 const std::string &key = "hardshrink") const override {
730 archive.write(key + ".type",
731 torch::full({1}, (int64_t)activation::hardshrink));
732 archive.write(key + ".lambda",
733 torch::full({1}, (double)this->options_.lambda()));
734
735 return archive;
736 }
737
740 inline torch::serialize::InputArchive &
741 read(torch::serialize::InputArchive &archive,
742 const std::string &key = "hardshrink") override {
743 torch::Tensor tensor;
744
745 archive.read(key + ".type", tensor);
747 throw std::runtime_error("activation mismatch");
748
749 archive.read(key + ".lambda", tensor);
750 this->options_.lambda(tensor.item<double>());
751
752 return archive;
753 }
754
755private:
756 torch::nn::functional::HardshrinkFuncOptions options_;
757};
758
770public:
771 explicit Hardsigmoid() = default;
772
773 ~Hardsigmoid() override = default;
774
776 inline torch::Tensor apply(const torch::Tensor &input) const override {
777 return torch::hardsigmoid(input);
778 }
779
781 inline virtual void
782 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
784 }
785
788 inline torch::serialize::OutputArchive &
789 write(torch::serialize::OutputArchive &archive,
790 const std::string &key = "hardsigmoid") const override {
791 archive.write(key + ".type",
792 torch::full({1}, (int64_t)activation::hardsigmoid));
793
794 return archive;
795 }
796
799 inline torch::serialize::InputArchive &
800 read(torch::serialize::InputArchive &archive,
801 const std::string &key = "hardsigmoid") override {
802 torch::Tensor tensor;
803
804 archive.read(key + ".type", tensor);
806 throw std::runtime_error("activation mismatch");
807
808 return archive;
809 }
810};
811
823public:
824 explicit Hardswish() = default;
825
826 ~Hardswish() override = default;
827
829 inline torch::Tensor apply(const torch::Tensor &input) const override {
830 return torch::hardswish(input);
831 }
832
834 inline virtual void
835 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
837 }
838
841 inline torch::serialize::OutputArchive &
842 write(torch::serialize::OutputArchive &archive,
843 const std::string &key = "hardswish") const override {
844 archive.write(key + ".type",
845 torch::full({1}, (int64_t)activation::hardswish));
846
847 return archive;
848 }
849
852 inline torch::serialize::InputArchive &
853 read(torch::serialize::InputArchive &archive,
854 const std::string &key = "hardswish") override {
855 torch::Tensor tensor;
856
857 archive.read(key + ".type", tensor);
859 throw std::runtime_error("activation mismatch");
860
861 return archive;
862 }
863};
864
876public:
877 explicit Hardtanh(torch::nn::functional::HardtanhFuncOptions options = {})
878 : options_(options) {}
879
880 explicit Hardtanh(double min_val, double max_val, bool inplace = false)
884 .inplace(inplace)) {}
885
886 ~Hardtanh() override = default;
887
889 inline torch::Tensor apply(const torch::Tensor &input) const override {
890 return torch::nn::functional::hardtanh(input, options_);
891 }
892
894 inline const torch::nn::functional::HardtanhFuncOptions &options() const {
895 return options_;
896 }
897
899 inline torch::nn::functional::HardtanhFuncOptions &options() {
900 return options_;
901 }
902
904 inline virtual void
905 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
907 << "(\n min_val=" << options_.min_val()
908 << ", max_val=" << options_.max_val()
909 << ", inplace=" << options_.inplace() << "\n)";
910 }
911
914 inline torch::serialize::OutputArchive &
915 write(torch::serialize::OutputArchive &archive,
916 const std::string &key = "hardtanh") const override {
917 archive.write(key + ".type",
918 torch::full({1}, (int64_t)activation::hardtanh));
919 archive.write(key + ".min_val",
920 torch::full({1}, (double)this->options_.min_val()));
921 archive.write(key + ".max_val",
922 torch::full({1}, (double)this->options_.max_val()));
923 archive.write(key + ".inplace",
924 torch::full({1}, (bool)this->options_.inplace()));
925
926 return archive;
927 }
928
931 inline torch::serialize::InputArchive &
932 read(torch::serialize::InputArchive &archive,
933 const std::string &key = "hardtanh") override {
934 torch::Tensor tensor;
935
936 archive.read(key + ".type", tensor);
938 throw std::runtime_error("activation mismatch");
939
940 archive.read(key + ".min_val", tensor);
941 this->options_.min_val(tensor.item<double>());
942 archive.read(key + ".max_val", tensor);
943 this->options_.max_val(tensor.item<double>());
944 archive.read(key + ".inplace", tensor);
945 this->options_.inplace(tensor.item<bool>());
946
947 return archive;
948 }
949
950private:
951 torch::nn::functional::HardtanhFuncOptions options_;
952};
953
959public:
960 explicit InstanceNorm(
961 torch::nn::functional::InstanceNormFuncOptions options = {})
962 : options_(std::move(options)) {}
963
964 explicit InstanceNorm(const torch::Tensor &running_mean,
965 const torch::Tensor &running_var,
966 const torch::Tensor &weight, const torch::Tensor &bias,
967 double eps, double momentum,
968 bool use_input_stats = true)
970 .running_mean(running_mean)
971 .running_var(running_var)
972 .weight(weight)
973 .bias(bias)
974 .eps(eps)
977
978 ~InstanceNorm() override = default;
979
981 inline torch::Tensor apply(const torch::Tensor &input) const override {
982 return torch::nn::functional::instance_norm(input, options_);
983 }
984
986 inline const torch::nn::functional::InstanceNormFuncOptions &options() const {
987 return options_;
988 }
989
991 inline torch::nn::functional::InstanceNormFuncOptions &options() {
992 return options_;
993 }
994
996 inline virtual void
997 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
998 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
999 << ", momentum=" << options_.momentum()
1000 << ", use_input_stats=" << options_.use_input_stats();
1001
1002 if (is_verbose(os)) {
1003 os << "\n running_mean = " << options_.running_mean()
1004 << "\n running_var = " << options_.running_var()
1005 << "\n weight = " << options_.weight()
1006 << "\n bias = " << options_.bias();
1007 }
1008
1009 os << "\n)";
1010 }
1011
1014 inline torch::serialize::OutputArchive &
1015 write(torch::serialize::OutputArchive &archive,
1016 const std::string &key = "instance_norm") const override {
1017 archive.write(key + ".type",
1018 torch::full({1}, (int64_t)activation::instance_norm));
1019 archive.write(key + ".running_mean", this->options_.running_mean());
1020 archive.write(key + ".var", this->options_.running_var());
1021 archive.write(key + ".weight", this->options_.weight());
1022 archive.write(key + ".bias", this->options_.bias());
1023 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1024 archive.write(key + ".momentum",
1025 torch::full({1}, (double)this->options_.momentum()));
1026 archive.write(key + ".use_input_stats",
1027 torch::full({1}, (bool)this->options_.use_input_stats()));
1028
1029 return archive;
1030 }
1031
1034 inline torch::serialize::InputArchive &
1035 read(torch::serialize::InputArchive &archive,
1036 const std::string &key = "instance_norm") override {
1037 torch::Tensor tensor;
1038
1039 archive.read(key + ".type", tensor);
1041 throw std::runtime_error("activation mismatch");
1042
1043 archive.read(key + ".running_mean", this->options_.running_mean());
1044 archive.read(key + ".running_var", this->options_.running_var());
1045 archive.read(key + ".weight", this->options_.weight());
1046 archive.read(key + ".bias", this->options_.bias());
1047 archive.read(key + ".eps", tensor);
1048 this->options_.eps(tensor.item<double>());
1049 archive.read(key + ".momentum", tensor);
1050 this->options_.momentum(tensor.item<double>());
1051 archive.read(key + ".use_input_stats", tensor);
1052 this->options_.use_input_stats(tensor.item<bool>());
1053
1054 return archive;
1055 }
1056
1057private:
1058 torch::nn::functional::InstanceNormFuncOptions options_;
1059};
1060
1065public:
1066 explicit LayerNorm(std::vector<int64_t> normalized_shape)
1069
1070 explicit LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
1071 : options_(std::move(options)) {}
1072
1073 explicit LayerNorm(std::vector<int64_t> normalized_shape,
1074 const torch::Tensor &weight, const torch::Tensor &bias,
1075 double eps)
1078 .weight(weight)
1079 .bias(bias)
1080 .eps(eps)) {}
1081
1082 ~LayerNorm() override = default;
1083
1085 inline torch::Tensor apply(const torch::Tensor &input) const override {
1086 return torch::nn::functional::layer_norm(input, options_);
1087 }
1088
1090 inline const torch::nn::functional::LayerNormFuncOptions &options() const {
1091 return options_;
1092 }
1093
1095 inline torch::nn::functional::LayerNormFuncOptions &options() {
1096 return options_;
1097 }
1098
1100 inline virtual void
1101 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1102 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps();
1103
1104 if (is_verbose(os)) {
1105 os << "\n normalized_shape = " << options_.normalized_shape()
1106 << "\n weight = " << options_.weight()
1107 << "\n bias = " << options_.bias();
1108 }
1109
1110 os << "\n)";
1111 }
1112
1115 inline torch::serialize::OutputArchive &
1116 write(torch::serialize::OutputArchive &archive,
1117 const std::string &key = "layer_norm") const override {
1118 archive.write(key + ".type",
1119 torch::full({1}, (int64_t)activation::layer_norm));
1120 archive.write(key + ".weight", this->options_.weight());
1121 archive.write(key + ".bias", this->options_.bias());
1122 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1123
1124 return archive;
1125 }
1126
1129 inline torch::serialize::InputArchive &
1130 read(torch::serialize::InputArchive &archive,
1131 const std::string &key = "layer_norm") override {
1132 torch::Tensor tensor;
1133
1134 archive.read(key + ".type", tensor);
1136 throw std::runtime_error("activation mismatch");
1137
1138 archive.read(key + ".weight", this->options_.weight());
1139 archive.read(key + ".bias", this->options_.bias());
1140 archive.read(key + ".eps", tensor);
1141 this->options_.eps(tensor.item<double>());
1142
1143 return archive;
1144 }
1145
1146private:
1147 torch::nn::functional::LayerNormFuncOptions options_;
1148};
1149
1160public:
1161 explicit LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options = {})
1162 : options_(options) {}
1163
1164 explicit LeakyReLU(double negative_slope, bool inplace = false)
1167 .inplace(inplace)) {}
1168
1169 ~LeakyReLU() override = default;
1170
1172 inline torch::Tensor apply(const torch::Tensor &input) const override {
1173 return torch::nn::functional::leaky_relu(input, options_);
1174 }
1175
1177 inline const torch::nn::functional::LeakyReLUFuncOptions &options() const {
1178 return options_;
1179 }
1180
1182 inline torch::nn::functional::LeakyReLUFuncOptions &options() {
1183 return options_;
1184 }
1185
1187 inline virtual void
1188 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1190 << "(\n negative_slope=" << options_.negative_slope()
1191 << ", inplace=" << options_.inplace() << "\n)";
1192 }
1193
1196 inline torch::serialize::OutputArchive &
1197 write(torch::serialize::OutputArchive &archive,
1198 const std::string &key = "leaky_relu") const override {
1199 archive.write(key + ".type",
1200 torch::full({1}, (int64_t)activation::leaky_relu));
1201
1202 archive.write(key + ".negative_slope",
1203 torch::full({1}, (double)this->options_.negative_slope()));
1204 archive.write(key + ".inplace",
1205 torch::full({1}, (bool)this->options_.inplace()));
1206
1207 return archive;
1208 }
1209
1212 inline torch::serialize::InputArchive &
1213 read(torch::serialize::InputArchive &archive,
1214 const std::string &key = "leaky_relu") override {
1215 torch::Tensor tensor;
1216
1217 archive.read(key + ".type", tensor);
1219 throw std::runtime_error("activation mismatch");
1220
1221 archive.read(key + ".negative_slope", tensor);
1222 this->options_.negative_slope(tensor.item<double>());
1223 archive.read(key + ".inplace", tensor);
1224 this->options_.inplace(tensor.item<bool>());
1225
1226 return archive;
1227 }
1228
1229private:
1230 torch::nn::functional::LeakyReLUFuncOptions options_;
1231};
1232
1235public:
1237 : options_(torch::nn::functional::LocalResponseNormFuncOptions(size)) {}
1238
1240 const torch::nn::functional::LocalResponseNormFuncOptions &options)
1241 : options_(options) {}
1242
1243 explicit LocalResponseNorm(int64_t size, double alpha, double beta, double k)
1245 .alpha(alpha)
1246 .beta(beta)
1247 .k(k)) {}
1248
1249 ~LocalResponseNorm() override = default;
1250
1252 inline torch::Tensor apply(const torch::Tensor &input) const override {
1253 return torch::nn::functional::local_response_norm(input, options_);
1254 }
1255
1257 inline const torch::nn::functional::LocalResponseNormFuncOptions &
1258 options() const {
1259 return options_;
1260 }
1261
1263 inline torch::nn::functional::LocalResponseNormFuncOptions &options() {
1264 return options_;
1265 }
1266
1268 inline virtual void
1269 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1270 os << utils::FullQualifiedName::name() << "(\n size=" << options_.size()
1271 << ", alpha=" << options_.alpha() << ", beta=" << options_.beta()
1272 << ", k=" << options_.k() << "\n)";
1273 }
1274
1277 inline torch::serialize::OutputArchive &
1278 write(torch::serialize::OutputArchive &archive,
1279 const std::string &key = "local_response_norm") const override {
1280 archive.write(key + ".type",
1281 torch::full({1}, (int64_t)activation::local_response_norm));
1282
1283 archive.write(key + ".size",
1284 torch::full({1}, (int64_t)this->options_.size()));
1285 archive.write(key + ".alpha",
1286 torch::full({1}, (double)this->options_.alpha()));
1287 archive.write(key + ".beta",
1288 torch::full({1}, (double)this->options_.beta()));
1289 archive.write(key + ".k", torch::full({1}, (double)this->options_.k()));
1290
1291 return archive;
1292 }
1293
1296 inline torch::serialize::InputArchive &
1297 read(torch::serialize::InputArchive &archive,
1298 const std::string &key = "local_response_norm") override {
1299 torch::Tensor tensor;
1300
1301 archive.read(key + ".type", tensor);
1303 throw std::runtime_error("activation mismatch");
1304
1305 archive.read(key + ".size", tensor);
1306 this->options_.size(tensor.item<int64_t>());
1307 archive.read(key + ".alpha", tensor);
1308 this->options_.alpha(tensor.item<double>());
1309 archive.read(key + ".beta", tensor);
1310 this->options_.beta(tensor.item<double>());
1311 archive.read(key + ".k", tensor);
1312 this->options_.k(tensor.item<double>());
1313
1314 return archive;
1315 }
1316
1317private:
1318 torch::nn::functional::LocalResponseNormFuncOptions options_;
1319};
1320
1327public:
1328 explicit LogSigmoid() = default;
1329
1330 ~LogSigmoid() override = default;
1331
1333 inline torch::Tensor apply(const torch::Tensor &input) const override {
1334 return torch::log_sigmoid(input);
1335 }
1336
1338 inline virtual void
1339 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1341 }
1342
1345 inline torch::serialize::OutputArchive &
1346 write(torch::serialize::OutputArchive &archive,
1347 const std::string &key = "logsigmoid") const override {
1348 archive.write(key + ".type",
1349 torch::full({1}, (int64_t)activation::logsigmoid));
1350
1351 return archive;
1352 }
1353
1356 inline torch::serialize::InputArchive &
1357 read(torch::serialize::InputArchive &archive,
1358 const std::string &key = "logsigmoid") override {
1359 torch::Tensor tensor;
1360
1361 archive.read(key + ".type", tensor);
1363 throw std::runtime_error("activation mismatch");
1364
1365 return archive;
1366 }
1367};
1368
1378public:
1379 explicit LogSoftmax(int64_t dim)
1380 : options_(torch::nn::functional::LogSoftmaxFuncOptions(dim)) {}
1381
1382 explicit LogSoftmax(
1383 const torch::nn::functional::LogSoftmaxFuncOptions &options)
1384 : options_(options) {}
1385
1386 ~LogSoftmax() override = default;
1387
1389 inline torch::Tensor apply(const torch::Tensor &input) const override {
1390 return torch::nn::functional::log_softmax(input, options_);
1391 }
1392
1394 inline const torch::nn::functional::LogSoftmaxFuncOptions &options() const {
1395 return options_;
1396 }
1397
1399 inline torch::nn::functional::LogSoftmaxFuncOptions &options() {
1400 return options_;
1401 }
1402
1404 inline virtual void
1405 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1406 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
1407 << "\n)";
1408 }
1409
1412 inline torch::serialize::OutputArchive &
1413 write(torch::serialize::OutputArchive &archive,
1414 const std::string &key = "logsoftmax") const override {
1415 archive.write(key + ".type",
1416 torch::full({1}, (int64_t)activation::logsoftmax));
1417
1418 return archive;
1419 }
1420
1423 inline torch::serialize::InputArchive &
1424 read(torch::serialize::InputArchive &archive,
1425 const std::string &key = "logsoftmax") override {
1426 torch::Tensor tensor;
1427
1428 archive.read(key + ".type", tensor);
1430 throw std::runtime_error("activation mismatch");
1431
1432 return archive;
1433 }
1434
1435private:
1436 torch::nn::functional::LogSoftmaxFuncOptions options_;
1437};
1438
1444class Mish : public ActivationFunction {
1445public:
1446 explicit Mish() = default;
1447
1448 ~Mish() override = default;
1449
1451 inline torch::Tensor apply(const torch::Tensor &input) const override {
1452 return torch::mish(input);
1453 }
1454
1456 inline virtual void
1457 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1459 }
1460
1463 inline torch::serialize::OutputArchive &
1464 write(torch::serialize::OutputArchive &archive,
1465 const std::string &key = "mish") const override {
1466 archive.write(key + ".type", torch::full({1}, (int64_t)activation::mish));
1467
1468 return archive;
1469 }
1470
1473 inline torch::serialize::InputArchive &
1474 read(torch::serialize::InputArchive &archive,
1475 const std::string &key = "mish") override {
1476 torch::Tensor tensor;
1477
1478 archive.read(key + ".type", tensor);
1479 if (tensor.item<int64_t>() != (int64_t)activation::mish)
1480 throw std::runtime_error("activation mismatch");
1481
1482 return archive;
1483 }
1484};
1485
1488public:
1489 explicit Normalize(torch::nn::functional::NormalizeFuncOptions options = {})
1490 : options_(std::move(options)) {}
1491
1492 explicit Normalize(double p, double eps, int64_t dim)
1493 : options_(
1494 torch::nn::functional::NormalizeFuncOptions().p(p).eps(eps).dim(
1495 dim)) {}
1496
1497 ~Normalize() override = default;
1498
1500 inline torch::Tensor apply(const torch::Tensor &input) const override {
1501 return torch::nn::functional::normalize(input, options_);
1502 }
1503
1505 inline const torch::nn::functional::NormalizeFuncOptions &options() const {
1506 return options_;
1507 }
1508
1510 inline torch::nn::functional::NormalizeFuncOptions &options() {
1511 return options_;
1512 }
1513
1515 inline virtual void
1516 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1517 os << utils::FullQualifiedName::name() << "(\n eps=" << options_.eps()
1518 << "(\n p=" << options_.p() << "(\n dim=" << options_.dim() << "\n)";
1519 }
1520
1523 inline torch::serialize::OutputArchive &
1524 write(torch::serialize::OutputArchive &archive,
1525 const std::string &key = "normalize") const override {
1526 archive.write(key + ".type",
1527 torch::full({1}, (int64_t)activation::normalize));
1528 archive.write(key + ".p", torch::full({1}, (double)this->options_.p()));
1529 archive.write(key + ".eps", torch::full({1}, (double)this->options_.eps()));
1530 archive.write(key + ".dim",
1531 torch::full({1}, (int64_t)this->options_.dim()));
1532
1533 return archive;
1534 }
1535
1538 inline torch::serialize::InputArchive &
1539 read(torch::serialize::InputArchive &archive,
1540 const std::string &key = "normalize") override {
1541 torch::Tensor tensor;
1542
1543 archive.read(key + ".type", tensor);
1545 throw std::runtime_error("activation mismatch");
1546
1547 archive.read(key + ".p", tensor);
1548 this->options_.p(tensor.item<double>());
1549 archive.read(key + ".eps", tensor);
1550 this->options_.eps(tensor.item<double>());
1551 archive.read(key + ".dim", tensor);
1552 this->options_.dim(tensor.item<int64_t>());
1553
1554 return archive;
1555 }
1556
1557private:
1558 torch::nn::functional::NormalizeFuncOptions options_;
1559};
1560
1563public:
1564 explicit PReLU(const torch::Tensor &weight) : weight_(weight) {}
1565
1566 ~PReLU() override = default;
1567
1569 const torch::Tensor &weight() const { return weight_; }
1570
1572 torch::Tensor &weight() { return weight_; }
1573
1575 inline torch::Tensor apply(const torch::Tensor &input) const override {
1576 return torch::nn::functional::prelu(input, weight());
1577 }
1578
1580 inline virtual void
1581 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1583
1584 if (is_verbose(os))
1585 os << "(\n weight = " << weight() << "\n)";
1586 }
1587
1590 inline torch::serialize::OutputArchive &
1591 write(torch::serialize::OutputArchive &archive,
1592 const std::string &key = "prelu") const override {
1593 archive.write(key + ".type", torch::full({1}, (int64_t)activation::prelu));
1594 archive.write(key + ".weight", this->weight());
1595
1596 return archive;
1597 }
1598
1601 inline torch::serialize::InputArchive &
1602 read(torch::serialize::InputArchive &archive,
1603 const std::string &key = "prelu") override {
1604 torch::Tensor tensor;
1605
1606 archive.read(key + ".type", tensor);
1607 if (tensor.item<int64_t>() != (int64_t)activation::prelu)
1608 throw std::runtime_error("activation mismatch");
1609
1610 archive.read(key + ".weight", this->weight());
1611
1612 return archive;
1613 }
1614
1615private:
1616 torch::Tensor weight_;
1617};
1618
1624class ReLU : public ActivationFunction {
1625public:
1626 explicit ReLU(torch::nn::functional::ReLUFuncOptions options = {})
1627 : options_(options) {}
1628
1629 explicit ReLU(bool inplace)
1631
1632 ~ReLU() override = default;
1633
1635 inline torch::Tensor apply(const torch::Tensor &input) const override {
1636 return torch::nn::functional::relu(input, options_);
1637 }
1638
1640 inline const torch::nn::functional::ReLUFuncOptions &options() const {
1641 return options_;
1642 }
1643
1645 inline torch::nn::functional::ReLUFuncOptions &options() { return options_; }
1646
1648 inline virtual void
1649 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1651 << "(\n inplace=" << options_.inplace() << "\n)";
1652 }
1653
1656 inline torch::serialize::OutputArchive &
1657 write(torch::serialize::OutputArchive &archive,
1658 const std::string &key = "relu") const override {
1659 archive.write(key + ".type", torch::full({1}, (int64_t)activation::relu));
1660 archive.write(key + ".inplace",
1661 torch::full({1}, (bool)this->options_.inplace()));
1662
1663 return archive;
1664 }
1665
1668 inline torch::serialize::InputArchive &
1669 read(torch::serialize::InputArchive &archive,
1670 const std::string &key = "relu") override {
1671 torch::Tensor tensor;
1672
1673 archive.read(key + ".type", tensor);
1674 if (tensor.item<int64_t>() != (int64_t)activation::relu)
1675 throw std::runtime_error("activation mismatch");
1676
1677 archive.read(key + ".inplace", tensor);
1678 this->options_.inplace(tensor.item<bool>());
1679
1680 return archive;
1681 }
1682
1683private:
1684 torch::nn::functional::ReLUFuncOptions options_;
1685};
1686
1693public:
1694 explicit ReLU6(torch::nn::functional::ReLU6FuncOptions options = {})
1695 : options_(options) {}
1696
1697 explicit ReLU6(bool inplace)
1699
1700 ~ReLU6() override = default;
1701
1703 inline torch::Tensor apply(const torch::Tensor &input) const override {
1704 return torch::nn::functional::relu6(input, options_);
1705 }
1706
1708 inline const torch::nn::functional::ReLU6FuncOptions &options() const {
1709 return options_;
1710 }
1711
1713 inline torch::nn::functional::ReLU6FuncOptions &options() { return options_; }
1714
1716 inline virtual void
1717 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1719 << "(\n inplace=" << options_.inplace() << "\n)";
1720 }
1721
1724 inline torch::serialize::OutputArchive &
1725 write(torch::serialize::OutputArchive &archive,
1726 const std::string &key = "relu6") const override {
1727 archive.write(key + ".type", torch::full({1}, (int64_t)activation::relu6));
1728 archive.write(key + ".inplace",
1729 torch::full({1}, (bool)this->options_.inplace()));
1730
1731 return archive;
1732 }
1733
1736 inline torch::serialize::InputArchive &
1737 read(torch::serialize::InputArchive &archive,
1738 const std::string &key = "relu6") override {
1739 torch::Tensor tensor;
1740
1741 archive.read(key + ".type", tensor);
1742 if (tensor.item<int64_t>() != (int64_t)activation::relu6)
1743 throw std::runtime_error("activation mismatch");
1744
1745 archive.read(key + ".inplace", tensor);
1746 this->options_.inplace(tensor.item<bool>());
1747
1748 return archive;
1749 }
1750
1751private:
1752 torch::nn::functional::ReLU6FuncOptions options_;
1753};
1754
1765public:
1766 explicit RReLU(torch::nn::functional::RReLUFuncOptions options = {})
1767 : options_(options) {}
1768
1769 explicit RReLU(double lower, double upper, bool inplace = false)
1770 : options_(torch::nn::functional::RReLUFuncOptions()
1771 .lower(lower)
1772 .upper(upper)
1773 .inplace(inplace)) {}
1774
1775 ~RReLU() override = default;
1776
1778 inline torch::Tensor apply(const torch::Tensor &input) const override {
1779 return torch::nn::functional::rrelu(input, options_);
1780 }
1781
1783 inline const torch::nn::functional::RReLUFuncOptions &options() const {
1784 return options_;
1785 }
1786
1788 inline torch::nn::functional::RReLUFuncOptions &options() { return options_; }
1789
1791 inline virtual void
1792 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1793 os << utils::FullQualifiedName::name() << "(\n lower=" << options_.lower()
1794 << ", upper=" << options_.upper() << ", inplace=" << options_.inplace()
1795 << "\n)";
1796 }
1797
1800 inline torch::serialize::OutputArchive &
1801 write(torch::serialize::OutputArchive &archive,
1802 const std::string &key = "rrelu") const override {
1803 archive.write(key + ".type", torch::full({1}, (int64_t)activation::rrelu));
1804 archive.write(key + ".lower",
1805 torch::full({1}, (double)this->options_.lower()));
1806 archive.write(key + ".upper",
1807 torch::full({1}, (double)this->options_.upper()));
1808 archive.write(key + ".inplace",
1809 torch::full({1}, (bool)this->options_.inplace()));
1810
1811 return archive;
1812 }
1813
1816 inline torch::serialize::InputArchive &
1817 read(torch::serialize::InputArchive &archive,
1818 const std::string &key = "rrelu") override {
1819 torch::Tensor tensor;
1820
1821 archive.read(key + ".type", tensor);
1822 if (tensor.item<int64_t>() != (int64_t)activation::rrelu)
1823 throw std::runtime_error("activation mismatch");
1824
1825 archive.read(key + ".lower", tensor);
1826 this->options_.lower(tensor.item<double>());
1827 archive.read(key + ".upper", tensor);
1828 this->options_.upper(tensor.item<double>());
1829 archive.read(key + ".inplace", tensor);
1830 this->options_.inplace(tensor.item<bool>());
1831
1832 return archive;
1833 }
1834
1835private:
1836 torch::nn::functional::RReLUFuncOptions options_;
1837};
1838
1847class SELU : public ActivationFunction {
1848public:
1849 explicit SELU(torch::nn::functional::SELUFuncOptions options = {})
1850 : options_(options) {}
1851
1852 explicit SELU(bool inplace)
1854
1855 ~SELU() override = default;
1856
1858 inline torch::Tensor apply(const torch::Tensor &input) const override {
1859 return torch::nn::functional::selu(input, options_);
1860 }
1861
1863 inline const torch::nn::functional::SELUFuncOptions &options() const {
1864 return options_;
1865 }
1866
1868 inline torch::nn::functional::SELUFuncOptions &options() { return options_; }
1869
1871 inline virtual void
1872 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1874 << "(\n inplace=" << options_.inplace() << "\n)";
1875 }
1876
1879 inline torch::serialize::OutputArchive &
1880 write(torch::serialize::OutputArchive &archive,
1881 const std::string &key = "selu") const override {
1882 archive.write(key + ".type", torch::full({1}, (int64_t)activation::selu));
1883 archive.write(key + ".inplace",
1884 torch::full({1}, (bool)this->options_.inplace()));
1885
1886 return archive;
1887 }
1888
1891 inline torch::serialize::InputArchive &
1892 read(torch::serialize::InputArchive &archive,
1893 const std::string &key = "selu") override {
1894 torch::Tensor tensor;
1895
1896 archive.read(key + ".type", tensor);
1897 if (tensor.item<int64_t>() != (int64_t)activation::selu)
1898 throw std::runtime_error("activation mismatch");
1899
1900 archive.read(key + ".inplace", tensor);
1901 this->options_.inplace(tensor.item<bool>());
1902
1903 return archive;
1904 }
1905
1906private:
1907 torch::nn::functional::SELUFuncOptions options_;
1908};
1909
1916public:
1918 inline torch::Tensor apply(const torch::Tensor &input) const override {
1919 return torch::sigmoid(input);
1920 }
1921
1923 inline virtual void
1924 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1926 }
1927
1930 inline torch::serialize::OutputArchive &
1931 write(torch::serialize::OutputArchive &archive,
1932 const std::string &key = "sigmoid") const override {
1933 archive.write(key + ".type",
1934 torch::full({1}, (int64_t)activation::sigmoid));
1935
1936 return archive;
1937 }
1938
1941 inline torch::serialize::InputArchive &
1942 read(torch::serialize::InputArchive &archive,
1943 const std::string &key = "sigmoid") override {
1944 torch::Tensor tensor;
1945
1946 archive.read(key + ".type", tensor);
1947 if (tensor.item<int64_t>() != (int64_t)activation::sigmoid)
1948 throw std::runtime_error("activation mismatch");
1949
1950 return archive;
1951 }
1952};
1953
1959class SiLU : public ActivationFunction {
1960public:
1962 inline torch::Tensor apply(const torch::Tensor &input) const override {
1963 return torch::silu(input);
1964 }
1965
1967 inline virtual void
1968 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
1970 }
1971
1974 inline torch::serialize::OutputArchive &
1975 write(torch::serialize::OutputArchive &archive,
1976 const std::string &key = "silu") const override {
1977 archive.write(key + ".type", torch::full({1}, (int64_t)activation::silu));
1978
1979 return archive;
1980 }
1981
1984 inline torch::serialize::InputArchive &
1985 read(torch::serialize::InputArchive &archive,
1986 const std::string &key = "silu") override {
1987 torch::Tensor tensor;
1988
1989 archive.read(key + ".type", tensor);
1990 if (tensor.item<int64_t>() != (int64_t)activation::silu)
1991 throw std::runtime_error("activation mismatch");
1992
1993 return archive;
1994 }
1995};
1996
2005public:
2006 explicit Softmax(int64_t dim)
2007 : options_(torch::nn::functional::SoftmaxFuncOptions(dim)) {}
2008
2009 explicit Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
2010 : options_(options) {}
2011
2012 ~Softmax() override = default;
2013
2015 inline torch::Tensor apply(const torch::Tensor &input) const override {
2016 return torch::nn::functional::softmax(input, options_);
2017 }
2018
2020 inline const torch::nn::functional::SoftmaxFuncOptions &options() const {
2021 return options_;
2022 }
2023
2025 inline torch::nn::functional::SoftmaxFuncOptions &options() {
2026 return options_;
2027 }
2028
2030 inline virtual void
2031 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2032 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
2033 << "\n)";
2034 }
2035
2038 inline torch::serialize::OutputArchive &
2039 write(torch::serialize::OutputArchive &archive,
2040 const std::string &key = "softmax") const override {
2041 archive.write(key + ".type",
2042 torch::full({1}, (int64_t)activation::softmax));
2043 archive.write(key + ".dim",
2044 torch::full({1}, (int64_t)this->options_.dim()));
2045
2046 return archive;
2047 }
2048
2051 inline torch::serialize::InputArchive &
2052 read(torch::serialize::InputArchive &archive,
2053 const std::string &key = "softmax") override {
2054 torch::Tensor tensor;
2055
2056 archive.read(key + ".type", tensor);
2057 if (tensor.item<int64_t>() != (int64_t)activation::softmax)
2058 throw std::runtime_error("activation mismatch");
2059
2060 archive.read(key + ".dim", tensor);
2061 this->options_.dim(tensor.item<int64_t>());
2062
2063 return archive;
2064 }
2065
2066private:
2067 torch::nn::functional::SoftmaxFuncOptions options_;
2068};
2069
2076public:
2077 explicit Softmin(int64_t dim)
2078 : options_(torch::nn::functional::SoftminFuncOptions(dim)) {}
2079
2080 explicit Softmin(const torch::nn::functional::SoftminFuncOptions &options)
2081 : options_(options) {}
2082
2083 ~Softmin() override = default;
2084
2086 inline torch::Tensor apply(const torch::Tensor &input) const override {
2087 return torch::nn::functional::softmin(input, options_);
2088 }
2089
2091 inline const torch::nn::functional::SoftminFuncOptions &options() const {
2092 return options_;
2093 }
2094
2096 inline torch::nn::functional::SoftminFuncOptions &options() {
2097 return options_;
2098 }
2099
2101 inline virtual void
2102 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2103 os << utils::FullQualifiedName::name() << "(\n dim=" << options_.dim()
2104 << "\n)";
2105 }
2106
2109 inline torch::serialize::OutputArchive &
2110 write(torch::serialize::OutputArchive &archive,
2111 const std::string &key = "softmin") const override {
2112 archive.write(key + ".type",
2113 torch::full({1}, (int64_t)activation::softmin));
2114 archive.write(key + ".dim",
2115 torch::full({1}, (int64_t)this->options_.dim()));
2116
2117 return archive;
2118 }
2119
2122 inline torch::serialize::InputArchive &
2123 read(torch::serialize::InputArchive &archive,
2124 const std::string &key = "softmin") override {
2125 torch::Tensor tensor;
2126
2127 archive.read(key + ".type", tensor);
2128 if (tensor.item<int64_t>() != (int64_t)activation::softmin)
2129 throw std::runtime_error("activation mismatch");
2130
2131 archive.read(key + ".dim", tensor);
2132 this->options_.dim(tensor.item<int64_t>());
2133
2134 return archive;
2135 }
2136
2137private:
2138 torch::nn::functional::SoftminFuncOptions options_;
2139};
2140
2147public:
2148 explicit Softplus(torch::nn::functional::SoftplusFuncOptions options = {})
2149 : options_(options) {}
2150
2151 explicit Softplus(double beta, double threshold)
2152 : options_(
2154 threshold)) {}
2155
2156 ~Softplus() override = default;
2157
2159 inline torch::Tensor apply(const torch::Tensor &input) const override {
2160 return torch::nn::functional::softplus(input, options_);
2161 }
2162
2164 inline const torch::nn::functional::SoftplusFuncOptions &options() const {
2165 return options_;
2166 }
2167
2169 inline torch::nn::functional::SoftplusFuncOptions &options() {
2170 return options_;
2171 }
2172
2174 inline virtual void
2175 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2176 os << utils::FullQualifiedName::name() << "(\n beta=" << options_.beta()
2177 << ", theshold=" << options_.threshold() << "\n)";
2178 }
2179
2182 inline torch::serialize::OutputArchive &
2183 write(torch::serialize::OutputArchive &archive,
2184 const std::string &key = "softplus") const override {
2185 archive.write(key + ".type",
2186 torch::full({1}, (int64_t)activation::softplus));
2187 archive.write(key + ".beta",
2188 torch::full({1}, (double)this->options_.beta()));
2189 archive.write(key + ".threshold",
2190 torch::full({1}, (double)this->options_.threshold()));
2191
2192 return archive;
2193 }
2194
2197 inline torch::serialize::InputArchive &
2198 read(torch::serialize::InputArchive &archive,
2199 const std::string &key = "softplus") override {
2200 torch::Tensor tensor;
2201
2202 archive.read(key + ".type", tensor);
2204 throw std::runtime_error("activation mismatch");
2205
2206 archive.read(key + ".beta", tensor);
2207 this->options_.beta(tensor.item<double>());
2208 archive.read(key + ".threshold", tensor);
2209 this->options_.threshold(tensor.item<double>());
2210
2211 return archive;
2212 }
2213
2214private:
2215 torch::nn::functional::SoftplusFuncOptions options_;
2216};
2217
2229public:
2230 explicit Softshrink(torch::nn::functional::SoftshrinkFuncOptions options = {})
2231 : options_(options) {}
2232
2233 explicit Softshrink(double lambda)
2234 : options_(
2236
2237 ~Softshrink() override = default;
2238
2240 inline torch::Tensor apply(const torch::Tensor &input) const override {
2241 return torch::nn::functional::softshrink(input, options_);
2242 }
2243
2245 inline const torch::nn::functional::SoftshrinkFuncOptions &options() const {
2246 return options_;
2247 }
2248
2250 inline torch::nn::functional::SoftshrinkFuncOptions &options() {
2251 return options_;
2252 }
2253
2255 inline virtual void
2256 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2258 << "(\n lambda=" << options_.lambda() << "\n)";
2259 }
2260
2263 inline torch::serialize::OutputArchive &
2264 write(torch::serialize::OutputArchive &archive,
2265 const std::string &key = "softshrink") const override {
2266 archive.write(key + ".type",
2267 torch::full({1}, (int64_t)activation::softshrink));
2268 archive.write(key + ".lambda",
2269 torch::full({1}, (double)this->options_.lambda()));
2270
2271 return archive;
2272 }
2273
2276 inline torch::serialize::InputArchive &
2277 read(torch::serialize::InputArchive &archive,
2278 const std::string &key = "softshrink") override {
2279 torch::Tensor tensor;
2280
2281 archive.read(key + ".type", tensor);
2283 throw std::runtime_error("activation mismatch");
2284
2285 archive.read(key + ".lambda", tensor);
2286 this->options_.lambda(tensor.item<double>());
2287
2288 return archive;
2289 }
2290
2291private:
2292 torch::nn::functional::SoftshrinkFuncOptions options_;
2293};
2294
2301public:
2303 inline torch::Tensor apply(const torch::Tensor &input) const override {
2304 return torch::nn::functional::softsign(input);
2305 }
2306
2308 inline virtual void
2309 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2311 }
2312
2315 inline torch::serialize::OutputArchive &
2316 write(torch::serialize::OutputArchive &archive,
2317 const std::string &key = "softsign") const override {
2318 archive.write(key + ".type",
2319 torch::full({1}, (int64_t)activation::softsign));
2320
2321 return archive;
2322 }
2323
2326 inline torch::serialize::InputArchive &
2327 read(torch::serialize::InputArchive &archive,
2328 const std::string &key = "softsign") override {
2329 torch::Tensor tensor;
2330
2331 archive.read(key + ".type", tensor);
2333 throw std::runtime_error("activation mismatch");
2334
2335 return archive;
2336 }
2337};
2338
2344class Tanh : public ActivationFunction {
2345public:
2347 inline torch::Tensor apply(const torch::Tensor &input) const override {
2348 return torch::tanh(input);
2349 }
2350
2352 inline virtual void
2353 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2355 }
2356
2359 inline torch::serialize::OutputArchive &
2360 write(torch::serialize::OutputArchive &archive,
2361 const std::string &key = "tanh") const override {
2362 archive.write(key + ".type", torch::full({1}, (int64_t)activation::tanh));
2363
2364 return archive;
2365 }
2366
2369 inline torch::serialize::InputArchive &
2370 read(torch::serialize::InputArchive &archive,
2371 const std::string &key = "tanh") override {
2372 torch::Tensor tensor;
2373
2374 archive.read(key + ".type", tensor);
2375 if (tensor.item<int64_t>() != (int64_t)activation::tanh)
2376 throw std::runtime_error("activation mismatch");
2377
2378 return archive;
2379 }
2380};
2381
2388public:
2390 inline torch::Tensor apply(const torch::Tensor &input) const override {
2391 return torch::nn::functional::tanhshrink(input);
2392 }
2393
2395 inline virtual void
2396 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2398 }
2399
2402 inline torch::serialize::OutputArchive &
2403 write(torch::serialize::OutputArchive &archive,
2404 const std::string &key = "tanhshrink") const override {
2405 archive.write(key + ".type",
2406 torch::full({1}, (int64_t)activation::tanhshrink));
2407
2408 return archive;
2409 }
2410
2413 inline torch::serialize::InputArchive &
2414 read(torch::serialize::InputArchive &archive,
2415 const std::string &key = "tanhshrink") override {
2416 torch::Tensor tensor;
2417
2418 archive.read(key + ".type", tensor);
2420 throw std::runtime_error("activation mismatch");
2421
2422 return archive;
2423 }
2424};
2425
2436public:
2437 explicit Threshold(torch::nn::functional::ThresholdFuncOptions options)
2438 : options_(options) {}
2439
2440 explicit Threshold(double threshold, double value, bool inplace = false)
2442 .inplace(inplace)) {}
2443
2444 ~Threshold() override = default;
2445
2447 inline torch::Tensor apply(const torch::Tensor &input) const override {
2448 return torch::nn::functional::threshold(input, options_);
2449 }
2450
2452 inline const torch::nn::functional::ThresholdFuncOptions &options() const {
2453 return options_;
2454 }
2455
2457 inline torch::nn::functional::ThresholdFuncOptions &options() {
2458 return options_;
2459 }
2460
2462 inline virtual void
2463 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
2465 << "(\n threshold=" << options_.threshold()
2466 << ", value=" << options_.value() << ", inplace=" << options_.inplace()
2467 << "\n)";
2468 }
2469
2472 inline torch::serialize::OutputArchive &
2473 write(torch::serialize::OutputArchive &archive,
2474 const std::string &key = "threshold") const override {
2475 archive.write(key + ".type",
2476 torch::full({1}, (int64_t)activation::threshold));
2477 archive.write(key + ".threshold",
2478 torch::full({1}, (double)this->options_.threshold()));
2479 archive.write(key + ".value",
2480 torch::full({1}, (double)this->options_.value()));
2481 archive.write(key + ".inplace",
2482 torch::full({1}, (bool)this->options_.inplace()));
2483
2484 return archive;
2485 }
2486
2489 inline torch::serialize::InputArchive &
2490 read(torch::serialize::InputArchive &archive,
2491 const std::string &key = "threshold") override {
2492 torch::Tensor tensor;
2493
2494 archive.read(key + ".type", tensor);
2496 throw std::runtime_error("activation mismatch");
2497
2498 archive.read(key + ".threshold", tensor);
2499 this->options_.threshold(tensor.item<double>());
2500 archive.read(key + ".value", tensor);
2501 this->options_.value(tensor.item<double>());
2502 archive.read(key + ".inplace", tensor);
2503 this->options_.inplace(tensor.item<bool>());
2504
2505 return archive;
2506 }
2507
2508private:
2509 torch::nn::functional::ThresholdFuncOptions options_;
2510};
2511
2512} // namespace iganet
Abstract activation function structure.
Definition layer.hpp:62
virtual torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key) const =0
Writes the activation function into a torch::serialize::OutputArchive object.
virtual torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key)=0
Reads the activation function from a torch::serialize::InputArchive object.
virtual void pretty_print(std::ostream &os) const noexcept=0
Returns a string representation of the activation function.
virtual torch::Tensor apply(const torch::Tensor &) const =0
Applies the activation function to the given input.
virtual ~ActivationFunction()=default
Batch Normalization as described in the paper.
Definition layer.hpp:134
BatchNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, torch::nn::functional::BatchNormFuncOptions options={})
Definition layer.hpp:136
torch::Tensor & running_mean()
Returns non-constant reference to running mean.
Definition layer.hpp:166
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:186
torch::Tensor running_var_
Definition layer.hpp:255
BatchNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool training=false)
Definition layer.hpp:142
const torch::Tensor & running_mean() const
Returns constant reference to running mean.
Definition layer.hpp:163
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="batch_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:207
const torch::Tensor & running_var() const
Returns constant reference to running variance.
Definition layer.hpp:169
torch::Tensor running_mean_
Definition layer.hpp:255
const torch::nn::functional::BatchNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:175
~BatchNorm() override=default
torch::nn::functional::BatchNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:180
torch::Tensor & running_var()
Returns non-constant reference to running var.
Definition layer.hpp:172
torch::nn::functional::BatchNormFuncOptions options_
Definition layer.hpp:254
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="batch_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:231
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:157
Continuously Differentiable Exponential Linear Units activation function.
Definition layer.hpp:264
torch::nn::functional::CELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:286
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:290
CELU(torch::nn::functional::CELUFuncOptions options={})
Definition layer.hpp:266
torch::nn::functional::CELUFuncOptions options_
Definition layer.hpp:329
const torch::nn::functional::CELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:281
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="celu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:312
CELU(double alpha, bool inplace=false)
Definition layer.hpp:269
~CELU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:276
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="celu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:298
Exponential Linear Units activation function.
Definition layer.hpp:341
const torch::nn::functional::ELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:358
~ELU() override=default
torch::nn::functional::ELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:363
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="elu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:375
ELU(double alpha, bool inplace=false)
Definition layer.hpp:346
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:353
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:367
torch::nn::functional::ELUFuncOptions options_
Definition layer.hpp:406
ELU(torch::nn::functional::ELUFuncOptions options={})
Definition layer.hpp:343
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="elu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:389
Gaussian Error Linear Units activation function.
Definition layer.hpp:417
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:437
GELU()=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:430
~GELU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:424
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:447
Grated Linear Units activation function.
Definition layer.hpp:468
~GLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="glu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:512
const torch::nn::functional::GLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:484
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:479
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="glu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:501
torch::nn::functional::GLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:489
torch::nn::functional::GLUFuncOptions options_
Definition layer.hpp:527
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:493
GLU(torch::nn::functional::GLUFuncOptions options={})
Definition layer.hpp:470
GLU(int64_t dim)
Definition layer.hpp:473
Group Normalization over a mini-batch of inputs as described in the paper Group Normalization,...
Definition layer.hpp:532
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:550
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:566
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="group_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:594
torch::nn::functional::GroupNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:560
const torch::nn::functional::GroupNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:555
~GroupNorm() override=default
GroupNorm(int64_t num_groups, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition layer.hpp:540
GroupNorm(int64_t num_groups)
Definition layer.hpp:534
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="group_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:580
torch::nn::functional::GroupNormFuncOptions options_
Definition layer.hpp:611
GroupNorm(torch::nn::functional::GroupNormFuncOptions options)
Definition layer.hpp:537
Gumbel-Softmax distribution activation function.
Definition layer.hpp:615
const torch::nn::functional::GumbelSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:636
GumbelSoftmax(torch::nn::functional::GumbelSoftmaxFuncOptions options={})
Definition layer.hpp:617
~GumbelSoftmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:630
torch::nn::functional::GumbelSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:641
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:647
GumbelSoftmax(double tau, int dim, bool hard)
Definition layer.hpp:621
torch::nn::functional::GumbelSoftmaxFuncOptions options_
Definition layer.hpp:688
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="gumbel_softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:655
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="gumbel_softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:669
Hard shrinkish activation function.
Definition layer.hpp:692
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:741
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:720
~Hardshrink() override=default
torch::nn::functional::HardshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:714
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:728
Hardshrink(double lambda)
Definition layer.hpp:697
torch::nn::functional::HardshrinkFuncOptions options_
Definition layer.hpp:756
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:704
Hardshrink(torch::nn::functional::HardshrinkFuncOptions options={})
Definition layer.hpp:694
const torch::nn::functional::HardshrinkFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:709
Hardsigmoid activation function.
Definition layer.hpp:769
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:776
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:800
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:789
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:782
~Hardsigmoid() override=default
Hardswish activation function.
Definition layer.hpp:822
~Hardswish() override=default
Hardswish()=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardswish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:842
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:829
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardswish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:853
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:835
Hardtanh activation function.
Definition layer.hpp:875
const torch::nn::functional::HardtanhFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:894
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="hardtanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:932
Hardtanh(torch::nn::functional::HardtanhFuncOptions options={})
Definition layer.hpp:877
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:905
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:889
torch::nn::functional::HardtanhFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:899
~Hardtanh() override=default
Hardtanh(double min_val, double max_val, bool inplace=false)
Definition layer.hpp:880
torch::nn::functional::HardtanhFuncOptions options_
Definition layer.hpp:951
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="hardtanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:915
Instance Normalization as described in the paper.
Definition layer.hpp:958
torch::nn::functional::InstanceNormFuncOptions options_
Definition layer.hpp:1058
torch::nn::functional::InstanceNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:991
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:981
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="instance_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1015
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:997
const torch::nn::functional::InstanceNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:986
InstanceNorm(torch::nn::functional::InstanceNormFuncOptions options={})
Definition layer.hpp:960
InstanceNorm(const torch::Tensor &running_mean, const torch::Tensor &running_var, const torch::Tensor &weight, const torch::Tensor &bias, double eps, double momentum, bool use_input_stats=true)
Definition layer.hpp:964
~InstanceNorm() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="instance_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1035
Layer Normalization as described in the paper.
Definition layer.hpp:1064
const torch::nn::functional::LayerNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1090
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="layer_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1116
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1101
LayerNorm(std::vector< int64_t > normalized_shape, const torch::Tensor &weight, const torch::Tensor &bias, double eps)
Definition layer.hpp:1073
LayerNorm(std::vector< int64_t > normalized_shape)
Definition layer.hpp:1066
~LayerNorm() override=default
torch::nn::functional::LayerNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1095
LayerNorm(torch::nn::functional::LayerNormFuncOptions options)
Definition layer.hpp:1070
torch::nn::functional::LayerNormFuncOptions options_
Definition layer.hpp:1147
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1085
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="layer_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1130
Leaky ReLU activation function.
Definition layer.hpp:1159
~LeakyReLU() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1188
const torch::nn::functional::LeakyReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1177
LeakyReLU(torch::nn::functional::LeakyReLUFuncOptions options={})
Definition layer.hpp:1161
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="leaky_relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1197
torch::nn::functional::LeakyReLUFuncOptions options_
Definition layer.hpp:1230
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1172
torch::nn::functional::LeakyReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1182
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="leaky_relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1213
LeakyReLU(double negative_slope, bool inplace=false)
Definition layer.hpp:1164
Local response Normalization.
Definition layer.hpp:1234
torch::nn::functional::LocalResponseNormFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1263
LocalResponseNorm(const torch::nn::functional::LocalResponseNormFuncOptions &options)
Definition layer.hpp:1239
torch::nn::functional::LocalResponseNormFuncOptions options_
Definition layer.hpp:1318
LocalResponseNorm(int64_t size, double alpha, double beta, double k)
Definition layer.hpp:1243
LocalResponseNorm(int64_t size)
Definition layer.hpp:1236
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="local_response_norm") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1278
~LocalResponseNorm() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1252
const torch::nn::functional::LocalResponseNormFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1258
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="local_response_norm") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1297
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1269
LogSigmoid activation function.
Definition layer.hpp:1326
~LogSigmoid() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1357
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1346
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1339
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1333
LogSigmoid()=default
LogSoftmax activation function.
Definition layer.hpp:1377
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1389
torch::nn::functional::LogSoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1399
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="logsoftmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1413
LogSoftmax(int64_t dim)
Definition layer.hpp:1379
LogSoftmax(const torch::nn::functional::LogSoftmaxFuncOptions &options)
Definition layer.hpp:1382
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1405
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="logsoftmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1424
torch::nn::functional::LogSoftmaxFuncOptions options_
Definition layer.hpp:1436
~LogSoftmax() override=default
const torch::nn::functional::LogSoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1394
Mish activation function.
Definition layer.hpp:1444
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1451
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="mish") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1474
~Mish() override=default
Mish()=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="mish") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1464
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1457
No-op activation function.
Definition layer.hpp:92
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="none") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:107
virtual void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:100
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:95
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="none") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:117
Lp Normalization.
Definition layer.hpp:1487
~Normalize() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1516
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="normalize") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1524
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="normalize") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1539
torch::nn::functional::NormalizeFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1510
Normalize(double p, double eps, int64_t dim)
Definition layer.hpp:1492
Normalize(torch::nn::functional::NormalizeFuncOptions options={})
Definition layer.hpp:1489
torch::nn::functional::NormalizeFuncOptions options_
Definition layer.hpp:1558
const torch::nn::functional::NormalizeFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1505
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1500
PReLU activation function.
Definition layer.hpp:1562
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1575
torch::Tensor weight_
Definition layer.hpp:1616
~PReLU() override=default
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="prelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1591
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="prelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1602
const torch::Tensor & weight() const
Returns constant reference to weights.
Definition layer.hpp:1569
PReLU(const torch::Tensor &weight)
Definition layer.hpp:1564
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1581
torch::Tensor & weight()
Returns non-constant reference to weights.
Definition layer.hpp:1572
Randomized ReLU activation function.
Definition layer.hpp:1764
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1792
RReLU(double lower, double upper, bool inplace=false)
Definition layer.hpp:1769
const torch::nn::functional::RReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1783
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1778
torch::nn::functional::RReLUFuncOptions options_
Definition layer.hpp:1836
~RReLU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="rrelu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1817
torch::nn::functional::RReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1788
RReLU(torch::nn::functional::RReLUFuncOptions options={})
Definition layer.hpp:1766
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="rrelu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1801
ReLU6 activation function.
Definition layer.hpp:1692
~ReLU6() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1717
ReLU6(torch::nn::functional::ReLU6FuncOptions options={})
Definition layer.hpp:1694
torch::nn::functional::ReLU6FuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1713
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1703
const torch::nn::functional::ReLU6FuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1708
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu6") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1725
ReLU6(bool inplace)
Definition layer.hpp:1697
torch::nn::functional::ReLU6FuncOptions options_
Definition layer.hpp:1752
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu6") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1737
ReLU activation function.
Definition layer.hpp:1624
torch::nn::functional::ReLUFuncOptions options_
Definition layer.hpp:1684
~ReLU() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1635
ReLU(bool inplace)
Definition layer.hpp:1629
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="relu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1657
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1649
torch::nn::functional::ReLUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1645
ReLU(torch::nn::functional::ReLUFuncOptions options={})
Definition layer.hpp:1626
const torch::nn::functional::ReLUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1640
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="relu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1669
SELU activation function.
Definition layer.hpp:1847
~SELU() override=default
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="selu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1892
torch::nn::functional::SELUFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:1868
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1858
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1872
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="selu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1880
const torch::nn::functional::SELUFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:1863
SELU(bool inplace)
Definition layer.hpp:1852
SELU(torch::nn::functional::SELUFuncOptions options={})
Definition layer.hpp:1849
torch::nn::functional::SELUFuncOptions options_
Definition layer.hpp:1907
Sigmoid Linear Unit activation function.
Definition layer.hpp:1959
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="silu") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1975
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1962
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="silu") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1985
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1968
Sigmoid activation function.
Definition layer.hpp:1915
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:1924
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:1918
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="sigmoid") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:1942
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="sigmoid") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:1931
Softmax activation function.
Definition layer.hpp:2004
torch::nn::functional::SoftmaxFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2025
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2031
Softmax(const torch::nn::functional::SoftmaxFuncOptions &options)
Definition layer.hpp:2009
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmax") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2039
const torch::nn::functional::SoftmaxFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2020
~Softmax() override=default
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2015
Softmax(int64_t dim)
Definition layer.hpp:2006
torch::nn::functional::SoftmaxFuncOptions options_
Definition layer.hpp:2067
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmax") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2052
Softmin activation function.
Definition layer.hpp:2075
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softmin") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2110
Softmin(int64_t dim)
Definition layer.hpp:2077
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2086
torch::nn::functional::SoftminFuncOptions options_
Definition layer.hpp:2138
const torch::nn::functional::SoftminFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2091
~Softmin() override=default
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2102
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softmin") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2123
torch::nn::functional::SoftminFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2096
Softmin(const torch::nn::functional::SoftminFuncOptions &options)
Definition layer.hpp:2080
Softplus activation function.
Definition layer.hpp:2146
torch::nn::functional::SoftplusFuncOptions options_
Definition layer.hpp:2215
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2159
Softplus(torch::nn::functional::SoftplusFuncOptions options={})
Definition layer.hpp:2148
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2175
const torch::nn::functional::SoftplusFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2164
Softplus(double beta, double threshold)
Definition layer.hpp:2151
torch::nn::functional::SoftplusFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2169
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softplus") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2198
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softplus") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2183
~Softplus() override=default
Softshrink activation function.
Definition layer.hpp:2228
const torch::nn::functional::SoftshrinkFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2245
Softshrink(double lambda)
Definition layer.hpp:2233
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2264
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2256
Softshrink(torch::nn::functional::SoftshrinkFuncOptions options={})
Definition layer.hpp:2230
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2240
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2277
torch::nn::functional::SoftshrinkFuncOptions options_
Definition layer.hpp:2292
torch::nn::functional::SoftshrinkFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2250
~Softshrink() override=default
Softsign activation function.
Definition layer.hpp:2300
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2303
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="softsign") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2327
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2309
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="softsign") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2316
Tanh activation function.
Definition layer.hpp:2344
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanh") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2360
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2347
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanh") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2370
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2353
Tanhshrink activation function.
Definition layer.hpp:2387
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2390
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="tanhshrink") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2414
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2396
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="tanhshrink") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2403
Threshold activation function.
Definition layer.hpp:2435
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the activation function.
Definition layer.hpp:2463
const torch::nn::functional::ThresholdFuncOptions & options() const
Returns constant reference to options.
Definition layer.hpp:2452
~Threshold() override=default
Threshold(double threshold, double value, bool inplace=false)
Definition layer.hpp:2440
torch::nn::functional::ThresholdFuncOptions & options()
Returns non-constant reference to options.
Definition layer.hpp:2457
torch::Tensor apply(const torch::Tensor &input) const override
Applies the activation function to the given input.
Definition layer.hpp:2447
torch::serialize::InputArchive & read(torch::serialize::InputArchive &archive, const std::string &key="threshold") override
Reads the activation function from a torch::serialize::InputArchive object.
Definition layer.hpp:2490
Threshold(torch::nn::functional::ThresholdFuncOptions options)
Definition layer.hpp:2437
torch::serialize::OutputArchive & write(torch::serialize::OutputArchive &archive, const std::string &key="threshold") const override
Writes the activation function into a torch::serialize::OutputArchive object.
Definition layer.hpp:2473
torch::nn::functional::ThresholdFuncOptions options_
Definition layer.hpp:2509
Full qualified name descriptor.
Definition fqn.hpp:26
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:31
Core components.
Full qualified name utility functions.
Definition boundary.hpp:22
bool is_verbose(std::ostream &os)
Definition core.hpp:831
constexpr bool is_SplineType_v
Alias to the value of is_SplineType.
Definition bspline.hpp:3243
std::ostream & operator<<(std::ostream &os, const Boundary< Spline > &obj)
Print (as string) a Boundary object.
Definition boundary.hpp:1978
struct iganet::@0 Log
Logger.
@ none
Definition boundary.hpp:38
activation
Enumerator for nonlinear activation functions.
Definition layer.hpp:23
short int short_t
Definition core.hpp:74
STL namespace.