IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
patch.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core.hpp>
18
19#include <utils/blocktensor.hpp>
20#include <utils/tensorarray.hpp>
21
22namespace iganet {
23
25template <typename real_t, short_t GeoDim, short_t ParDim> class BSplinePatch {
26public:
28 virtual torch::Device device() const noexcept = 0;
29
31 virtual int32_t device_index() const noexcept = 0;
32
34 virtual torch::Dtype dtype() const noexcept = 0;
35
37 virtual torch::Layout layout() const noexcept = 0;
38
40 virtual bool requires_grad() const noexcept = 0;
41
43 virtual bool pinned_memory() const noexcept = 0;
44
46 virtual bool is_sparse() const noexcept = 0;
47
49 virtual BSplinePatch &set_requires_grad(bool requires_grad) noexcept = 0;
50
51 // @brief Returns all coefficients as a single tensor
52 virtual torch::Tensor as_tensor() const noexcept = 0;
53
55 virtual BSplinePatch &from_tensor(const torch::Tensor &tensor) noexcept = 0;
56
59 virtual int64_t as_tensor_size() const noexcept = 0;
60
64 virtual utils::BlockTensor<torch::Tensor, 1, GeoDim>
65 eval_from_precomputed(const torch::Tensor &basfunc,
66 const torch::Tensor &coeff_indices, int64_t numeval,
67 torch::IntArrayRef sizes) const = 0;
68
69 virtual utils::BlockTensor<torch::Tensor, 1, GeoDim>
70 eval_from_precomputed(const utils::TensorArray<ParDim> &basfunc,
71 const torch::Tensor &coeff_indices, int64_t numeval,
72 torch::IntArrayRef sizes) const = 0;
74
76 virtual void
77 pretty_print(std::ostream &os = Log(log::info)) const noexcept = 0;
78};
79
81template <typename real_t, short_t GeoDim, short_t ParDim>
82inline std::ostream &
83operator<<(std::ostream &os, const BSplinePatch<real_t, GeoDim, ParDim> &obj) {
84 obj.pretty_print(os);
85 return os;
86}
87
88} // namespace iganet
Compile-time block tensor.
Abstract patch function base class.
Definition patch.hpp:25
virtual torch::Layout layout() const noexcept=0
Returns the layout property.
virtual bool pinned_memory() const noexcept=0
Returns the pinned_memory property.
virtual torch::Tensor as_tensor() const noexcept=0
virtual int32_t device_index() const noexcept=0
Returns the device_index property.
virtual bool requires_grad() const noexcept=0
Returns the requires_grad property.
virtual torch::Dtype dtype() const noexcept=0
Returns the dtype property.
virtual BSplinePatch & set_requires_grad(bool requires_grad) noexcept=0
Sets the B-spline object's requires_grad property.
virtual BSplinePatch & from_tensor(const torch::Tensor &tensor) noexcept=0
Sets all coefficients from a single tensor.
virtual bool is_sparse() const noexcept=0
Returns if the layout is sparse.
virtual int64_t as_tensor_size() const noexcept=0
Returns the size of the single tensor representation of all coefficients.
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept=0
Returns a string representation.
virtual utils::BlockTensor< torch::Tensor, 1, GeoDim > eval_from_precomputed(const torch::Tensor &basfunc, const torch::Tensor &coeff_indices, int64_t numeval, torch::IntArrayRef sizes) const =0
Returns the value of the spline function from precomputed basis function.
virtual torch::Device device() const noexcept=0
Returns the device property.
Core components.
Definition boundary.hpp:22
struct iganet::@0 Log
Logger.
log
Enumerator for specifying the logging level.
Definition core.hpp:90
short int short_t
Definition core.hpp:74
STL namespace.
Definition optimizer.hpp:62
TensorArray utility functions.