IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
options.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core.hpp>
18#include <utils/fqn.hpp>
19#include <utils/getenv.hpp>
20
21#include <torch/torch.h>
22
23namespace iganet {
24
25struct half {};
26
28 template<typename T>
29 concept DType =
30 std::is_same_v<T, bool> ||
31 std::is_same_v<T, char> ||
32 std::is_same_v<T, short> ||
33 std::is_same_v<T, int> ||
34 std::is_same_v<T, long> ||
35 std::is_same_v<T, long long> ||
36 std::is_same_v<T, half> ||
37 std::is_same_v<T, float> ||
38 std::is_same_v<T, double> ||
39 std::is_same_v<T, std::complex<half>> ||
40 std::is_same_v<T, std::complex<float>> ||
41 std::is_same_v<T, std::complex<double>>;
42
49 template <typename T> requires DType<T> inline constexpr torch::Dtype dtype();
50
51template <> inline constexpr torch::Dtype dtype<bool>() { return torch::kBool; }
52
53template <> inline constexpr torch::Dtype dtype<char>() { return torch::kChar; }
54
55template <> inline constexpr torch::Dtype dtype<short>() {
56 return torch::kShort;
57}
58
59template <> inline constexpr torch::Dtype dtype<int>() { return torch::kInt; }
60
61 template <> inline constexpr torch::Dtype dtype<long>() { return torch::kLong; }
62
63 template <> inline constexpr torch::Dtype dtype<long long>() { return torch::kLong; }
64
65template <> inline constexpr torch::Dtype dtype<half>() { return torch::kHalf; }
66
67template <> inline constexpr torch::Dtype dtype<float>() {
68 return torch::kFloat;
69}
70
71template <> inline constexpr torch::Dtype dtype<double>() {
72 return torch::kDouble;
73}
74
75template <> inline constexpr torch::Dtype dtype<std::complex<half>>() {
76 return at::kComplexHalf;
77}
78
79template <> inline constexpr torch::Dtype dtype<std::complex<float>>() {
80 return at::kComplexFloat;
81}
82
83template <> inline constexpr torch::Dtype dtype<std::complex<double>>() {
84 return at::kComplexDouble;
85}
87
88inline int guess_device_index() {
89#ifdef IGANET_WITH_MPI
90 int rank;
91 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
92 return rank %
93 utils::getenv("IGANET_DEVICE_COUNT", (torch::cuda::is_available()
94 ? torch::cuda::device_count()
95 : (torch::xpu::is_available() ? torch::xpu::device_count() ? 1)));
96#else
97 return 0;
98#endif
99}
100
105 template <typename real_t>
106 requires DType<real_t>
108public:
111 : options_(
113 .dtype(::iganet::dtype<real_t>())
114 .device_index(utils::getenv("IGANET_DEVICE_INDEX",
116 .device(
117 (utils::getenv("IGANET_DEVICE", std::string{}) == "CPU")
118 ? torch::kCPU
119 : (utils::getenv("IGANET_DEVICE", std::string{}) == "CUDA")
120 ? torch::kCUDA
121 : (utils::getenv("IGANET_DEVICE", std::string{}) == "HIP")
122 ? torch::kHIP
123 : (utils::getenv("IGANET_DEVICE", std::string{}) == "MPS")
124 ? torch::kMPS
125 : (utils::getenv("IGANET_DEVICE", std::string{}) == "XLA")
126 ? torch::kXLA
127 : (utils::getenv("IGANET_DEVICE", std::string{}) == "XPU")
128 ? torch::kXPU
129 : (torch::cuda::is_available() ? torch::kCUDA
130 : (torch::xpu::is_available() ? torch::kXPU : torch::kCPU)))) {}
131
133 explicit Options(torch::TensorOptions &&options)
134 : options_(options.dtype(::iganet::dtype<real_t>())) {}
135
137 operator torch::TensorOptions() const { return options_; }
138
140 inline torch::Device device() const noexcept { return options_.device(); }
141
143 inline int32_t device_index() const noexcept { return options_.device_index(); }
144
146 inline torch::Dtype dtype() const noexcept { return ::iganet::dtype<real_t>(); }
147
149 inline torch::Layout layout() const noexcept { return options_.layout(); }
150
152 inline bool requires_grad() const noexcept { return options_.requires_grad(); }
153
155 inline bool pinned_memory() const noexcept { return options_.pinned_memory(); }
156
158 inline bool is_sparse() const noexcept { return options_.is_sparse(); }
159
161 inline Options<real_t> device(torch::Device device) const noexcept {
162 return Options(options_.device(device));
163 }
164
167 inline Options<real_t> device_index(int16_t device_index) const noexcept {
168 return Options(options_.device_index(device_index));
169 }
170
172 template <typename other_t> inline Options<other_t> dtype() const noexcept {
173 return Options<other_t>(options_.dtype(::iganet::dtype<other_t>()));
174 }
175
177 inline Options<real_t> layout(torch::Layout layout) const noexcept {
178 return Options(options_.layout(layout));
179 }
180
183 inline Options<real_t> requires_grad(bool requires_grad) const noexcept {
184 return Options(options_.requires_grad(requires_grad));
185 }
186
189 inline Options<real_t> pinned_memory(bool pinned_memory) const noexcept {
190 return Options(options_.pinned_memory(pinned_memory));
191 }
192
195 inline Options<real_t>
196 memory_format(torch::MemoryFormat memory_format) const noexcept {
197 return Options(options_.memory_format(memory_format));
198 }
199
201 using value_type = real_t;
202
204 inline virtual void
205 pretty_print(std::ostream &os = Log(log::info)) const noexcept override {
206 os << name() << "(\noptions = " << options_ << "\n)";
207 }
208
209private:
211 const torch::TensorOptions options_;
212};
213
215template <typename real_t>
216inline std::ostream &operator<<(std::ostream &os, const Options<real_t> &obj) {
217 obj.pretty_print(os);
218 return os;
219}
220
222template <typename real_t>
223class Options<Options<real_t>> : public Options<real_t> {
224 using Options<real_t>::Options;
225};
226
227} // namespace iganet
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:107
Options< real_t > memory_format(torch::MemoryFormat memory_format) const noexcept
Returns a new Options object with the memory_format property as given.
Definition options.hpp:196
Options()
Default constructor.
Definition options.hpp:110
Options(torch::TensorOptions &&options)
Constructor from torch::TensorOptions.
Definition options.hpp:133
Options< real_t > requires_grad(bool requires_grad) const noexcept
Returns a new Options object with the requires_grad property as given.
Definition options.hpp:183
torch::Device device() const noexcept
Returns the device property.
Definition options.hpp:140
bool requires_grad() const noexcept
Returns the requires_grad property.
Definition options.hpp:152
bool is_sparse() const noexcept
Returns if the layout is sparse.
Definition options.hpp:158
operator torch::TensorOptions() const
Implicit conversion operator.
Definition options.hpp:137
Options< real_t > device(torch::Device device) const noexcept
Returns a new Options object with the device property as given.
Definition options.hpp:161
int32_t device_index() const noexcept
Returns the device_index property.
Definition options.hpp:143
real_t value_type
Data type.
Definition options.hpp:201
torch::Layout layout() const noexcept
Returns the layout property.
Definition options.hpp:149
Options< real_t > device_index(int16_t device_index) const noexcept
Returns a new Options object with the device_index property as given.
Definition options.hpp:167
Options< real_t > pinned_memory(bool pinned_memory) const noexcept
Returns a new Options object with the pinned_memory property as given.
Definition options.hpp:189
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the Options object.
Definition options.hpp:205
const torch::TensorOptions options_
Tensor options.
Definition options.hpp:211
torch::Dtype dtype() const noexcept
Returns the dtype property.
Definition options.hpp:146
Options< other_t > dtype() const noexcept
Returns a new Options object with the dtype property as given.
Definition options.hpp:172
Options< real_t > layout(torch::Layout layout) const noexcept
Returns a new Options object with the layout property as given.
Definition options.hpp:177
bool pinned_memory() const noexcept
Returns the pinned_memory property.
Definition options.hpp:155
Full qualified name descriptor.
Definition fqn.hpp:26
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:31
Concept to identify template parameters that are acceptable as DTypes.
Definition options.hpp:29
Core components.
Full qualified name utility functions.
Environment utility function.
T getenv(std::string variable, const T &default_value)
Returns the value from an environment variable.
Definition getenv.hpp:24
Definition boundary.hpp:22
constexpr torch::Dtype dtype< char >()
Definition options.hpp:53
constexpr torch::Dtype dtype< float >()
Definition options.hpp:67
constexpr torch::Dtype dtype< bool >()
Definition options.hpp:51
constexpr torch::Dtype dtype< double >()
Definition options.hpp:71
int guess_device_index()
Definition options.hpp:88
constexpr torch::Dtype dtype< long long >()
Definition options.hpp:63
constexpr torch::Dtype dtype< half >()
Definition options.hpp:65
struct iganet::@0 Log
Logger.
constexpr torch::Dtype dtype< int >()
Definition options.hpp:59
constexpr torch::Dtype dtype()
constexpr torch::Dtype dtype< short >()
Definition options.hpp:55
constexpr torch::Dtype dtype< long >()
Definition options.hpp:61
std::ostream & operator<<(std::ostream &os, const Boundary< Spline > &obj)
Print (as string) a Boundary object.
Definition boundary.hpp:1963
Definition options.hpp:25
STL namespace.
Definition optimizer.hpp:62