IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
linalg.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <torch/torch.h>
18
19namespace iganet {
20namespace utils {
21
36template <short_t dim = 0, typename T0, typename T1>
37inline auto dotproduct(T0 &&t0, T1 &&t1) {
38 return torch::sum(torch::mul(t0, t1), dim);
39}
40
60template <short_t dim = 0, typename T0, typename T1>
61inline auto kronproduct(T0 &&t0, T1 &&t1) {
62 switch (t1.sizes().size()) {
63 case 1:
64 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
65 t1.repeat({t0.size(dim)}));
66 case 2:
67 if constexpr (dim == 0)
68 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
69 t1.repeat({t0.size(dim), 1}));
70 else if constexpr (dim == 1)
71 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
72 t1.repeat({1, t0.size(dim)}));
73 case 3:
74 if constexpr (dim == 0)
75 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
76 t1.repeat({t0.size(dim), 1, 1}));
77 else if constexpr (dim == 1)
78 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
79 t1.repeat({1, t0.size(dim), 1}));
80 else if constexpr (dim == 2)
81 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
82 t1.repeat({1, 1, t0.size(dim)}));
83 case 4:
84 if constexpr (dim == 0)
85 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
86 t1.repeat({t0.size(dim), 1, 1, 1}));
87 else if constexpr (dim == 1)
88 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
89 t1.repeat({1, t0.size(dim), 1, 1}));
90 else if constexpr (dim == 2)
91 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
92 t1.repeat({1, 1, t0.size(dim), 1}));
93 else if constexpr (dim == 3)
94 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
95 t1.repeat({1, 1, 1, t0.size(dim)}));
96 case 5:
97 if constexpr (dim == 0)
98 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
99 t1.repeat({t0.size(dim), 1, 1, 1, 1}));
100 else if constexpr (dim == 1)
101 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
102 t1.repeat({1, t0.size(dim), 1, 1, 1}));
103 else if constexpr (dim == 2)
104 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
105 t1.repeat({1, 1, t0.size(dim), 1, 1}));
106 else if constexpr (dim == 3)
107 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
108 t1.repeat({1, 1, 1, t0.size(dim), 1}));
109 else if constexpr (dim == 4)
110 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
111 t1.repeat({1, 1, 1, 1, t0.size(dim)}));
112 case 6:
113 if constexpr (dim == 0)
114 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
115 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1}));
116 else if constexpr (dim == 1)
117 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
118 t1.repeat({1, t0.size(dim), 1, 1, 1, 1}));
119 else if constexpr (dim == 2)
120 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
121 t1.repeat({1, 1, t0.size(dim), 1, 1, 1}));
122 else if constexpr (dim == 3)
123 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
124 t1.repeat({1, 1, 1, t0.size(dim), 1, 1}));
125 else if constexpr (dim == 4)
126 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
127 t1.repeat({1, 1, 1, 1, t0.size(dim), 1}));
128 else if constexpr (dim == 5)
129 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
130 t1.repeat({1, 1, 1, 1, 1, t0.size(dim)}));
131 case 7:
132 if constexpr (dim == 0)
133 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
134 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1}));
135 else if constexpr (dim == 1)
136 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
137 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1}));
138 else if constexpr (dim == 2)
139 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
140 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1}));
141 else if constexpr (dim == 3)
142 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
143 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1}));
144 else if constexpr (dim == 4)
145 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
146 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1}));
147 else if constexpr (dim == 5)
148 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
149 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1}));
150 else if constexpr (dim == 6)
151 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
152 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim)}));
153 case 8:
154 if constexpr (dim == 0)
155 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
156 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1, 1}));
157 else if constexpr (dim == 1)
158 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
159 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1, 1}));
160 else if constexpr (dim == 2)
161 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
162 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1, 1}));
163 else if constexpr (dim == 3)
164 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
165 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1, 1}));
166 else if constexpr (dim == 4)
167 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
168 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1, 1}));
169 else if constexpr (dim == 5)
170 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
171 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1, 1}));
172 else if constexpr (dim == 6)
173 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
174 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim), 1}));
175 else if constexpr (dim == 7)
176 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
177 t1.repeat({1, 1, 1, 1, 1, 1, 1, t0.size(dim)}));
178 default:
179 throw std::runtime_error("Unsupported tensor dimension");
180 }
181}
182
202template <short_t dim = 0, typename T, typename... Ts>
203inline auto kronproduct(T &&t, Ts &&...ts) {
204 return kronproduct<dim>(std::forward<T>(t),
205 kronproduct<dim>(std::forward<Ts>(ts)...));
206}
207
210template <typename T0, typename T1> inline auto kron(T0 &&t0, T1 &&t1) {
211 return torch::kron(std::forward<T0>(t0), std::forward<T1>(t1));
212}
213
214template <typename T, typename... Ts> inline auto kron(T &&t, Ts &&...ts) {
215 return kron(std::forward<T>(t), kron(std::forward<Ts>(ts)...));
216}
218
220template <typename T, std::size_t N>
221inline T prod(std::array<T, N> array, std::size_t start_index = 0,
222 std::size_t stop_index = N - 1) {
223 T result{1};
224
225 for (std::size_t i = start_index; i <= stop_index; ++i)
226 result *= array[i];
227
228 return result;
229}
230
232template <typename T, std::size_t N>
233inline T sum(std::array<T, N> array, std::size_t start_index = 0,
234 std::size_t stop_index = N - 1) {
235 T result{0};
236
237 for (std::size_t i = start_index; i <= stop_index; ++i)
238 result += array[i];
239
240 return result;
241}
242
243} // namespace utils
244} // namespace iganet
auto kron(T0 &&t0, T1 &&t1)
Computes the Kronecker-product between two or more tensors.
Definition linalg.hpp:210
auto kronproduct(T0 &&t0, T1 &&t1)
Computes the directional Kronecker-product between two tensors along the given dimension.
Definition linalg.hpp:61
T sum(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) sum of all std::array entries.
Definition linalg.hpp:233
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:221
auto dotproduct(T0 &&t0, T1 &&t1)
Computes the directional dot-product between two tensors with summation along the given dimension.
Definition linalg.hpp:37
Definition boundary.hpp:22
constexpr bool is_SplineType_v
Alias to the value of is_SplineType.
Definition bspline.hpp:3243
short int short_t
Definition core.hpp:74