IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
options.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core/core.hpp>
18#include <utils/fqn.hpp>
19#include <utils/getenv.hpp>
20
21namespace iganet {
22
23struct half {};
24
26template <typename T>
27concept DType =
28 std::is_same_v<T, bool> || std::is_same_v<T, char> ||
29 std::is_same_v<T, short> || std::is_same_v<T, int> ||
30 std::is_same_v<T, long> || std::is_same_v<T, long long> ||
31 std::is_same_v<T, half> || std::is_same_v<T, float> ||
32 std::is_same_v<T, double> || std::is_same_v<T, std::complex<half>> ||
33 std::is_same_v<T, std::complex<float>> ||
34 std::is_same_v<T, std::complex<double>>;
35
42template <typename T>
43 requires DType<T>
44inline constexpr torch::Dtype dtype();
45
46template <> inline constexpr torch::Dtype dtype<bool>() { return torch::kBool; }
47
48template <> inline constexpr torch::Dtype dtype<char>() { return torch::kChar; }
49
50template <> inline constexpr torch::Dtype dtype<short>() {
51 return torch::kShort;
52}
53
54template <> inline constexpr torch::Dtype dtype<int>() { return torch::kInt; }
55
56template <> inline constexpr torch::Dtype dtype<long>() { return torch::kLong; }
57
58template <> inline constexpr torch::Dtype dtype<long long>() {
59 return torch::kLong;
60}
61
62template <> inline constexpr torch::Dtype dtype<half>() { return torch::kHalf; }
63
64template <> inline constexpr torch::Dtype dtype<float>() {
65 return torch::kFloat;
66}
67
68template <> inline constexpr torch::Dtype dtype<double>() {
69 return torch::kDouble;
70}
71
72template <> inline constexpr torch::Dtype dtype<std::complex<half>>() {
73 return at::kComplexHalf;
74}
75
76template <> inline constexpr torch::Dtype dtype<std::complex<float>>() {
77 return at::kComplexFloat;
78}
79
80template <> inline constexpr torch::Dtype dtype<std::complex<double>>() {
81 return at::kComplexDouble;
82}
84
85inline int guess_device_index() {
86#ifdef IGANET_WITH_MPI
87 int rank;
88 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
89 return rank %
90 utils::getenv("IGANET_DEVICE_COUNT", (torch::cuda::is_available()
91 ? torch::cuda::device_count()
92 : (torch::xpu::is_available() ? torch::xpu::device_count() ? 1)));
93#else
94 return 0;
95#endif
96}
97
102template <typename real_t>
103 requires DType<real_t>
105public:
108 : options_(
110 .dtype(::iganet::dtype<real_t>())
111 .device_index(utils::getenv("IGANET_DEVICE_INDEX",
113 .device(
114 (utils::getenv("IGANET_DEVICE", std::string{}) == "CPU")
115 ? torch::kCPU
116 : (utils::getenv("IGANET_DEVICE", std::string{}) == "CUDA")
117 ? torch::kCUDA
118 : (utils::getenv("IGANET_DEVICE", std::string{}) == "HIP")
119 ? torch::kHIP
120 : (utils::getenv("IGANET_DEVICE", std::string{}) == "MPS")
121 ? torch::kMPS
122 : (utils::getenv("IGANET_DEVICE", std::string{}) == "XLA")
123 ? torch::kXLA
124 : (utils::getenv("IGANET_DEVICE", std::string{}) == "XPU")
125 ? torch::kXPU
126 : (torch::cuda::is_available()
127 ? torch::kCUDA
128 : (torch::xpu::is_available() ? torch::kXPU
129 : torch::kCPU)))) {
130 }
131
133 explicit Options(torch::TensorOptions &&options)
134 : options_(options.dtype(::iganet::dtype<real_t>())) {}
135
141 operator torch::TensorOptions() const { return options_; }
142
144 inline torch::Device device() const noexcept { return options_.device(); }
145
147 inline int32_t device_index() const noexcept {
148 return options_.device_index();
149 }
150
152 static inline torch::Dtype dtype() noexcept {
153 return ::iganet::dtype<real_t>();
154 }
155
157 inline torch::Layout layout() const noexcept { return options_.layout(); }
158
160 inline bool requires_grad() const noexcept {
161 return options_.requires_grad();
162 }
163
165 inline bool pinned_memory() const noexcept {
166 return options_.pinned_memory();
167 }
168
170 inline bool is_sparse() const noexcept { return options_.is_sparse(); }
171
173 inline Options<real_t> device(torch::Device device) const noexcept {
174 return Options(options_.device(device));
175 }
176
179 inline Options<real_t> device_index(int16_t device_index) const noexcept {
180 return Options(options_.device_index(device_index));
181 }
182
184 template <typename other_t> inline Options<other_t> dtype() const noexcept {
185 return Options<other_t>(options_.dtype(::iganet::dtype<other_t>()));
186 }
187
189 inline Options<real_t> layout(torch::Layout layout) const noexcept {
190 return Options(options_.layout(layout));
191 }
192
195 inline Options<real_t> requires_grad(bool requires_grad) const noexcept {
196 return Options(options_.requires_grad(requires_grad));
197 }
198
201 inline Options<real_t> pinned_memory(bool pinned_memory) const noexcept {
202 return Options(options_.pinned_memory(pinned_memory));
203 }
204
207 inline Options<real_t>
208 memory_format(torch::MemoryFormat memory_format) const noexcept {
209 return Options(options_.memory_format(memory_format));
210 }
211
213 using value_type = real_t;
214
216 inline void pretty_print(std::ostream &os) const noexcept override {
217 os << name() << "(\noptions = " << options_ << "\n)";
218 }
219
220private:
222 const torch::TensorOptions options_;
223};
224
226template <typename real_t>
227inline std::ostream &operator<<(std::ostream &os, const Options<real_t> &obj) {
228 obj.pretty_print(os);
229 return os;
230}
231
233template <typename real_t>
234class Options<Options<real_t>> : public Options<real_t> {
235 using Options<real_t>::Options;
236};
237
238} // namespace iganet
The Options class handles the automated determination of dtype from the template argument and the sel...
Definition options.hpp:104
Options< real_t > memory_format(torch::MemoryFormat memory_format) const noexcept
Returns a new Options object with the memory_format property as given.
Definition options.hpp:208
Options()
Default constructor.
Definition options.hpp:107
static torch::Dtype dtype() noexcept
Returns the dtype property.
Definition options.hpp:152
Options(torch::TensorOptions &&options)
Constructor from torch::TensorOptions.
Definition options.hpp:133
Options< real_t > requires_grad(bool requires_grad) const noexcept
Returns a new Options object with the requires_grad property as given.
Definition options.hpp:195
torch::Device device() const noexcept
Returns the device property.
Definition options.hpp:144
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the Options object.
Definition options.hpp:216
bool requires_grad() const noexcept
Returns the requires_grad property.
Definition options.hpp:160
bool is_sparse() const noexcept
Returns if the layout is sparse.
Definition options.hpp:170
operator torch::TensorOptions() const
Implicit conversion operator.
Definition options.hpp:141
Options< real_t > device(torch::Device device) const noexcept
Returns a new Options object with the device property as given.
Definition options.hpp:173
int32_t device_index() const noexcept
Returns the device_index property.
Definition options.hpp:147
real_t value_type
Data type.
Definition options.hpp:213
torch::Layout layout() const noexcept
Returns the layout property.
Definition options.hpp:157
Options< real_t > device_index(int16_t device_index) const noexcept
Returns a new Options object with the device_index property as given.
Definition options.hpp:179
Options< real_t > pinned_memory(bool pinned_memory) const noexcept
Returns a new Options object with the pinned_memory property as given.
Definition options.hpp:201
const torch::TensorOptions options_
Tensor options.
Definition options.hpp:222
Options< other_t > dtype() const noexcept
Returns a new Options object with the dtype property as given.
Definition options.hpp:184
Options< real_t > layout(torch::Layout layout) const noexcept
Returns a new Options object with the layout property as given.
Definition options.hpp:189
bool pinned_memory() const noexcept
Returns the pinned_memory property.
Definition options.hpp:165
Full qualified name descriptor.
Definition fqn.hpp:22
virtual const std::string & name() const noexcept
Returns the full qualified name of the object.
Definition fqn.hpp:28
Concept to identify template parameters that are acceptable as DTypes.
Definition options.hpp:27
Core components.
Full qualified name utility functions.
Environment utility function.
T getenv(std::string variable, const T &default_value)
Returns the value from an environment variable.
Definition getenv.hpp:24
Definition core.hpp:72
constexpr torch::Dtype dtype< char >()
Definition options.hpp:48
constexpr torch::Dtype dtype< float >()
Definition options.hpp:64
constexpr torch::Dtype dtype< bool >()
Definition options.hpp:46
constexpr torch::Dtype dtype< double >()
Definition options.hpp:68
int guess_device_index()
Definition options.hpp:85
constexpr torch::Dtype dtype< long long >()
Definition options.hpp:58
std::ostream & operator<<(std::ostream &os, const MemoryDebugger< id > &obj)
Print (as string) a memory debugger object.
Definition memory.hpp:125
constexpr torch::Dtype dtype< half >()
Definition options.hpp:62
constexpr torch::Dtype dtype< int >()
Definition options.hpp:54
constexpr torch::Dtype dtype()
constexpr torch::Dtype dtype< short >()
Definition options.hpp:50
constexpr torch::Dtype dtype< long >()
Definition options.hpp:56
Definition options.hpp:23
STL namespace.
Definition optimizer.hpp:61