IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
patch.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core/core.hpp>
18
19#include <utils/blocktensor.hpp>
20#include <utils/tensorarray.hpp>
21
22namespace iganet {
23
24 namespace detail {
25
26 // @brief Concept to identify template parameters that have an
27 // as_tensor function
28 template <typename T>
29 concept HasAsTensor = requires(T a) {
30 { a.as_tensor() };
31 };
32
33 // @brief Concept to identify template parameters that have an
34 // as_tensor_size function
35 template <typename T>
36 concept HasAsTensorSize = requires(T a) {
37 { a.as_tensor_size() };
38 };
39
40 // @brief Concept to identify template parameters that have a
41 // from_tensor function
42 template <typename T>
43 concept HasFromTensor = requires(T a) {
44 { a.from_tensor() };
45 };
46
47 } // namespace detail
48
50template <typename real_t, short_t GeoDim, short_t ParDim> class BSplinePatch {
51public:
53 virtual ~BSplinePatch() = default;
54
56 virtual torch::Device device() const noexcept = 0;
57
59 virtual int32_t device_index() const noexcept = 0;
60
62 virtual torch::Dtype dtype() const noexcept = 0;
63
65 virtual torch::Layout layout() const noexcept = 0;
66
68 virtual bool requires_grad() const noexcept = 0;
69
71 virtual bool pinned_memory() const noexcept = 0;
72
74 virtual bool is_sparse() const noexcept = 0;
75
77 virtual BSplinePatch &set_requires_grad(bool requires_grad) noexcept = 0;
78
79 // @brief Returns all coefficients as a single tensor
80 virtual torch::Tensor as_tensor() const noexcept = 0;
81
83 virtual BSplinePatch &from_tensor(const torch::Tensor &tensor) noexcept = 0;
84
87 virtual int64_t as_tensor_size() const noexcept = 0;
88
92 virtual utils::BlockTensor<torch::Tensor, 1, GeoDim>
93 eval_from_precomputed(const torch::Tensor &basfunc,
94 const torch::Tensor &coeff_indices, int64_t numeval,
95 torch::IntArrayRef sizes) const = 0;
96
97 virtual utils::BlockTensor<torch::Tensor, 1, GeoDim>
98 eval_from_precomputed(const utils::TensorArray<ParDim> &basfunc,
99 const torch::Tensor &coeff_indices, int64_t numeval,
100 torch::IntArrayRef sizes) const = 0;
102
104 virtual void
105 pretty_print(std::ostream &os = Log(log::info)) const noexcept = 0;
106};
107
109template <typename real_t, short_t GeoDim, short_t ParDim>
110inline std::ostream &
111operator<<(std::ostream &os, const BSplinePatch<real_t, GeoDim, ParDim> &obj) {
112 obj.pretty_print(os);
113 return os;
114}
115
116} // namespace iganet
Compile-time block tensor.
Abstract patch function base class.
Definition patch.hpp:50
virtual torch::Layout layout() const noexcept=0
Returns the layout property.
virtual bool pinned_memory() const noexcept=0
Returns the pinned_memory property.
virtual torch::Tensor as_tensor() const noexcept=0
virtual int32_t device_index() const noexcept=0
Returns the device_index property.
virtual bool requires_grad() const noexcept=0
Returns the requires_grad property.
virtual torch::Dtype dtype() const noexcept=0
Returns the dtype property.
virtual BSplinePatch & set_requires_grad(bool requires_grad) noexcept=0
Sets the B-spline object's requires_grad property.
virtual BSplinePatch & from_tensor(const torch::Tensor &tensor) noexcept=0
Sets all coefficients from a single tensor.
virtual bool is_sparse() const noexcept=0
Returns if the layout is sparse.
virtual ~BSplinePatch()=default
Destructor.
virtual int64_t as_tensor_size() const noexcept=0
Returns the size of the single tensor representation of all coefficients.
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept=0
Returns a string representation.
virtual utils::BlockTensor< torch::Tensor, 1, GeoDim > eval_from_precomputed(const torch::Tensor &basfunc, const torch::Tensor &coeff_indices, int64_t numeval, torch::IntArrayRef sizes) const =0
Returns the value of the spline function from precomputed basis function.
virtual torch::Device device() const noexcept=0
Returns the device property.
Definition patch.hpp:29
Definition patch.hpp:36
Definition patch.hpp:43
Core components.
Definition core.hpp:72
struct iganet::@0 Log
Logger.
log
Enumerator for specifying the logging level.
Definition core.hpp:90
short int short_t
Definition core.hpp:74
STL namespace.
Definition optimizer.hpp:61
TensorArray utility functions.