IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
linalg.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <core/core.hpp>
18
19namespace iganet::utils {
20
35template <short_t dim = 0, typename T0, typename T1>
36inline auto dotproduct(T0 &&t0, T1 &&t1) {
37 return torch::sum(torch::mul(t0, t1), dim);
38}
39
59template <short_t dim = 0, typename T0, typename T1>
60inline auto kronproduct(T0 &&t0, T1 &&t1) {
61 switch (t1.sizes().size()) {
62 case 1:
63 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
64 t1.repeat({t0.size(dim)}));
65 case 2:
66 if constexpr (dim == 0)
67 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
68 t1.repeat({t0.size(dim), 1}));
69 else if constexpr (dim == 1 || dim == -1)
70 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
71 t1.repeat({1, t0.size(dim)}));
72 case 3:
73 if constexpr (dim == 0)
74 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
75 t1.repeat({t0.size(dim), 1, 1}));
76 else if constexpr (dim == 1)
77 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
78 t1.repeat({1, t0.size(dim), 1}));
79 else if constexpr (dim == 2 || dim == -1)
80 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
81 t1.repeat({1, 1, t0.size(dim)}));
82 case 4:
83 if constexpr (dim == 0)
84 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
85 t1.repeat({t0.size(dim), 1, 1, 1}));
86 else if constexpr (dim == 1)
87 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
88 t1.repeat({1, t0.size(dim), 1, 1}));
89 else if constexpr (dim == 2)
90 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
91 t1.repeat({1, 1, t0.size(dim), 1}));
92 else if constexpr (dim == 3 || dim == -1)
93 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
94 t1.repeat({1, 1, 1, t0.size(dim)}));
95 case 5:
96 if constexpr (dim == 0)
97 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
98 t1.repeat({t0.size(dim), 1, 1, 1, 1}));
99 else if constexpr (dim == 1)
100 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
101 t1.repeat({1, t0.size(dim), 1, 1, 1}));
102 else if constexpr (dim == 2)
103 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
104 t1.repeat({1, 1, t0.size(dim), 1, 1}));
105 else if constexpr (dim == 3)
106 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
107 t1.repeat({1, 1, 1, t0.size(dim), 1}));
108 else if constexpr (dim == 4 || dim == -1)
109 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
110 t1.repeat({1, 1, 1, 1, t0.size(dim)}));
111 case 6:
112 if constexpr (dim == 0)
113 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
114 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1}));
115 else if constexpr (dim == 1)
116 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
117 t1.repeat({1, t0.size(dim), 1, 1, 1, 1}));
118 else if constexpr (dim == 2)
119 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
120 t1.repeat({1, 1, t0.size(dim), 1, 1, 1}));
121 else if constexpr (dim == 3)
122 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
123 t1.repeat({1, 1, 1, t0.size(dim), 1, 1}));
124 else if constexpr (dim == 4)
125 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
126 t1.repeat({1, 1, 1, 1, t0.size(dim), 1}));
127 else if constexpr (dim == 5 || dim == -1)
128 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
129 t1.repeat({1, 1, 1, 1, 1, t0.size(dim)}));
130 case 7:
131 if constexpr (dim == 0)
132 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
133 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1}));
134 else if constexpr (dim == 1)
135 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
136 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1}));
137 else if constexpr (dim == 2)
138 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
139 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1}));
140 else if constexpr (dim == 3)
141 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
142 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1}));
143 else if constexpr (dim == 4)
144 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
145 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1}));
146 else if constexpr (dim == 5)
147 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
148 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1}));
149 else if constexpr (dim == 6 || dim == -1)
150 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
151 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim)}));
152 case 8:
153 if constexpr (dim == 0)
154 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
155 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1, 1}));
156 else if constexpr (dim == 1)
157 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
158 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1, 1}));
159 else if constexpr (dim == 2)
160 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
161 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1, 1}));
162 else if constexpr (dim == 3)
163 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
164 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1, 1}));
165 else if constexpr (dim == 4)
166 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
167 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1, 1}));
168 else if constexpr (dim == 5)
169 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
170 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1, 1}));
171 else if constexpr (dim == 6)
172 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
173 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim), 1}));
174 else if constexpr (dim == 7 || dim == -1)
175 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
176 t1.repeat({1, 1, 1, 1, 1, 1, 1, t0.size(dim)}));
177 default:
178 throw std::runtime_error("Unsupported tensor dimension");
179 }
180}
181
201template <short_t dim = 0, typename T, typename... Ts>
202inline auto kronproduct(T &&t, Ts &&...ts) {
203 return kronproduct<dim>(std::forward<T>(t),
204 kronproduct<dim>(std::forward<Ts>(ts)...));
205}
206
209template <typename T0, typename T1> inline auto kron(T0 &&t0, T1 &&t1) {
210 return torch::kron(std::forward<T0>(t0), std::forward<T1>(t1));
211}
212
213template <typename T, typename... Ts> inline auto kron(T &&t, Ts &&...ts) {
214 return kron(std::forward<T>(t), kron(std::forward<Ts>(ts)...));
215}
217
219template <typename T, std::size_t N>
220inline T prod(std::array<T, N> array, std::size_t start_index = 0,
221 std::size_t stop_index = N - 1) {
222 T result{1};
223
224 for (std::size_t i = start_index; i <= stop_index; ++i)
225 result *= array[i];
226
227 return result;
228}
229
231template <typename T, std::size_t N>
232inline T sum(std::array<T, N> array, std::size_t start_index = 0,
233 std::size_t stop_index = N - 1) {
234 T result{0};
235
236 for (std::size_t i = start_index; i <= stop_index; ++i)
237 result += array[i];
238
239 return result;
240}
241
242} // namespace iganet::utils
Core components.
Definition blocktensor.hpp:24
auto kron(T0 &&t0, T1 &&t1)
Computes the Kronecker-product between two or more tensors.
Definition linalg.hpp:209
auto kronproduct(T0 &&t0, T1 &&t1)
Computes the directional Kronecker-product between two tensors along the given dimension.
Definition linalg.hpp:60
T sum(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) sum of all std::array entries.
Definition linalg.hpp:232
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:220
auto dotproduct(T0 &&t0, T1 &&t1)
Computes the directional dot-product between two tensors with summation along the given dimension.
Definition linalg.hpp:36
short int short_t
Definition core.hpp:74