IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
blocktensor.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <array>
18#include <memory>
19#include <type_traits>
20
21#include <core/core.hpp>
22#include <utils/fqn.hpp>
23
24namespace iganet::utils {
25
28template <typename T> struct is_shared_ptr : std::false_type {};
29
30template <typename T>
31struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
33
35template <typename T> inline auto make_shared(T &&arg) {
36 if constexpr (is_shared_ptr<std::decay_t<T>>::value)
37 return std::forward<std::decay_t<T>>(arg);
38 else
39 return std::make_shared<std::decay_t<T>>(std::forward<T>(arg));
40}
41
43template <typename T, std::size_t... Dims> class BlockTensor;
44
46template <typename T, std::size_t... Dims>
48
49protected:
51 std::array<std::shared_ptr<T>, (Dims * ...)> data_;
52
53public:
55 BlockTensorCore() = default;
56
58 template <typename... Ts, std::size_t... dims>
60 auto it = data_.begin();
61 (std::transform(other.data().begin(), other.data().end(), it,
62 [&it]<typename D>(D &&d) {
63 ++it;
64 return std::forward<D>(d);
65 }),
66 ...);
67 }
68
70 template <typename... Ts, std::size_t... dims>
72 auto it = data_.begin();
73 (std::transform(other.data().begin(), other.data().end(), it,
74 [&it]<typename D>(D &&d) {
75 ++it;
76 return std::forward<D>(d);
77 }),
78 ...);
79 }
80
82 template <typename... Ts>
83 explicit BlockTensorCore(Ts &&...data)
84 : data_({make_shared<Ts>(std::forward<Ts>(data))...}) {}
85
87 inline static constexpr auto dims() {
88 return std::array<std::size_t, sizeof...(Dims)>({Dims...});
89 }
90
92 template <std::size_t i> inline static constexpr std::size_t dim() {
93 if constexpr (i < sizeof...(Dims))
94 return std::get<i>(std::forward_as_tuple(Dims...));
95 else
96 return 0;
97 }
98
100 inline static constexpr std::size_t size() { return sizeof...(Dims); }
101
103 inline static constexpr std::size_t entries() { return (Dims * ...); }
104
106 inline const std::array<std::shared_ptr<T>, (Dims * ...)> &data() const {
107 return data_;
108 }
109
111 inline std::array<std::shared_ptr<T>, (Dims * ...)> &data() { return data_; }
112
114 inline const std::shared_ptr<T> &operator[](std::size_t idx) const {
115 assert(idx < (Dims * ...));
116 return data_[idx];
117 }
118
120 inline std::shared_ptr<T> &operator[](std::size_t idx) {
121 assert(idx < (Dims * ...));
122 return data_[idx];
123 }
124
126 inline const T &operator()(std::size_t idx) const {
127 assert(idx < (Dims * ...));
128 return *data_[idx];
129 }
130
132 inline T &operator()(std::size_t idx) {
133 assert(idx < (Dims * ...));
134 return *data_[idx];
135 }
136
138 template <typename Data> inline T &set(std::size_t idx, Data &&data) {
139 assert(idx < (Dims * ...));
140 data_[idx] = make_shared<Data>(std::forward<Data>(data));
141 return *data_[idx];
142 }
143
145 inline void pretty_print(std::ostream &os) const noexcept override = 0;
146};
147
149template <typename T, std::size_t... Dims>
150inline std::ostream &operator<<(std::ostream &os,
151 const BlockTensorCore<T, Dims...> &obj) {
152 obj.pretty_print(os);
153 return os;
154}
155
157template <typename T, std::size_t Rows>
158class BlockTensor<T, Rows> : public BlockTensorCore<T, Rows> {
159private:
161
162public:
163 using BlockTensorCore<T, Rows>::BlockTensorCore;
164
166 inline static constexpr std::size_t rows() { return Rows; }
167
169 inline void pretty_print(std::ostream &os) const noexcept override {
170 os << Base::name() << "\n";
171 for (std::size_t row = 0; row < Rows; ++row)
172 os << "[" << row << "] = \n" << *Base::data_[row] << "\n";
173 }
174};
175
181template <typename T, std::size_t Rows, std::size_t Cols>
182class BlockTensor<T, Rows, Cols> : public BlockTensorCore<T, Rows, Cols> {
183private:
185
186public:
187 using BlockTensorCore<T, Rows, Cols>::BlockTensorCore;
188
190 inline static constexpr std::size_t rows() { return Rows; }
191
193 inline static constexpr std::size_t cols() { return Cols; }
194
195 using Base::operator();
196
198 inline const T &operator()(std::size_t row, std::size_t col) const {
199 assert(row < Rows && col < Cols);
200 return *Base::data_[Cols * row + col];
201 }
202
204 inline T &operator()(std::size_t row, std::size_t col) {
205 assert(row < Rows && col < Cols);
206 return *Base::data_[Cols * row + col];
207 }
208
209 using Base::set;
210
212 template <typename D>
213 inline T &set(std::size_t row, std::size_t col, D &&data) {
214 assert(row < Rows && col < Cols);
215 Base::data_[Cols * row + col] = make_shared<D>(std::forward<D>(data));
216 return *Base::data_[Cols * row + col];
217 }
218
220 inline auto tr() const {
222 for (std::size_t row = 0; row < Rows; ++row)
223 for (std::size_t col = 0; col < Cols; ++col)
224 result[Rows * col + row] = Base::data_[Cols * row + col];
225 return result;
226 }
227
231
232 inline auto det() const {
233 if constexpr (Rows == 1 && Cols == 1) {
234 auto result = *Base::data_[0];
235 return result;
236 } else if constexpr (Rows == 2 && Cols == 2) {
237 auto result = torch::mul(*Base::data_[0], *Base::data_[3]) -
238 torch::mul(*Base::data_[1], *Base::data_[2]);
239 return result;
240 } else if constexpr (Rows == 3 && Cols == 3) {
241 auto result =
242 torch::mul(*Base::data_[0],
243 torch::mul(*Base::data_[4], *Base::data_[8]) -
244 torch::mul(*Base::data_[5], *Base::data_[7])) -
245 torch::mul(*Base::data_[1],
246 torch::mul(*Base::data_[3], *Base::data_[8]) -
247 torch::mul(*Base::data_[5], *Base::data_[6])) +
248 torch::mul(*Base::data_[2],
249 torch::mul(*Base::data_[3], *Base::data_[7]) -
250 torch::mul(*Base::data_[4], *Base::data_[6]));
251 return result;
252 } else if constexpr (Rows == 4 && Cols == 4) {
253 auto a11 = torch::mul(*Base::data_[5],
254 (torch::mul(*Base::data_[10], *Base::data_[15]) -
255 torch::mul(*Base::data_[11], *Base::data_[14]))) -
256 torch::mul(*Base::data_[9],
257 (torch::mul(*Base::data_[6], *Base::data_[15]) -
258 torch::mul(*Base::data_[7], *Base::data_[14]))) -
259 torch::mul(*Base::data_[13],
260 (torch::mul(*Base::data_[7], *Base::data_[10]) -
261 torch::mul(*Base::data_[6], *Base::data_[11])));
262
263 auto a21 = torch::mul(*Base::data_[4],
264 (torch::mul(*Base::data_[11], *Base::data_[14]) -
265 torch::mul(*Base::data_[10], *Base::data_[15]))) -
266 torch::mul(*Base::data_[8],
267 (torch::mul(*Base::data_[7], *Base::data_[14]) -
268 torch::mul(*Base::data_[6], *Base::data_[15]))) -
269 torch::mul(*Base::data_[12],
270 (torch::mul(*Base::data_[6], *Base::data_[11]) -
271 torch::mul(*Base::data_[7], *Base::data_[10])));
272
273 auto a31 = torch::mul(*Base::data_[4],
274 (torch::mul(*Base::data_[9], *Base::data_[15]) -
275 torch::mul(*Base::data_[11], *Base::data_[13]))) -
276 torch::mul(*Base::data_[8],
277 (torch::mul(*Base::data_[5], *Base::data_[15]) -
278 torch::mul(*Base::data_[7], *Base::data_[13]))) -
279 torch::mul(*Base::data_[12],
280 (torch::mul(*Base::data_[7], *Base::data_[9]) -
281 torch::mul(*Base::data_[5], *Base::data_[11])));
282
283 auto a41 = torch::mul(*Base::data_[4],
284 (torch::mul(*Base::data_[10], *Base::data_[13]) -
285 torch::mul(*Base::data_[9], *Base::data_[14]))) -
286 torch::mul(*Base::data_[8],
287 (torch::mul(*Base::data_[6], *Base::data_[13]) -
288 torch::mul(*Base::data_[5], *Base::data_[14]))) -
289 torch::mul(*Base::data_[12],
290 (torch::mul(*Base::data_[5], *Base::data_[10]) -
291 torch::mul(*Base::data_[6], *Base::data_[9])));
292
293 auto result =
294 torch::mul(*Base::data_[0], a11) + torch::mul(*Base::data_[1], a21) +
295 torch::mul(*Base::data_[2], a31) + torch::mul(*Base::data_[3], a41);
296
297 return result;
298 } else {
299 throw std::runtime_error("Unsupported block tensor dimension");
300 return *this;
301 }
302 }
303
307 inline auto inv() const {
308
309 auto det_ = this->det();
310
311 if constexpr (Rows == 1 && Cols == 1) {
313 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
314 return result;
315 } else if constexpr (Rows == 2 && Cols == 2) {
316
318 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
319 result[1] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
320 result[2] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
321 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
322 return result;
323 } else if constexpr (Rows == 3 && Cols == 3) {
324
325 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
326 torch::mul(*Base::data_[5], *Base::data_[7]);
327 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
328 torch::mul(*Base::data_[1], *Base::data_[8]);
329 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
330 torch::mul(*Base::data_[2], *Base::data_[4]);
331 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
332 torch::mul(*Base::data_[3], *Base::data_[8]);
333 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
334 torch::mul(*Base::data_[2], *Base::data_[6]);
335 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
336 torch::mul(*Base::data_[0], *Base::data_[5]);
337 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
338 torch::mul(*Base::data_[4], *Base::data_[6]);
339 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
340 torch::mul(*Base::data_[0], *Base::data_[7]);
341 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
342 torch::mul(*Base::data_[1], *Base::data_[3]);
343
345 result[0] = std::make_shared<T>(torch::div(a11, det_));
346 result[1] = std::make_shared<T>(torch::div(a12, det_));
347 result[2] = std::make_shared<T>(torch::div(a13, det_));
348 result[3] = std::make_shared<T>(torch::div(a21, det_));
349 result[4] = std::make_shared<T>(torch::div(a22, det_));
350 result[5] = std::make_shared<T>(torch::div(a23, det_));
351 result[6] = std::make_shared<T>(torch::div(a31, det_));
352 result[7] = std::make_shared<T>(torch::div(a32, det_));
353 result[8] = std::make_shared<T>(torch::div(a33, det_));
354 return result;
355 } else if constexpr (Rows == 4 && Cols == 4) {
356 auto a11 = torch::mul(*Base::data_[5],
357 (torch::mul(*Base::data_[10], *Base::data_[15]) -
358 torch::mul(*Base::data_[11], *Base::data_[14]))) -
359 torch::mul(*Base::data_[9],
360 (torch::mul(*Base::data_[6], *Base::data_[15]) -
361 torch::mul(*Base::data_[7], *Base::data_[14]))) -
362 torch::mul(*Base::data_[13],
363 (torch::mul(*Base::data_[7], *Base::data_[10]) -
364 torch::mul(*Base::data_[6], *Base::data_[11])));
365
366 auto a12 = torch::mul(*Base::data_[1],
367 (torch::mul(*Base::data_[11], *Base::data_[14]) -
368 torch::mul(*Base::data_[10], *Base::data_[15]))) -
369 torch::mul(*Base::data_[9],
370 (torch::mul(*Base::data_[3], *Base::data_[14]) -
371 torch::mul(*Base::data_[2], *Base::data_[15]))) -
372 torch::mul(*Base::data_[13],
373 (torch::mul(*Base::data_[2], *Base::data_[11]) -
374 torch::mul(*Base::data_[3], *Base::data_[10])));
375
376 auto a13 = torch::mul(*Base::data_[1],
377 (torch::mul(*Base::data_[6], *Base::data_[15]) -
378 torch::mul(*Base::data_[7], *Base::data_[14]))) -
379 torch::mul(*Base::data_[5],
380 (torch::mul(*Base::data_[2], *Base::data_[15]) -
381 torch::mul(*Base::data_[3], *Base::data_[14]))) -
382 torch::mul(*Base::data_[13],
383 (torch::mul(*Base::data_[3], *Base::data_[6]) -
384 torch::mul(*Base::data_[2], *Base::data_[7])));
385
386 auto a14 = torch::mul(*Base::data_[1],
387 (torch::mul(*Base::data_[7], *Base::data_[10]) -
388 torch::mul(*Base::data_[6], *Base::data_[11]))) -
389 torch::mul(*Base::data_[5],
390 (torch::mul(*Base::data_[3], *Base::data_[10]) -
391 torch::mul(*Base::data_[2], *Base::data_[11]))) -
392 torch::mul(*Base::data_[9],
393 (torch::mul(*Base::data_[2], *Base::data_[7]) -
394 torch::mul(*Base::data_[3], *Base::data_[6])));
395
396 auto a21 = torch::mul(*Base::data_[4],
397 (torch::mul(*Base::data_[11], *Base::data_[14]) -
398 torch::mul(*Base::data_[10], *Base::data_[15]))) -
399 torch::mul(*Base::data_[8],
400 (torch::mul(*Base::data_[7], *Base::data_[14]) -
401 torch::mul(*Base::data_[6], *Base::data_[15]))) -
402 torch::mul(*Base::data_[12],
403 (torch::mul(*Base::data_[6], *Base::data_[11]) -
404 torch::mul(*Base::data_[7], *Base::data_[10])));
405
406 auto a22 = torch::mul(*Base::data_[0],
407 (torch::mul(*Base::data_[10], *Base::data_[15]) -
408 torch::mul(*Base::data_[11], *Base::data_[14]))) -
409 torch::mul(*Base::data_[8],
410 (torch::mul(*Base::data_[2], *Base::data_[15]) -
411 torch::mul(*Base::data_[3], *Base::data_[14]))) -
412 torch::mul(*Base::data_[12],
413 (torch::mul(*Base::data_[3], *Base::data_[10]) -
414 torch::mul(*Base::data_[2], *Base::data_[11])));
415
416 auto a23 = torch::mul(*Base::data_[0],
417 (torch::mul(*Base::data_[7], *Base::data_[14]) -
418 torch::mul(*Base::data_[6], *Base::data_[15]))) -
419 torch::mul(*Base::data_[4],
420 (torch::mul(*Base::data_[3], *Base::data_[14]) -
421 torch::mul(*Base::data_[2], *Base::data_[15]))) -
422 torch::mul(*Base::data_[12],
423 (torch::mul(*Base::data_[2], *Base::data_[7]) -
424 torch::mul(*Base::data_[3], *Base::data_[6])));
425
426 auto a24 = torch::mul(*Base::data_[0],
427 (torch::mul(*Base::data_[6], *Base::data_[11]) -
428 torch::mul(*Base::data_[7], *Base::data_[10]))) -
429 torch::mul(*Base::data_[4],
430 (torch::mul(*Base::data_[2], *Base::data_[11]) -
431 torch::mul(*Base::data_[3], *Base::data_[10]))) -
432 torch::mul(*Base::data_[8],
433 (torch::mul(*Base::data_[3], *Base::data_[6]) -
434 torch::mul(*Base::data_[2], *Base::data_[7])));
435
436 auto a31 = torch::mul(*Base::data_[4],
437 (torch::mul(*Base::data_[9], *Base::data_[15]) -
438 torch::mul(*Base::data_[11], *Base::data_[13]))) -
439 torch::mul(*Base::data_[8],
440 (torch::mul(*Base::data_[5], *Base::data_[15]) -
441 torch::mul(*Base::data_[7], *Base::data_[13]))) -
442 torch::mul(*Base::data_[12],
443 (torch::mul(*Base::data_[7], *Base::data_[9]) -
444 torch::mul(*Base::data_[5], *Base::data_[11])));
445
446 auto a32 = torch::mul(*Base::data_[0],
447 (torch::mul(*Base::data_[11], *Base::data_[13]) -
448 torch::mul(*Base::data_[9], *Base::data_[15]))) -
449 torch::mul(*Base::data_[8],
450 (torch::mul(*Base::data_[3], *Base::data_[13]) -
451 torch::mul(*Base::data_[1], *Base::data_[15]))) -
452 torch::mul(*Base::data_[12],
453 (torch::mul(*Base::data_[1], *Base::data_[11]) -
454 torch::mul(*Base::data_[3], *Base::data_[9])));
455
456 auto a33 = torch::mul(*Base::data_[0],
457 (torch::mul(*Base::data_[5], *Base::data_[15]) -
458 torch::mul(*Base::data_[7], *Base::data_[13]))) -
459 torch::mul(*Base::data_[4],
460 (torch::mul(*Base::data_[1], *Base::data_[15]) -
461 torch::mul(*Base::data_[3], *Base::data_[13]))) -
462 torch::mul(*Base::data_[12],
463 (torch::mul(*Base::data_[3], *Base::data_[5]) -
464 torch::mul(*Base::data_[1], *Base::data_[7])));
465
466 auto a34 = torch::mul(*Base::data_[0],
467 (torch::mul(*Base::data_[7], *Base::data_[9]) -
468 torch::mul(*Base::data_[5], *Base::data_[11]))) -
469 torch::mul(*Base::data_[4],
470 (torch::mul(*Base::data_[3], *Base::data_[9]) -
471 torch::mul(*Base::data_[1], *Base::data_[11]))) -
472 torch::mul(*Base::data_[8],
473 (torch::mul(*Base::data_[1], *Base::data_[7]) -
474 torch::mul(*Base::data_[3], *Base::data_[5])));
475
476 auto a41 = torch::mul(*Base::data_[4],
477 (torch::mul(*Base::data_[10], *Base::data_[13]) -
478 torch::mul(*Base::data_[9], *Base::data_[14]))) -
479 torch::mul(*Base::data_[8],
480 (torch::mul(*Base::data_[6], *Base::data_[13]) -
481 torch::mul(*Base::data_[5], *Base::data_[14]))) -
482 torch::mul(*Base::data_[12],
483 (torch::mul(*Base::data_[5], *Base::data_[10]) -
484 torch::mul(*Base::data_[6], *Base::data_[9])));
485
486 auto a42 = torch::mul(*Base::data_[0],
487 (torch::mul(*Base::data_[9], *Base::data_[14]) -
488 torch::mul(*Base::data_[10], *Base::data_[13]))) -
489 torch::mul(*Base::data_[8],
490 (torch::mul(*Base::data_[1], *Base::data_[14]) -
491 torch::mul(*Base::data_[2], *Base::data_[13]))) -
492 torch::mul(*Base::data_[12],
493 (torch::mul(*Base::data_[2], *Base::data_[9]) -
494 torch::mul(*Base::data_[1], *Base::data_[10])));
495
496 auto a43 = torch::mul(*Base::data_[0],
497 (torch::mul(*Base::data_[6], *Base::data_[13]) -
498 torch::mul(*Base::data_[5], *Base::data_[14]))) -
499 torch::mul(*Base::data_[4],
500 (torch::mul(*Base::data_[2], *Base::data_[13]) -
501 torch::mul(*Base::data_[1], *Base::data_[14]))) -
502 torch::mul(*Base::data_[12],
503 (torch::mul(*Base::data_[1], *Base::data_[6]) -
504 torch::mul(*Base::data_[2], *Base::data_[5])));
505
506 auto a44 = torch::mul(*Base::data_[0],
507 (torch::mul(*Base::data_[5], *Base::data_[10]) -
508 torch::mul(*Base::data_[6], *Base::data_[9]))) -
509 torch::mul(*Base::data_[4],
510 (torch::mul(*Base::data_[1], *Base::data_[10]) -
511 torch::mul(*Base::data_[2], *Base::data_[9]))) -
512 torch::mul(*Base::data_[8],
513 (torch::mul(*Base::data_[2], *Base::data_[5]) -
514 torch::mul(*Base::data_[1], *Base::data_[6])));
516 result[0] = std::make_shared<T>(torch::div(a11, det_));
517 result[1] = std::make_shared<T>(torch::div(a12, det_));
518 result[2] = std::make_shared<T>(torch::div(a13, det_));
519 result[3] = std::make_shared<T>(torch::div(a14, det_));
520 result[4] = std::make_shared<T>(torch::div(a21, det_));
521 result[5] = std::make_shared<T>(torch::div(a22, det_));
522 result[6] = std::make_shared<T>(torch::div(a23, det_));
523 result[7] = std::make_shared<T>(torch::div(a24, det_));
524 result[8] = std::make_shared<T>(torch::div(a31, det_));
525 result[9] = std::make_shared<T>(torch::div(a32, det_));
526 result[10] = std::make_shared<T>(torch::div(a33, det_));
527 result[11] = std::make_shared<T>(torch::div(a34, det_));
528 result[12] = std::make_shared<T>(torch::div(a41, det_));
529 result[13] = std::make_shared<T>(torch::div(a42, det_));
530 result[14] = std::make_shared<T>(torch::div(a43, det_));
531 result[15] = std::make_shared<T>(torch::div(a44, det_));
532 return result;
533 } else {
534 throw std::runtime_error("Unsupported block tensor dimension");
535 return *this;
536 }
537 }
538
546 inline auto ginv() const {
547 if constexpr (Rows == Cols)
548 return this->inv();
549 else
550 // Compute the generalized inverse, i.e. (A^T A)^{-1} A^T
551 return (this->tr() * (*this)).inv() * this->tr();
552 }
553
559 inline auto invtr() const {
560
561 auto det_ = this->det();
562
563 if constexpr (Rows == 1 && Cols == 1) {
565 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
566 return result;
567 } else if constexpr (Rows == 2 && Cols == 2) {
568
570 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
571 result[1] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
572 result[2] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
573 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
574 return result;
575 } else if constexpr (Rows == 3 && Cols == 3) {
576
577 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
578 torch::mul(*Base::data_[5], *Base::data_[7]);
579 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
580 torch::mul(*Base::data_[1], *Base::data_[8]);
581 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
582 torch::mul(*Base::data_[2], *Base::data_[4]);
583 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
584 torch::mul(*Base::data_[3], *Base::data_[8]);
585 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
586 torch::mul(*Base::data_[2], *Base::data_[6]);
587 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
588 torch::mul(*Base::data_[0], *Base::data_[5]);
589 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
590 torch::mul(*Base::data_[4], *Base::data_[6]);
591 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
592 torch::mul(*Base::data_[0], *Base::data_[7]);
593 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
594 torch::mul(*Base::data_[1], *Base::data_[3]);
595
597 result[0] = std::make_shared<T>(torch::div(a11, det_));
598 result[1] = std::make_shared<T>(torch::div(a21, det_));
599 result[2] = std::make_shared<T>(torch::div(a31, det_));
600 result[3] = std::make_shared<T>(torch::div(a12, det_));
601 result[4] = std::make_shared<T>(torch::div(a22, det_));
602 result[5] = std::make_shared<T>(torch::div(a32, det_));
603 result[6] = std::make_shared<T>(torch::div(a13, det_));
604 result[7] = std::make_shared<T>(torch::div(a23, det_));
605 result[8] = std::make_shared<T>(torch::div(a33, det_));
606 return result;
607 } else if constexpr (Rows == 4 && Cols == 4) {
608
609 auto a11 = torch::mul(*Base::data_[5],
610 (torch::mul(*Base::data_[10], *Base::data_[15]) -
611 torch::mul(*Base::data_[11], *Base::data_[14]))) -
612 torch::mul(*Base::data_[9],
613 (torch::mul(*Base::data_[6], *Base::data_[15]) -
614 torch::mul(*Base::data_[7], *Base::data_[14]))) -
615 torch::mul(*Base::data_[13],
616 (torch::mul(*Base::data_[7], *Base::data_[10]) -
617 torch::mul(*Base::data_[6], *Base::data_[11])));
618
619 auto a12 = torch::mul(*Base::data_[1],
620 (torch::mul(*Base::data_[11], *Base::data_[14]) -
621 torch::mul(*Base::data_[10], *Base::data_[15]))) -
622 torch::mul(*Base::data_[9],
623 (torch::mul(*Base::data_[3], *Base::data_[14]) -
624 torch::mul(*Base::data_[2], *Base::data_[15]))) -
625 torch::mul(*Base::data_[13],
626 (torch::mul(*Base::data_[2], *Base::data_[11]) -
627 torch::mul(*Base::data_[3], *Base::data_[10])));
628
629 auto a13 = torch::mul(*Base::data_[1],
630 (torch::mul(*Base::data_[6], *Base::data_[15]) -
631 torch::mul(*Base::data_[7], *Base::data_[14]))) -
632 torch::mul(*Base::data_[5],
633 (torch::mul(*Base::data_[2], *Base::data_[15]) -
634 torch::mul(*Base::data_[3], *Base::data_[14]))) -
635 torch::mul(*Base::data_[13],
636 (torch::mul(*Base::data_[3], *Base::data_[6]) -
637 torch::mul(*Base::data_[2], *Base::data_[7])));
638
639 auto a14 = torch::mul(*Base::data_[1],
640 (torch::mul(*Base::data_[7], *Base::data_[10]) -
641 torch::mul(*Base::data_[6], *Base::data_[11]))) -
642 torch::mul(*Base::data_[5],
643 (torch::mul(*Base::data_[3], *Base::data_[10]) -
644 torch::mul(*Base::data_[2], *Base::data_[11]))) -
645 torch::mul(*Base::data_[9],
646 (torch::mul(*Base::data_[2], *Base::data_[7]) -
647 torch::mul(*Base::data_[3], *Base::data_[6])));
648
649 auto a21 = torch::mul(*Base::data_[4],
650 (torch::mul(*Base::data_[11], *Base::data_[14]) -
651 torch::mul(*Base::data_[10], *Base::data_[15]))) -
652 torch::mul(*Base::data_[8],
653 (torch::mul(*Base::data_[7], *Base::data_[14]) -
654 torch::mul(*Base::data_[6], *Base::data_[15]))) -
655 torch::mul(*Base::data_[12],
656 (torch::mul(*Base::data_[6], *Base::data_[11]) -
657 torch::mul(*Base::data_[7], *Base::data_[10])));
658
659 auto a22 = torch::mul(*Base::data_[0],
660 (torch::mul(*Base::data_[10], *Base::data_[15]) -
661 torch::mul(*Base::data_[11], *Base::data_[14]))) -
662 torch::mul(*Base::data_[8],
663 (torch::mul(*Base::data_[2], *Base::data_[15]) -
664 torch::mul(*Base::data_[3], *Base::data_[14]))) -
665 torch::mul(*Base::data_[12],
666 (torch::mul(*Base::data_[3], *Base::data_[10]) -
667 torch::mul(*Base::data_[2], *Base::data_[11])));
668
669 auto a23 = torch::mul(*Base::data_[0],
670 (torch::mul(*Base::data_[7], *Base::data_[14]) -
671 torch::mul(*Base::data_[6], *Base::data_[15]))) -
672 torch::mul(*Base::data_[4],
673 (torch::mul(*Base::data_[3], *Base::data_[14]) -
674 torch::mul(*Base::data_[2], *Base::data_[15]))) -
675 torch::mul(*Base::data_[12],
676 (torch::mul(*Base::data_[2], *Base::data_[7]) -
677 torch::mul(*Base::data_[3], *Base::data_[6])));
678
679 auto a24 = torch::mul(*Base::data_[0],
680 (torch::mul(*Base::data_[6], *Base::data_[11]) -
681 torch::mul(*Base::data_[7], *Base::data_[10]))) -
682 torch::mul(*Base::data_[4],
683 (torch::mul(*Base::data_[2], *Base::data_[11]) -
684 torch::mul(*Base::data_[3], *Base::data_[10]))) -
685 torch::mul(*Base::data_[8],
686 (torch::mul(*Base::data_[3], *Base::data_[6]) -
687 torch::mul(*Base::data_[2], *Base::data_[7])));
688
689 auto a31 = torch::mul(*Base::data_[4],
690 (torch::mul(*Base::data_[9], *Base::data_[15]) -
691 torch::mul(*Base::data_[11], *Base::data_[13]))) -
692 torch::mul(*Base::data_[8],
693 (torch::mul(*Base::data_[5], *Base::data_[15]) -
694 torch::mul(*Base::data_[7], *Base::data_[13]))) -
695 torch::mul(*Base::data_[12],
696 (torch::mul(*Base::data_[7], *Base::data_[9]) -
697 torch::mul(*Base::data_[5], *Base::data_[11])));
698
699 auto a32 = torch::mul(*Base::data_[0],
700 (torch::mul(*Base::data_[11], *Base::data_[13]) -
701 torch::mul(*Base::data_[9], *Base::data_[15]))) -
702 torch::mul(*Base::data_[8],
703 (torch::mul(*Base::data_[3], *Base::data_[13]) -
704 torch::mul(*Base::data_[1], *Base::data_[15]))) -
705 torch::mul(*Base::data_[12],
706 (torch::mul(*Base::data_[1], *Base::data_[11]) -
707 torch::mul(*Base::data_[3], *Base::data_[9])));
708
709 auto a33 = torch::mul(*Base::data_[0],
710 (torch::mul(*Base::data_[5], *Base::data_[15]) -
711 torch::mul(*Base::data_[7], *Base::data_[13]))) -
712 torch::mul(*Base::data_[4],
713 (torch::mul(*Base::data_[1], *Base::data_[15]) -
714 torch::mul(*Base::data_[3], *Base::data_[13]))) -
715 torch::mul(*Base::data_[12],
716 (torch::mul(*Base::data_[3], *Base::data_[5]) -
717 torch::mul(*Base::data_[1], *Base::data_[7])));
718
719 auto a34 = torch::mul(*Base::data_[0],
720 (torch::mul(*Base::data_[7], *Base::data_[9]) -
721 torch::mul(*Base::data_[5], *Base::data_[11]))) -
722 torch::mul(*Base::data_[4],
723 (torch::mul(*Base::data_[3], *Base::data_[9]) -
724 torch::mul(*Base::data_[1], *Base::data_[11]))) -
725 torch::mul(*Base::data_[8],
726 (torch::mul(*Base::data_[1], *Base::data_[7]) -
727 torch::mul(*Base::data_[3], *Base::data_[5])));
728
729 auto a41 = torch::mul(*Base::data_[4],
730 (torch::mul(*Base::data_[10], *Base::data_[13]) -
731 torch::mul(*Base::data_[9], *Base::data_[14]))) -
732 torch::mul(*Base::data_[8],
733 (torch::mul(*Base::data_[6], *Base::data_[13]) -
734 torch::mul(*Base::data_[5], *Base::data_[14]))) -
735 torch::mul(*Base::data_[12],
736 (torch::mul(*Base::data_[5], *Base::data_[10]) -
737 torch::mul(*Base::data_[6], *Base::data_[9])));
738
739 auto a42 = torch::mul(*Base::data_[0],
740 (torch::mul(*Base::data_[9], *Base::data_[14]) -
741 torch::mul(*Base::data_[10], *Base::data_[13]))) -
742 torch::mul(*Base::data_[8],
743 (torch::mul(*Base::data_[1], *Base::data_[14]) -
744 torch::mul(*Base::data_[2], *Base::data_[13]))) -
745 torch::mul(*Base::data_[12],
746 (torch::mul(*Base::data_[2], *Base::data_[9]) -
747 torch::mul(*Base::data_[1], *Base::data_[10])));
748
749 auto a43 = torch::mul(*Base::data_[0],
750 (torch::mul(*Base::data_[6], *Base::data_[13]) -
751 torch::mul(*Base::data_[5], *Base::data_[14]))) -
752 torch::mul(*Base::data_[4],
753 (torch::mul(*Base::data_[2], *Base::data_[13]) -
754 torch::mul(*Base::data_[1], *Base::data_[14]))) -
755 torch::mul(*Base::data_[12],
756 (torch::mul(*Base::data_[1], *Base::data_[6]) -
757 torch::mul(*Base::data_[2], *Base::data_[5])));
758
759 auto a44 = torch::mul(*Base::data_[0],
760 (torch::mul(*Base::data_[5], *Base::data_[10]) -
761 torch::mul(*Base::data_[6], *Base::data_[9]))) -
762 torch::mul(*Base::data_[4],
763 (torch::mul(*Base::data_[1], *Base::data_[10]) -
764 torch::mul(*Base::data_[2], *Base::data_[9]))) -
765 torch::mul(*Base::data_[8],
766 (torch::mul(*Base::data_[2], *Base::data_[5]) -
767 torch::mul(*Base::data_[1], *Base::data_[6])));
768
770 result[0] = std::make_shared<T>(torch::div(a11, det_));
771 result[1] = std::make_shared<T>(torch::div(a21, det_));
772 result[2] = std::make_shared<T>(torch::div(a31, det_));
773 result[3] = std::make_shared<T>(torch::div(a41, det_));
774 result[4] = std::make_shared<T>(torch::div(a12, det_));
775 result[5] = std::make_shared<T>(torch::div(a22, det_));
776 result[6] = std::make_shared<T>(torch::div(a32, det_));
777 result[7] = std::make_shared<T>(torch::div(a42, det_));
778 result[8] = std::make_shared<T>(torch::div(a13, det_));
779 result[9] = std::make_shared<T>(torch::div(a23, det_));
780 result[10] = std::make_shared<T>(torch::div(a33, det_));
781 result[11] = std::make_shared<T>(torch::div(a43, det_));
782 result[12] = std::make_shared<T>(torch::div(a14, det_));
783 result[13] = std::make_shared<T>(torch::div(a24, det_));
784 result[14] = std::make_shared<T>(torch::div(a34, det_));
785 result[15] = std::make_shared<T>(torch::div(a44, det_));
786 return result;
787 } else {
788 throw std::runtime_error("Unsupported block tensor dimension");
789 return *this;
790 }
791 }
792
802 inline auto ginvtr() const {
803 if constexpr (Rows == Cols)
804 return this->invtr();
805 else
806 // Compute the transpose of the generalized inverse, i.e. A (A^T A)^{-T}
807 return (*this) * (this->tr() * (*this)).invtr();
808 }
809
811 inline auto trace() const {
812 static_assert(Rows == Cols, "trace(.) requires square block tensor");
813
814 if constexpr (Rows == 1)
815 return BlockTensor<T, 1, 1>(*Base::data_[0]);
816
817 else if constexpr (Rows == 2)
818 return BlockTensor<T, 1, 1>(*Base::data_[0] + *Base::data_[3]);
819
820 else if constexpr (Rows == 3)
821 return BlockTensor<T, 1, 1>(*Base::data_[0] + *Base::data_[4] +
822 *Base::data_[8]);
823
824 else if constexpr (Rows == 4)
825 return BlockTensor<T, 1, 1>(*Base::data_[0] + *Base::data_[5] +
826 *Base::data_[10] + *Base::data_[15]);
827
828 else
829 throw std::runtime_error("Unsupported block tensor dimension");
830 }
831
832private:
834 template <std::size_t... Is>
835 inline auto norm_(std::index_sequence<Is...>) const {
836 return torch::sqrt(
837 std::apply([](const auto &...tensors) { return (tensors + ...); },
838 std::make_tuple(std::get<Is>(Base::data_)->square()...)));
839 }
840
841public:
843 inline auto norm() const {
845 std::make_shared<T>(norm_(std::make_index_sequence<Rows * Cols>{})));
846 }
847
848private:
850 template <std::size_t... Is>
851 inline auto normalize_(std::index_sequence<Is...> is) const {
852 auto n_ = norm_(is);
854 std::make_shared<T>(*std::get<Is>(Base::data_) / n_)...);
855 }
856
857public:
859 inline auto normalize() const {
860 return normalize_(std::make_index_sequence<Rows * Cols>{});
861 }
862
863private:
865 template <std::size_t... Is>
866 inline auto dot_(std::index_sequence<Is...>,
867 const BlockTensor<T, Rows, Cols> &other) const {
868 return std::apply(
869 [](const auto &...tensors) { return (tensors + ...); },
870 std::make_tuple(torch::mul(*std::get<Is>(Base::data_),
871 *std::get<Is>(other.data_))...));
872 }
873
874public:
876 inline auto dot(const BlockTensor<T, Rows, Cols> &other) const {
877 return BlockTensor<T, 1, 1>(std::make_shared<T>(
878 dot_(std::make_index_sequence<Rows * Cols>{}, other)));
879 }
880
882 inline void pretty_print(std::ostream &os) const noexcept override {
883 os << Base::name() << "\n";
884 for (std::size_t row = 0; row < Rows; ++row)
885 for (std::size_t col = 0; col < Cols; ++col)
886 os << "[" << row << "," << col << "] = \n"
887 << *Base::data_[Cols * row + col] << "\n";
888 }
889};
890
893template <typename T, typename U, std::size_t Rows, std::size_t Common,
894 std::size_t Cols>
896 const BlockTensor<U, Common, Cols> &rhs) {
897 BlockTensor<std::common_type_t<T, U>, Rows, Cols> result;
898 for (std::size_t row = 0; row < Rows; ++row)
899 for (std::size_t col = 0; col < Cols; ++col) {
900 T tmp =
901 (lhs[Common * row]->dim() > rhs[col]->dim()
902 ? torch::mul(*lhs[Common * row], rhs[col]->unsqueeze(-1))
903 : (lhs[Common * row]->dim() < rhs[col]->dim()
904 ? torch::mul(lhs[Common * row]->unsqueeze(-1), *rhs[col])
905 : torch::mul(*lhs[Common * row], *rhs[col])));
906 for (std::size_t idx = 1; idx < Common; ++idx)
907 tmp += (lhs[Common * row]->dim() > rhs[col]->dim()
908 ? torch::mul(*lhs[Common * row + idx],
909 rhs[Cols * idx + col]->unsqueeze(-1))
910 : (lhs[Common * row]->dim() < rhs[col]->dim()
911 ? torch::mul(lhs[Common * row + idx]->unsqueeze(-1),
912 *rhs[Cols * idx + col])
913 : torch::mul(*lhs[Common * row + idx],
914 *rhs[Cols * idx + col])));
915 result[Cols * row + col] = std::make_shared<T>(tmp);
916 }
917 return result;
918}
919
926template <typename T, std::size_t Rows, std::size_t Cols, std::size_t Slices>
927class BlockTensor<T, Rows, Cols, Slices>
928 : public BlockTensorCore<T, Rows, Cols, Slices> {
929private:
931
932public:
933 using BlockTensorCore<T, Rows, Cols, Slices>::BlockTensorCore;
934
936 inline static constexpr std::size_t rows() { return Rows; }
937
939 inline static constexpr std::size_t cols() { return Cols; }
940
942 inline static constexpr std::size_t slices() { return Slices; }
943
944 using Base::operator();
945
947 inline const T &operator()(std::size_t row, std::size_t col,
948 std::size_t slice) const {
949 assert(row < Rows && col < Cols && slice < Slices);
950 return *Base::data_[Rows * Cols * slice + Cols * row + col];
951 }
952
954 inline T &operator()(std::size_t row, std::size_t col, std::size_t slice) {
955 assert(row < Rows && col < Cols && slice < Slices);
956 return *Base::data_[Rows * Cols * slice + Cols * row + col];
957 }
958
959 using Base::set;
960
962 template <typename D>
963 inline T &set(std::size_t row, std::size_t col, std::size_t slice, D &&data) {
964 Base::data_[Rows * Cols * slice + Cols * row + col] =
965 make_shared<D>(std::forward<D>(data));
966 return *Base::data_[Rows * Cols * slice + Cols * row + col];
967 }
968
970 inline auto slice(std::size_t slice) const {
971 assert(slice < Slices);
973 for (std::size_t row = 0; row < Rows; ++row)
974 for (std::size_t col = 0; col < Cols; ++col)
975 result[Cols * row + col] =
976 Base::data_[Rows * Cols * slice + Cols * row + col];
977 return result;
978 }
979
982 inline auto reorder_ikj() const {
984 for (std::size_t slice = 0; slice < Slices; ++slice)
985 for (std::size_t row = 0; row < Rows; ++row)
986 for (std::size_t col = 0; col < Cols; ++col)
987 result[Rows * Slices * col + Slices * row + slice] =
988 Base::data_[Rows * Cols * slice + Cols * row + col];
989 return result;
990 }
991
995 inline auto reorder_jik() const {
997 for (std::size_t slice = 0; slice < Slices; ++slice)
998 for (std::size_t row = 0; row < Rows; ++row)
999 for (std::size_t col = 0; col < Cols; ++col)
1000 result[Rows * Cols * slice + Rows * col + row] =
1001 Base::data_[Rows * Cols * slice + Cols * row + col];
1002 return result;
1003 }
1004
1007 inline auto reorder_kji() const {
1009 for (std::size_t slice = 0; slice < Slices; ++slice)
1010 for (std::size_t row = 0; row < Rows; ++row)
1011 for (std::size_t col = 0; col < Cols; ++col)
1012 result[Slices * Cols * row + Cols * slice + col] =
1013 Base::data_[Rows * Cols * slice + Cols * row + col];
1014 return result;
1015 }
1016
1019 inline auto reorder_kij() const {
1021 for (std::size_t slice = 0; slice < Slices; ++slice)
1022 for (std::size_t row = 0; row < Rows; ++row)
1023 for (std::size_t col = 0; col < Cols; ++col)
1024 result[Slices * Rows * col + Rows * slice + row] =
1025 Base::data_[Rows * Cols * slice + Cols * row + col];
1026 return result;
1027 }
1028
1030 inline void pretty_print(std::ostream &os) const noexcept override {
1031 os << Base::name() << "\n";
1032 for (std::size_t slice = 0; slice < Slices; ++slice)
1033 for (std::size_t row = 0; row < Rows; ++row)
1034 for (std::size_t col = 0; col < Cols; ++col)
1035 os << "[" << row << "," << col << "," << slice << "] = \n"
1036 << *Base::data_[Rows * Cols * slice + Cols * row + col] << "\n";
1037 }
1038};
1039
1042template <typename T, typename U, std::size_t Rows, std::size_t Common,
1043 std::size_t Cols, std::size_t Slices>
1046 BlockTensor<std::common_type_t<T, U>, Rows, Cols, Slices> result;
1047 for (std::size_t slice = 0; slice < Slices; ++slice)
1048 for (std::size_t row = 0; row < Rows; ++row)
1049 for (std::size_t col = 0; col < Cols; ++col) {
1050 T tmp =
1051 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1052 ? torch::mul(*lhs[Common * row],
1053 rhs[Rows * Cols * slice + col]->unsqueeze(-1))
1054 : (lhs[Common * row]->dim() <
1055 rhs[Rows * Cols * slice + col]->dim()
1056 ? torch::mul(lhs[Common * row]->unsqueeze(-1),
1057 *rhs[Rows * Cols * slice + col])
1058 : torch::mul(*lhs[Common * row],
1059 *rhs[Rows * Cols * slice + col])));
1060 for (std::size_t idx = 1; idx < Common; ++idx)
1061 tmp +=
1062 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1063 ? torch::mul(
1064 *lhs[Common * row + idx],
1065 rhs[Rows * Cols * slice + Cols * idx + col]->unsqueeze(
1066 -1))
1067 : (lhs[Common * row]->dim() <
1068 rhs[Rows * Cols * slice + col]->dim()
1069 ? torch::mul(
1070 lhs[Common * row + idx]->unsqueeze(-1),
1071 *rhs[Rows * Cols * slice + Cols * idx + col])
1072 : torch::mul(
1073 *lhs[Common * row + idx],
1074 *rhs[Rows * Cols * slice + Cols * idx + col])));
1075 result[Rows * Cols * slice + Cols * row + col] =
1076 std::make_shared<T>(tmp);
1077 }
1078 return result;
1079}
1080
1083template <typename T, typename U, std::size_t Rows, std::size_t Common,
1084 std::size_t Cols, std::size_t Slices>
1086 const BlockTensor<U, Common, Cols> &rhs) {
1087 BlockTensor<std::common_type_t<T, U>, Rows, Cols, Slices> result;
1088 for (std::size_t slice = 0; slice < Slices; ++slice)
1089 for (std::size_t row = 0; row < Rows; ++row)
1090 for (std::size_t col = 0; col < Cols; ++col) {
1091 T tmp =
1092 (lhs[Rows * Cols * slice + Common * row]->dim() > rhs[col]->dim()
1093 ? torch::mul(*lhs[Rows * Cols * slice + Common * row],
1094 rhs[col]->unsqueeze(-1))
1095 : (lhs[Rows * Cols * slice + Common * row]->dim() <
1096 rhs[col]->dim()
1097 ? torch::mul(lhs[Rows * Cols * slice + Common * row]
1098 ->unsqueeze(-1),
1099 *rhs[col])
1100 : torch::mul(*lhs[Rows * Cols * slice + Common * row],
1101 *rhs[col])));
1102 for (std::size_t idx = 1; idx < Common; ++idx)
1103 tmp +=
1104 (lhs[Rows * Cols * slice + Common * row + idx]->dim() >
1105 rhs[Cols * idx + col]->dim()
1106 ? torch::mul(*lhs[Rows * Cols * slice + Common * row + idx],
1107 rhs[Cols * idx + col])
1108 ->unsqueeze(-1)
1109 : (lhs[Rows * Cols * slice + Common * row + idx]->dim() <
1110 rhs[Cols * idx + col]->dim()
1111 ? torch::mul(
1112 lhs[Rows * Cols * slice + Common * row + idx]
1113 ->unsqueeze(-1),
1114 *rhs[Cols * idx + col])
1115 : torch::mul(
1116 *lhs[Rows * Cols * slice + Common * row + idx],
1117 *rhs[Cols * idx + col])));
1118 result[Rows * Cols * slice + Cols * row + col] =
1119 std::make_shared<T>(tmp);
1120 }
1121 return result;
1122}
1123
1124#define blocktensor_unary_op(name) \
1125 template <typename T, std::size_t... Dims> \
1126 inline auto name(const BlockTensor<T, Dims...> &input) { \
1127 BlockTensor<T, Dims...> result; \
1128 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1129 result[idx] = std::make_shared<T>(torch::name(*input[idx])); \
1130 return result; \
1131 }
1132
1133#define blocktensor_unary_special_op(name) \
1134 template <typename T, std::size_t... Dims> \
1135 inline auto name(const BlockTensor<T, Dims...> &input) { \
1136 BlockTensor<T, Dims...> result; \
1137 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1138 result[idx] = std::make_shared<T>(torch::special::name(*input[idx])); \
1139 return result; \
1140 }
1141
1142#define blocktensor_binary_op(name) \
1143 template <typename T, typename U, std::size_t... Dims> \
1144 inline auto name(const BlockTensor<T, Dims...> &input, \
1145 const BlockTensor<U, Dims...> &other) { \
1146 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1147 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1148 result[idx] = \
1149 std::make_shared<T>(torch::name(*input[idx], *other[idx])); \
1150 return result; \
1151 }
1152
1153#define blocktensor_binary_special_op(name) \
1154 template <typename T, typename U, std::size_t... Dims> \
1155 inline auto name(const BlockTensor<T, Dims...> &input, \
1156 const BlockTensor<U, Dims...> &other) { \
1157 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1158 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1159 result[idx] = \
1160 std::make_shared<T>(torch::special::name(*input[idx], *other[idx])); \
1161 return result; \
1162 }
1163
1167
1170
1174
1177
1181
1184
1187template <typename T, typename U, typename V, std::size_t... Dims>
1188inline auto add(const BlockTensor<T, Dims...> &input,
1189 const BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1190 BlockTensor<std::common_type_t<T, U>, Dims...> result;
1191 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1192 result[idx] =
1193 std::make_shared<T>(torch::add(*input[idx], *other[idx], alpha));
1194 return result;
1195}
1196
1199template <typename T, typename U, typename V, std::size_t... Dims>
1200inline auto add(const BlockTensor<T, Dims...> &input, U other, V alpha = 1.0) {
1201 BlockTensor<T, Dims...> result;
1202 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1203 result[idx] = std::make_shared<T>(torch::add(*input[idx], other, alpha));
1204 return result;
1205}
1206
1209template <typename T, typename U, typename V, std::size_t... Dims>
1210inline auto add(T input, const BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1211 BlockTensor<U, Dims...> result;
1212 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1213 result[idx] = std::make_shared<T>(torch::add(input, *other[idx], alpha));
1214 return result;
1215}
1216
1220template <typename T, typename U, typename V, typename W, std::size_t... Dims>
1221inline auto addcdiv(const BlockTensor<T, Dims...> &input,
1222 const BlockTensor<U, Dims...> &tensor1,
1223 const BlockTensor<V, Dims...> &tensor2, W value = 1.0) {
1225 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1226 result[idx] = std::make_shared<T>(
1227 torch::addcdiv(*input[idx], *tensor1[idx], *tensor2[idx], value));
1228 return result;
1229}
1230
1235template <typename T, typename U, typename V, typename W, std::size_t... Dims>
1236inline auto addcmul(const BlockTensor<T, Dims...> &input,
1237 const BlockTensor<U, Dims...> &tensor1,
1238 const BlockTensor<V, Dims...> &tensor2, W value = 1.0) {
1240 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1241 result[idx] = std::make_shared<T>(
1242 torch::addcmul(*input[idx], *tensor1[idx], *tensor2[idx], value));
1243 return result;
1244}
1245
1249
1253
1256
1260
1263
1267
1270
1274
1275
1277
1282
1283#if TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11 || \
1284 TORCH_VERSION_MAJOR >= 2
1286blocktensor_binary_op(arctan2);
1287#endif
1288
1292
1296
1300
1304
1308
1312
1317
1320template <typename T, typename U, std::size_t... Dims>
1321inline auto clamp(const BlockTensor<T, Dims...> &input, U min, U max) {
1322 BlockTensor<T, Dims...> result;
1323 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1324 result[idx] = std::make_shared<T>(torch::clamp(*input[idx], min, max));
1325 return result;
1326}
1327
1329template <typename T, typename U, std::size_t... Dims>
1330inline auto clip(const BlockTensor<T, Dims...> &input, U min, U max) {
1331 BlockTensor<T, Dims...> result;
1332 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1333 result[idx] = std::make_shared<T>(torch::clip(*input[idx], min, max));
1334 return result;
1335}
1336
1340
1344
1348
1352
1356
1357
1360
1363
1367
1370template <typename T, std::size_t Rows, std::size_t Cols>
1371inline auto dot(const BlockTensor<T, Rows, Cols> &input,
1372 const BlockTensor<T, Rows, Cols> &tensor) {
1373 return input.dot(tensor);
1374}
1375
1379
1383
1387
1391
1395
1399
1402
1407
1411
1415
1419
1423
1427
1431
1436
1440
1444
1448
1452
1456
1460
1464
1465
1468
1472
1476
1478
1481
1486
1490
1493
1497
1500
1504
1507
1511
1514
1518
1521
1525
1529
1533
1537
1541
1545
1549
1553
1556
1560
1564
1568
1572
1576
1580
1584
1588
1590template <typename T, typename U, typename V, std::size_t... Dims>
1591inline auto sub(const BlockTensor<T, Dims...> &input,
1592 const BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1593 BlockTensor<std::common_type_t<T, U>, Dims...> result;
1594 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1595 result[idx] =
1596 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1597 return result;
1598}
1599
1601template <typename T, typename U, typename V, std::size_t... Dims>
1602inline auto subtract(const BlockTensor<T, Dims...> &input,
1603 const BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1604 BlockTensor<std::common_type_t<T, U>, Dims...> result;
1605 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1606 result[idx] =
1607 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1608 return result;
1609}
1610
1614
1618
1622
1623
1625
1628template <typename T, typename U, std::size_t... Dims>
1629inline auto operator+(const BlockTensor<T, Dims...> &lhs,
1630 const BlockTensor<U, Dims...> &rhs) {
1631 BlockTensor<std::common_type_t<T, U>, Dims...> result;
1632 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1633 result[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1634 return result;
1635}
1636
1639template <typename T, typename U, std::size_t... Dims>
1640inline auto operator+(const BlockTensor<T, Dims...> &lhs, const U &rhs) {
1641 BlockTensor<T, Dims...> result;
1642 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1643 result[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1644 return result;
1645}
1646
1649template <typename T, typename U, std::size_t... Dims>
1650inline auto operator+(const T &lhs, const BlockTensor<U, Dims...> &rhs) {
1651 BlockTensor<U, Dims...> result;
1652 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1653 result[idx] = std::make_shared<U>(lhs + *rhs[idx]);
1654 return result;
1655}
1656
1658template <typename T, typename U, std::size_t... Dims>
1660 const BlockTensor<U, Dims...> &rhs) {
1661 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1662 lhs[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1663 return lhs;
1664}
1665
1667template <typename T, typename U, std::size_t... Dims>
1668inline auto operator+=(BlockTensor<T, Dims...> &lhs, const U &rhs) {
1669 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1670 lhs[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1671 return lhs;
1672}
1673
1676template <typename T, typename U, std::size_t... Dims>
1677inline auto operator-(const BlockTensor<T, Dims...> &lhs,
1678 const BlockTensor<U, Dims...> &rhs) {
1679 BlockTensor<std::common_type_t<T, U>, Dims...> result;
1680 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1681 result[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1682 return result;
1683}
1684
1687template <typename T, typename U, std::size_t... Dims>
1688inline auto operator-(const BlockTensor<T, Dims...> &lhs, const U &rhs) {
1689 BlockTensor<T, Dims...> result;
1690 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1691 result[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1692 return result;
1693}
1694
1697template <typename T, typename U, std::size_t... Dims>
1698inline auto operator-(const T &lhs, const BlockTensor<U, Dims...> &rhs) {
1699 BlockTensor<U, Dims...> result;
1700 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1701 result[idx] = std::make_shared<U>(lhs - *rhs[idx]);
1702 return result;
1703}
1704
1706template <typename T, typename U, std::size_t... Dims>
1708 const BlockTensor<U, Dims...> &rhs) {
1709 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1710 lhs[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1711 return lhs;
1712}
1713
1715template <typename T, typename U, std::size_t... Dims>
1716inline auto operator-=(BlockTensor<T, Dims...> &lhs, const U &rhs) {
1717 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1718 lhs[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1719 return lhs;
1720}
1721
1724template <typename T, typename U, std::size_t... Dims>
1725inline auto operator*(const BlockTensor<T, Dims...> &lhs, const U &rhs) {
1726 BlockTensor<T, Dims...> result;
1727 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1728 result[idx] =
1729 (lhs[idx]->dim() > rhs.dim()
1730 ? std::make_shared<T>(*lhs[idx] * rhs.unsqueeze(-1))
1731 : (lhs[idx]->dim() < rhs.dim()
1732 ? std::make_shared<T>(lhs[idx]->unsqueeze(-1) * rhs)
1733 : std::make_shared<T>(*lhs[idx] * rhs)));
1734 ;
1735 return result;
1736}
1737
1740template <typename T, typename U, std::size_t... Dims>
1741inline auto operator*(const T &lhs, const BlockTensor<U, Dims...> &rhs) {
1742 BlockTensor<U, Dims...> result;
1743 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1744 result[idx] =
1745 (lhs.dim() > rhs[idx]->dim()
1746 ? std::make_shared<U>(lhs * rhs[idx]->unsqueeze(-1))
1747 : (lhs.dim() < rhs[idx]->dim()
1748 ? std::make_shared<U>(lhs.unsqueeze(-1) * *rhs[idx])
1749 : std::make_shared<U>(lhs * *rhs[idx])));
1750 return result;
1751}
1752
1754template <typename T, typename U, std::size_t... TDims, std::size_t... UDims>
1756 const BlockTensor<U, UDims...> &rhs) {
1757 if constexpr ((sizeof...(TDims) != sizeof...(UDims)) ||
1758 ((TDims != UDims) || ...))
1759 return false;
1760
1761 bool result = true;
1762 for (std::size_t idx = 0; idx < (TDims * ...); ++idx)
1763 result = result && torch::equal(*lhs[idx], *rhs[idx]);
1764
1765 return result;
1766}
1767
1769template <typename T, typename U, std::size_t... TDims, std::size_t... UDims>
1771 const BlockTensor<U, UDims...> &rhs) {
1772 return !(lhs == rhs);
1773}
1774
1775} // namespace iganet::utils
#define blocktensor_unary_op(name)
Definition blocktensor.hpp:1124
#define blocktensor_unary_special_op(name)
Definition blocktensor.hpp:1133
#define blocktensor_binary_op(name)
Definition blocktensor.hpp:1142
#define blocktensor_binary_special_op(name)
Definition blocktensor.hpp:1153
static constexpr std::size_t slices()
Returns the number of slices.
Definition blocktensor.hpp:942
auto reorder_jik() const
Returns a new block tensor with rows and columns transposed and slices remaining fixed....
Definition blocktensor.hpp:995
const T & operator()(std::size_t row, std::size_t col, std::size_t slice) const
Returns a constant reference to entry (row, col, slice)
Definition blocktensor.hpp:947
T & set(std::size_t row, std::size_t col, std::size_t slice, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:963
auto reorder_kij() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1019
auto reorder_ikj() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:982
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:936
T & operator()(std::size_t row, std::size_t col, std::size_t slice)
Returns a non-constant reference to entry (row, col, slice)
Definition blocktensor.hpp:954
auto reorder_kji() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1007
auto slice(std::size_t slice) const
Returns a rank-2 tensor of the k-th slice.
Definition blocktensor.hpp:970
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BSplineCommon object.
Definition blocktensor.hpp:1030
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:939
auto inv() const
Returns the inverse of the block tensor.
Definition blocktensor.hpp:307
auto dot_(std::index_sequence< Is... >, const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:866
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:882
auto ginv() const
Returns the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:546
auto ginvtr() const
Returns the transpose of the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:802
auto tr() const
Returns the transpose of the block tensor.
Definition blocktensor.hpp:220
auto invtr() const
Returns the transpose of the inverse of the block tensor.
Definition blocktensor.hpp:559
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:193
auto normalize_(std::index_sequence< Is... > is) const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:851
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:190
auto norm() const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:843
const T & operator()(std::size_t row, std::size_t col) const
Returns a constant reference to entry (row, col)
Definition blocktensor.hpp:198
T & set(std::size_t row, std::size_t col, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:213
auto dot(const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:876
auto trace() const
Returns the trace of the block tensor.
Definition blocktensor.hpp:811
auto det() const
Returns the determinant of a square block tensor.
Definition blocktensor.hpp:232
T & operator()(std::size_t row, std::size_t col)
Returns a non-constant reference to entry (row, col)
Definition blocktensor.hpp:204
auto norm_(std::index_sequence< Is... >) const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:835
auto normalize() const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:859
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:169
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:166
Compile-time block tensor core.
Definition blocktensor.hpp:47
const std::array< std::shared_ptr< T >,(Dims *...)> & data() const
Returns a constant reference to the data array.
Definition blocktensor.hpp:106
static constexpr auto dims()
Returns all dimensions as array.
Definition blocktensor.hpp:87
BlockTensorCore(BlockTensorCore< Ts, dims... > &&...other)
Constructor from BlockTensorCore objects.
Definition blocktensor.hpp:59
BlockTensorCore(Ts &&...data)
Constructor from variadic templates.
Definition blocktensor.hpp:83
BlockTensorCore()=default
Default constructor.
std::shared_ptr< T > & operator[](std::size_t idx)
Returns a non-constant shared pointer to entry (idx)
Definition blocktensor.hpp:120
void pretty_print(std::ostream &os) const noexcept override=0
Returns a string representation of the BlockTensorCore object.
T & set(std::size_t idx, Data &&data)
Stores the given data object at the given index.
Definition blocktensor.hpp:138
static constexpr std::size_t dim()
Returns the i-th dimension.
Definition blocktensor.hpp:92
const std::shared_ptr< T > & operator[](std::size_t idx) const
Returns a constant shared pointer to entry (idx)
Definition blocktensor.hpp:114
BlockTensorCore(BlockTensor< Ts, dims... > &&...other)
Constructor from BlockTensor objects.
Definition blocktensor.hpp:71
static constexpr std::size_t entries()
Returns the total number of entries.
Definition blocktensor.hpp:103
std::array< std::shared_ptr< T >,(Dims *...)> & data()
Returns a non-constant reference to the data array.
Definition blocktensor.hpp:111
static constexpr std::size_t size()
Returns the number of dimensions.
Definition blocktensor.hpp:100
const T & operator()(std::size_t idx) const
Returns a constant reference to entry (idx)
Definition blocktensor.hpp:126
std::array< std::shared_ptr< T >,(Dims *...)> data_
Array storing the data.
Definition blocktensor.hpp:51
T & operator()(std::size_t idx)
Returns a non-constant reference to entry (idx)
Definition blocktensor.hpp:132
Full qualified name descriptor.
Definition fqn.hpp:22
Core components.
Full qualified name utility functions.
Definition blocktensor.hpp:24
auto addcmul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 multiplied by the elements of tensor2,...
Definition blocktensor.hpp:1236
auto log2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-2 of the elements of input
Definition blocktensor.hpp:1451
auto tan(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the tangent of the elements of input.
Definition blocktensor.hpp:1613
auto square(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square of the elements of input
Definition blocktensor.hpp:1587
auto mul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the product of each element of input and other
Definition blocktensor.hpp:1503
auto divide(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for div()
Definition blocktensor.hpp:1362
auto exp2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the base-2 exponential of the elements of input
Definition blocktensor.hpp:1394
auto frexp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the decomposition of the elements of input into mantissae and exponen...
Definition blocktensor.hpp:1422
auto xlogy(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Computes input * log(other)
Definition blocktensor.hpp:1624
auto operator-(const BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Subtracts one compile-time block tensor from another and returns a new compile-time block tensor.
Definition blocktensor.hpp:1677
auto floor(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the floor of the elements of input, the largest integer less than or ...
Definition blocktensor.hpp:1410
auto bitwise_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the bitwise NOT of the elements of input
Definition blocktensor.hpp:1291
auto i0(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise zeroth order modified Bessel function of the first k...
Definition blocktensor.hpp:1485
auto float_power(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input raised to the power of exponent,...
Definition blocktensor.hpp:1406
auto round(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input rounded to the nearest integer.
Definition blocktensor.hpp:1544
auto hypot(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
logit
Definition blocktensor.hpp:1480
auto imag(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the imaginary values of the elements of input
Definition blocktensor.hpp:1426
auto atan(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arctangent of the elements of input
Definition blocktensor.hpp:1266
auto copysign(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the magnitude of the elements of input and the sign of the elements o...
Definition blocktensor.hpp:1343
auto add(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Returns a new block tensor with the elements of other, scaled by alpha, added to the elements of inpu...
Definition blocktensor.hpp:1188
auto asin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arcsine of the elements of input
Definition blocktensor.hpp:1252
bool operator==(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are equal.
Definition blocktensor.hpp:1755
auto angle(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the angle (in radians) of the elements of input
Definition blocktensor.hpp:1248
auto nextafter(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Return a new block tensor with the next elementwise floating-point value after input towards other
Definition blocktensor.hpp:1517
auto arcsinh(const BlockTensor< T, Dims... > &input)
Alias for asinh()
Definition blocktensor.hpp:1262
auto sub(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Subtracts other, scaled by alpha, from input.
Definition blocktensor.hpp:1591
auto sign(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input
Definition blocktensor.hpp:1559
auto logical_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical AND of the elements of input and other
Definition blocktensor.hpp:1463
auto absolute(const BlockTensor< T, Dims... > &input)
Alias for abs()
Definition blocktensor.hpp:1169
auto fix(const BlockTensor< T, Dims... > &input)
Alias for trunc()
Definition blocktensor.hpp:1401
auto sinc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the normalized sinc of the elements of input
Definition blocktensor.hpp:1575
auto logical_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise logical NOT of the elements of input
Definition blocktensor.hpp:1467
auto positive(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the input
Definition blocktensor.hpp:1520
auto sqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square-root of the elements of input
Definition blocktensor.hpp:1583
auto reciprocal(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the elements of input
Definition blocktensor.hpp:1536
auto clip(const BlockTensor< T, Dims... > &input, U min, U max)
Alias for clamp()
Definition blocktensor.hpp:1330
auto arcsin(const BlockTensor< T, Dims... > &input)
Alias for asin()
Definition blocktensor.hpp:1255
auto bitwise_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise OR of the elements of input and other
Definition blocktensor.hpp:1299
auto atanh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic tangent of the elements of input
Definition blocktensor.hpp:1273
auto subtract(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Alias for sub()
Definition blocktensor.hpp:1602
auto atan2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the arctangent of the elements in input and other with consideration ...
Definition blocktensor.hpp:1281
auto expit(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the expit (also known as the logistic sigmoid function) of the elemen...
Definition blocktensor.hpp:1552
auto rsqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the square-root of the elements of input
Definition blocktensor.hpp:1548
auto sin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the sine of the elements of input
Definition blocktensor.hpp:1571
auto cosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1351
std::ostream & operator<<(std::ostream &os, const BlockTensorCore< T, Dims... > &obj)
Prints (as string) a compile-time block tensor object.
Definition blocktensor.hpp:150
auto bitwise_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise AND of the elements of input and other
Definition blocktensor.hpp:1295
auto erfc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the complementary error function of the elements of input
Definition blocktensor.hpp:1382
auto gammainc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized lower incomplete gamma function of each element of in...
Definition blocktensor.hpp:1489
auto operator-=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Decrements one compile-time block tensor by another.
Definition blocktensor.hpp:1707
auto arccos(const BlockTensor< T, Dims... > &input)
Alias for acos()
Definition blocktensor.hpp:1176
auto negative(const BlockTensor< T, Dims... > &input)
Alias for neg()
Definition blocktensor.hpp:1513
auto multiply(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for mul()
Definition blocktensor.hpp:1506
auto dot(const BlockTensor< T, Rows, Cols > &input, const BlockTensor< T, Rows, Cols > &tensor)
Returns a new block tensor with the dot product of the two input block tensors.
Definition blocktensor.hpp:1371
auto logaddexp2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input ...
Definition blocktensor.hpp:1459
auto pow(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the power of each element in input with exponent other
Definition blocktensor.hpp:1524
auto bitwise_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise XOR of the elements of input and other
Definition blocktensor.hpp:1303
auto ldexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input multiplied by 2**other.
Definition blocktensor.hpp:1430
auto igammac(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1499
auto neg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the negative of the elements of input
Definition blocktensor.hpp:1510
auto exp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential of the elements of input
Definition blocktensor.hpp:1390
auto arctan(const BlockTensor< T, Dims... > &input)
Alias for atan()
Definition blocktensor.hpp:1269
auto log1p(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of (1 + the elements of input)
Definition blocktensor.hpp:1447
auto arctanh(const BlockTensor< T, Dims... > &input)
Alias for atanh()
Definition blocktensor.hpp:1276
auto ceil(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the ceil of the elements of input, the smallest integer greater than ...
Definition blocktensor.hpp:1316
auto trunc(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the truncated integer values of the elements of input.
Definition blocktensor.hpp:1621
auto cos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the cosine of the elements of input
Definition blocktensor.hpp:1347
auto erfinv(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse error function of the elements of input
Definition blocktensor.hpp:1386
bool operator!=(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are not equal.
Definition blocktensor.hpp:1770
auto expm1(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential minus 1 of the elements of input
Definition blocktensor.hpp:1398
auto signbit(const BlockTensor< T, Dims... > &input)
Tests if each element of input has its sign bit set (is less than zero) or not.
Definition blocktensor.hpp:1567
auto conj_physical(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the conjugate of the elements of input tensor.
Definition blocktensor.hpp:1339
auto sinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic sine of the elements of input
Definition blocktensor.hpp:1579
auto remainder(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the modulus of the elements of input
Definition blocktensor.hpp:1540
auto digamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithmic derivative of the gamma function of the elements of i...
Definition blocktensor.hpp:1366
auto asinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic sine of the elements of input
Definition blocktensor.hpp:1259
auto div(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input divided by the elements of other
Definition blocktensor.hpp:1359
auto igamma(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1492
auto bitwise_left_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the left arithmetic shift of the elements of input by other bits.
Definition blocktensor.hpp:1307
auto rad2deg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with each of the elements of input converted from angles in radians to deg...
Definition blocktensor.hpp:1528
auto real(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the real values of the elements of input
Definition blocktensor.hpp:1532
auto lgamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of the absolute value of the gamma function of ...
Definition blocktensor.hpp:1435
auto gammaincc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized upper incomplete gamma function of each element of in...
Definition blocktensor.hpp:1496
auto acos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse cosine of the elements of input
Definition blocktensor.hpp:1173
auto fmod(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the fmod of the elements of input and other
Definition blocktensor.hpp:1414
auto bitwise_right_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the right arithmetic shift of the element of input by other bits.
Definition blocktensor.hpp:1311
auto operator+=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Increments one compile-time block tensor by another.
Definition blocktensor.hpp:1659
auto sgn(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input, extension to complex value.
Definition blocktensor.hpp:1563
auto arccosh(const BlockTensor< T, Dims... > &input)
Alias for acosh()`.
Definition blocktensor.hpp:1183
auto abs(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the absolute value of the elements of input
Definition blocktensor.hpp:1166
auto logical_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical XOR of the elements of input and other
Definition blocktensor.hpp:1475
auto addcdiv(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 divided by the elements of tensor2,...
Definition blocktensor.hpp:1221
auto deg2rad(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input converted from angles in degrees to radians.
Definition blocktensor.hpp:1355
auto logaddexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input
Definition blocktensor.hpp:1455
auto erf(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the error function of the elements of input
Definition blocktensor.hpp:1378
auto operator*(const BlockTensor< T, Rows, Common > &lhs, const BlockTensor< U, Common, Cols > &rhs)
Multiplies one compile-time rank-2 block tensor with another compile-time rank-2 block tensor.
Definition blocktensor.hpp:895
auto log10(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-10 of the elements of input
Definition blocktensor.hpp:1443
auto make_shared(T &&arg)
Returns a std::shared_ptr<T> object from arg.
Definition blocktensor.hpp:35
auto acosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1180
auto clamp(const BlockTensor< T, Dims... > &input, U min, U max)
Returns a new block tensor with the elements of input clamped into the range [ min,...
Definition blocktensor.hpp:1321
auto logical_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical OR of the elements of input and other
Definition blocktensor.hpp:1471
auto frac(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the fractional portion of the elements of input
Definition blocktensor.hpp:1418
Forward declaration of BlockTensor.
Definition blocktensor.hpp:43
constexpr auto operator+(deriv lhs, deriv rhs)
Adds two enumerators for specifying the derivative of B-spline evaluation.
Definition bspline.hpp:89
log
Enumerator for specifying the logging level.
Definition core.hpp:90
STL namespace.
Type trait checks if template argument is of type std::shared_ptr<T>
Definition blocktensor.hpp:28