19#include <initializer_list>
40 return std::forward<
typename std::decay<T>::type>(arg);
42 return std::make_shared<typename std::decay<T>::type>(std::forward<T>(arg));
49template <
typename T, std::size_t... Dims>
54 std::array<std::shared_ptr<T>, (Dims * ...)>
data_;
61 template <
typename... Ts, std::size_t...
dims>
63 auto it =
data_.begin();
64 (std::transform(other.data().begin(), other.data().end(), it,
73 template <
typename... Ts, std::size_t...
dims>
75 auto it =
data_.begin();
76 (std::transform(other.data().begin(), other.data().end(), it,
85 template <
typename... Ts>
87 :
data_({make_shared<Ts>(std::forward<Ts>(
data))...}) {}
90 inline static constexpr auto dims() {
91 return std::array<std::size_t,
sizeof...(Dims)>({Dims...});
95 template <std::
size_t i>
inline static constexpr std::size_t
dim() {
96 if constexpr (i <
sizeof...(Dims))
97 return std::get<i>(std::forward_as_tuple(Dims...));
103 inline static constexpr std::size_t
size() {
return sizeof...(Dims); }
106 inline static constexpr std::size_t
entries() {
return (Dims * ...); }
109 inline const std::array<std::shared_ptr<T>, (Dims * ...)> &
data()
const {
114 inline std::array<std::shared_ptr<T>, (Dims * ...)> &
data() {
return data_; }
117 inline const std::shared_ptr<T> &
operator[](std::size_t idx)
const {
118 assert(0 <= idx && idx < (Dims * ...));
124 assert(0 <= idx && idx < (Dims * ...));
130 assert(0 <= idx && idx < (Dims * ...));
136 assert(0 <= idx && idx < (Dims * ...));
141 template <
typename Data>
inline T &
set(std::size_t idx, Data &&
data) {
142 assert(0 <= idx && idx < (Dims * ...));
143 data_[idx] = make_shared<Data>(std::forward<Data>(
data));
153template <
typename T, std::size_t... Dims>
161template <
typename T, std::
size_t Rows>
170 inline static constexpr std::size_t
rows() {
return Rows; }
175 os << Base::name() <<
"\n";
176 for (std::size_t row = 0; row < Rows; ++row)
177 os <<
"[" << row <<
"] = \n" << *Base::data_[row] <<
"\n";
186template <
typename T, std::
size_t Rows, std::
size_t Cols>
195 inline static constexpr std::size_t
rows() {
return Rows; }
198 inline static constexpr std::size_t
cols() {
return Cols; }
200 using Base::operator();
203 inline const T &
operator()(std::size_t row, std::size_t col)
const {
204 assert(0 <= row && row < Rows && 0 <= col && col < Cols);
205 return *Base::data_[Cols * row + col];
210 assert(0 <= row && row < Rows && 0 <= col && col < Cols);
211 return *Base::data_[Cols * row + col];
217 template <
typename D>
218 inline T &
set(std::size_t row, std::size_t col, D &&data) {
219 assert(0 <= row && row < Rows && 0 <= col && col < Cols);
220 Base::data_[Cols * row + col] = make_shared<D>(std::forward<D>(data));
221 return *Base::data_[Cols * row + col];
225 inline auto tr()
const {
227 for (std::size_t row = 0; row < Rows; ++row)
228 for (std::size_t col = 0; col < Cols; ++col)
229 result[Rows * col + row] = Base::data_[Cols * row + col];
238 if constexpr (Rows == 1 && Cols == 1) {
239 auto result = *Base::data_[0];
241 }
else if constexpr (Rows == 2 && Cols == 2) {
242 auto result = torch::mul(*Base::data_[0], *Base::data_[3]) -
243 torch::mul(*Base::data_[1], *Base::data_[2]);
245 }
else if constexpr (Rows == 3 && Cols == 3) {
247 torch::mul(*Base::data_[0],
248 torch::mul(*Base::data_[4], *Base::data_[8]) -
249 torch::mul(*Base::data_[5], *Base::data_[7])) -
250 torch::mul(*Base::data_[1],
251 torch::mul(*Base::data_[3], *Base::data_[8]) -
252 torch::mul(*Base::data_[5], *Base::data_[6])) +
253 torch::mul(*Base::data_[2],
254 torch::mul(*Base::data_[3], *Base::data_[7]) -
255 torch::mul(*Base::data_[4], *Base::data_[6]));
257 }
else if constexpr (Rows == 4 && Cols == 4) {
258 auto a11 = torch::mul(*Base::data_[5],
259 (torch::mul(*Base::data_[10], *Base::data_[15]) -
260 torch::mul(*Base::data_[11], *Base::data_[14]))) -
261 torch::mul(*Base::data_[9],
262 (torch::mul(*Base::data_[6], *Base::data_[15]) -
263 torch::mul(*Base::data_[7], *Base::data_[14]))) -
264 torch::mul(*Base::data_[13],
265 (torch::mul(*Base::data_[7], *Base::data_[10]) -
266 torch::mul(*Base::data_[6], *Base::data_[11])));
268 auto a21 = torch::mul(*Base::data_[4],
269 (torch::mul(*Base::data_[11], *Base::data_[14]) -
270 torch::mul(*Base::data_[10], *Base::data_[15]))) -
271 torch::mul(*Base::data_[8],
272 (torch::mul(*Base::data_[7], *Base::data_[14]) -
273 torch::mul(*Base::data_[6], *Base::data_[15]))) -
274 torch::mul(*Base::data_[12],
275 (torch::mul(*Base::data_[6], *Base::data_[11]) -
276 torch::mul(*Base::data_[7], *Base::data_[10])));
278 auto a31 = torch::mul(*Base::data_[4],
279 (torch::mul(*Base::data_[9], *Base::data_[15]) -
280 torch::mul(*Base::data_[11], *Base::data_[13]))) -
281 torch::mul(*Base::data_[8],
282 (torch::mul(*Base::data_[5], *Base::data_[15]) -
283 torch::mul(*Base::data_[7], *Base::data_[13]))) -
284 torch::mul(*Base::data_[12],
285 (torch::mul(*Base::data_[7], *Base::data_[9]) -
286 torch::mul(*Base::data_[5], *Base::data_[11])));
288 auto a41 = torch::mul(*Base::data_[4],
289 (torch::mul(*Base::data_[10], *Base::data_[13]) -
290 torch::mul(*Base::data_[9], *Base::data_[14]))) -
291 torch::mul(*Base::data_[8],
292 (torch::mul(*Base::data_[6], *Base::data_[13]) -
293 torch::mul(*Base::data_[5], *Base::data_[14]))) -
294 torch::mul(*Base::data_[12],
295 (torch::mul(*Base::data_[5], *Base::data_[10]) -
296 torch::mul(*Base::data_[6], *Base::data_[9])));
299 torch::mul(*Base::data_[0], a11) + torch::mul(*Base::data_[1], a21) +
300 torch::mul(*Base::data_[2], a31) + torch::mul(*Base::data_[3], a41);
304 throw std::runtime_error(
"Unsupported block tensor dimension");
314 auto det_ = this->det();
316 if constexpr (Rows == 1 && Cols == 1) {
318 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
320 }
else if constexpr (Rows == 2 && Cols == 2) {
323 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
324 result[1] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
325 result[2] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
326 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
328 }
else if constexpr (Rows == 3 && Cols == 3) {
330 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
331 torch::mul(*Base::data_[5], *Base::data_[7]);
332 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
333 torch::mul(*Base::data_[1], *Base::data_[8]);
334 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
335 torch::mul(*Base::data_[2], *Base::data_[4]);
336 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
337 torch::mul(*Base::data_[3], *Base::data_[8]);
338 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
339 torch::mul(*Base::data_[2], *Base::data_[6]);
340 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
341 torch::mul(*Base::data_[0], *Base::data_[5]);
342 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
343 torch::mul(*Base::data_[4], *Base::data_[6]);
344 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
345 torch::mul(*Base::data_[0], *Base::data_[7]);
346 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
347 torch::mul(*Base::data_[1], *Base::data_[3]);
350 result[0] = std::make_shared<T>(torch::div(a11, det_));
351 result[1] = std::make_shared<T>(torch::div(a12, det_));
352 result[2] = std::make_shared<T>(torch::div(a13, det_));
353 result[3] = std::make_shared<T>(torch::div(a21, det_));
354 result[4] = std::make_shared<T>(torch::div(a22, det_));
355 result[5] = std::make_shared<T>(torch::div(a23, det_));
356 result[6] = std::make_shared<T>(torch::div(a31, det_));
357 result[7] = std::make_shared<T>(torch::div(a32, det_));
358 result[8] = std::make_shared<T>(torch::div(a33, det_));
360 }
else if constexpr (Rows == 4 && Cols == 4) {
361 auto a11 = torch::mul(*Base::data_[5],
362 (torch::mul(*Base::data_[10], *Base::data_[15]) -
363 torch::mul(*Base::data_[11], *Base::data_[14]))) -
364 torch::mul(*Base::data_[9],
365 (torch::mul(*Base::data_[6], *Base::data_[15]) -
366 torch::mul(*Base::data_[7], *Base::data_[14]))) -
367 torch::mul(*Base::data_[13],
368 (torch::mul(*Base::data_[7], *Base::data_[10]) -
369 torch::mul(*Base::data_[6], *Base::data_[11])));
371 auto a12 = torch::mul(*Base::data_[1],
372 (torch::mul(*Base::data_[11], *Base::data_[14]) -
373 torch::mul(*Base::data_[10], *Base::data_[15]))) -
374 torch::mul(*Base::data_[9],
375 (torch::mul(*Base::data_[3], *Base::data_[14]) -
376 torch::mul(*Base::data_[2], *Base::data_[15]))) -
377 torch::mul(*Base::data_[13],
378 (torch::mul(*Base::data_[2], *Base::data_[11]) -
379 torch::mul(*Base::data_[3], *Base::data_[10])));
381 auto a13 = torch::mul(*Base::data_[1],
382 (torch::mul(*Base::data_[6], *Base::data_[15]) -
383 torch::mul(*Base::data_[7], *Base::data_[14]))) -
384 torch::mul(*Base::data_[5],
385 (torch::mul(*Base::data_[2], *Base::data_[15]) -
386 torch::mul(*Base::data_[3], *Base::data_[14]))) -
387 torch::mul(*Base::data_[13],
388 (torch::mul(*Base::data_[3], *Base::data_[6]) -
389 torch::mul(*Base::data_[2], *Base::data_[7])));
391 auto a14 = torch::mul(*Base::data_[1],
392 (torch::mul(*Base::data_[7], *Base::data_[10]) -
393 torch::mul(*Base::data_[6], *Base::data_[11]))) -
394 torch::mul(*Base::data_[5],
395 (torch::mul(*Base::data_[3], *Base::data_[10]) -
396 torch::mul(*Base::data_[2], *Base::data_[11]))) -
397 torch::mul(*Base::data_[9],
398 (torch::mul(*Base::data_[2], *Base::data_[7]) -
399 torch::mul(*Base::data_[3], *Base::data_[6])));
401 auto a21 = torch::mul(*Base::data_[4],
402 (torch::mul(*Base::data_[11], *Base::data_[14]) -
403 torch::mul(*Base::data_[10], *Base::data_[15]))) -
404 torch::mul(*Base::data_[8],
405 (torch::mul(*Base::data_[7], *Base::data_[14]) -
406 torch::mul(*Base::data_[6], *Base::data_[15]))) -
407 torch::mul(*Base::data_[12],
408 (torch::mul(*Base::data_[6], *Base::data_[11]) -
409 torch::mul(*Base::data_[7], *Base::data_[10])));
411 auto a22 = torch::mul(*Base::data_[0],
412 (torch::mul(*Base::data_[10], *Base::data_[15]) -
413 torch::mul(*Base::data_[11], *Base::data_[14]))) -
414 torch::mul(*Base::data_[8],
415 (torch::mul(*Base::data_[2], *Base::data_[15]) -
416 torch::mul(*Base::data_[3], *Base::data_[14]))) -
417 torch::mul(*Base::data_[12],
418 (torch::mul(*Base::data_[3], *Base::data_[10]) -
419 torch::mul(*Base::data_[2], *Base::data_[11])));
421 auto a23 = torch::mul(*Base::data_[0],
422 (torch::mul(*Base::data_[7], *Base::data_[14]) -
423 torch::mul(*Base::data_[6], *Base::data_[15]))) -
424 torch::mul(*Base::data_[4],
425 (torch::mul(*Base::data_[3], *Base::data_[14]) -
426 torch::mul(*Base::data_[2], *Base::data_[15]))) -
427 torch::mul(*Base::data_[12],
428 (torch::mul(*Base::data_[2], *Base::data_[7]) -
429 torch::mul(*Base::data_[3], *Base::data_[6])));
431 auto a24 = torch::mul(*Base::data_[0],
432 (torch::mul(*Base::data_[6], *Base::data_[11]) -
433 torch::mul(*Base::data_[7], *Base::data_[10]))) -
434 torch::mul(*Base::data_[4],
435 (torch::mul(*Base::data_[2], *Base::data_[11]) -
436 torch::mul(*Base::data_[3], *Base::data_[10]))) -
437 torch::mul(*Base::data_[8],
438 (torch::mul(*Base::data_[3], *Base::data_[6]) -
439 torch::mul(*Base::data_[2], *Base::data_[7])));
441 auto a31 = torch::mul(*Base::data_[4],
442 (torch::mul(*Base::data_[9], *Base::data_[15]) -
443 torch::mul(*Base::data_[11], *Base::data_[13]))) -
444 torch::mul(*Base::data_[8],
445 (torch::mul(*Base::data_[5], *Base::data_[15]) -
446 torch::mul(*Base::data_[7], *Base::data_[13]))) -
447 torch::mul(*Base::data_[12],
448 (torch::mul(*Base::data_[7], *Base::data_[9]) -
449 torch::mul(*Base::data_[5], *Base::data_[11])));
451 auto a32 = torch::mul(*Base::data_[0],
452 (torch::mul(*Base::data_[11], *Base::data_[13]) -
453 torch::mul(*Base::data_[9], *Base::data_[15]))) -
454 torch::mul(*Base::data_[8],
455 (torch::mul(*Base::data_[3], *Base::data_[13]) -
456 torch::mul(*Base::data_[1], *Base::data_[15]))) -
457 torch::mul(*Base::data_[12],
458 (torch::mul(*Base::data_[1], *Base::data_[11]) -
459 torch::mul(*Base::data_[3], *Base::data_[9])));
461 auto a33 = torch::mul(*Base::data_[0],
462 (torch::mul(*Base::data_[5], *Base::data_[15]) -
463 torch::mul(*Base::data_[7], *Base::data_[13]))) -
464 torch::mul(*Base::data_[4],
465 (torch::mul(*Base::data_[1], *Base::data_[15]) -
466 torch::mul(*Base::data_[3], *Base::data_[13]))) -
467 torch::mul(*Base::data_[12],
468 (torch::mul(*Base::data_[3], *Base::data_[5]) -
469 torch::mul(*Base::data_[1], *Base::data_[7])));
471 auto a34 = torch::mul(*Base::data_[0],
472 (torch::mul(*Base::data_[7], *Base::data_[9]) -
473 torch::mul(*Base::data_[5], *Base::data_[11]))) -
474 torch::mul(*Base::data_[4],
475 (torch::mul(*Base::data_[3], *Base::data_[9]) -
476 torch::mul(*Base::data_[1], *Base::data_[11]))) -
477 torch::mul(*Base::data_[8],
478 (torch::mul(*Base::data_[1], *Base::data_[7]) -
479 torch::mul(*Base::data_[3], *Base::data_[5])));
481 auto a41 = torch::mul(*Base::data_[4],
482 (torch::mul(*Base::data_[10], *Base::data_[13]) -
483 torch::mul(*Base::data_[9], *Base::data_[14]))) -
484 torch::mul(*Base::data_[8],
485 (torch::mul(*Base::data_[6], *Base::data_[13]) -
486 torch::mul(*Base::data_[5], *Base::data_[14]))) -
487 torch::mul(*Base::data_[12],
488 (torch::mul(*Base::data_[5], *Base::data_[10]) -
489 torch::mul(*Base::data_[6], *Base::data_[9])));
491 auto a42 = torch::mul(*Base::data_[0],
492 (torch::mul(*Base::data_[9], *Base::data_[14]) -
493 torch::mul(*Base::data_[10], *Base::data_[13]))) -
494 torch::mul(*Base::data_[8],
495 (torch::mul(*Base::data_[1], *Base::data_[14]) -
496 torch::mul(*Base::data_[2], *Base::data_[13]))) -
497 torch::mul(*Base::data_[12],
498 (torch::mul(*Base::data_[2], *Base::data_[9]) -
499 torch::mul(*Base::data_[1], *Base::data_[10])));
501 auto a43 = torch::mul(*Base::data_[0],
502 (torch::mul(*Base::data_[6], *Base::data_[13]) -
503 torch::mul(*Base::data_[5], *Base::data_[14]))) -
504 torch::mul(*Base::data_[4],
505 (torch::mul(*Base::data_[2], *Base::data_[13]) -
506 torch::mul(*Base::data_[1], *Base::data_[14]))) -
507 torch::mul(*Base::data_[12],
508 (torch::mul(*Base::data_[1], *Base::data_[6]) -
509 torch::mul(*Base::data_[2], *Base::data_[5])));
511 auto a44 = torch::mul(*Base::data_[0],
512 (torch::mul(*Base::data_[5], *Base::data_[10]) -
513 torch::mul(*Base::data_[6], *Base::data_[9]))) -
514 torch::mul(*Base::data_[4],
515 (torch::mul(*Base::data_[1], *Base::data_[10]) -
516 torch::mul(*Base::data_[2], *Base::data_[9]))) -
517 torch::mul(*Base::data_[8],
518 (torch::mul(*Base::data_[2], *Base::data_[5]) -
519 torch::mul(*Base::data_[1], *Base::data_[6])));
521 result[0] = std::make_shared<T>(torch::div(a11, det_));
522 result[1] = std::make_shared<T>(torch::div(a12, det_));
523 result[2] = std::make_shared<T>(torch::div(a13, det_));
524 result[3] = std::make_shared<T>(torch::div(a14, det_));
525 result[4] = std::make_shared<T>(torch::div(a21, det_));
526 result[5] = std::make_shared<T>(torch::div(a22, det_));
527 result[6] = std::make_shared<T>(torch::div(a23, det_));
528 result[7] = std::make_shared<T>(torch::div(a24, det_));
529 result[8] = std::make_shared<T>(torch::div(a31, det_));
530 result[9] = std::make_shared<T>(torch::div(a32, det_));
531 result[10] = std::make_shared<T>(torch::div(a33, det_));
532 result[11] = std::make_shared<T>(torch::div(a34, det_));
533 result[12] = std::make_shared<T>(torch::div(a41, det_));
534 result[13] = std::make_shared<T>(torch::div(a42, det_));
535 result[14] = std::make_shared<T>(torch::div(a43, det_));
536 result[15] = std::make_shared<T>(torch::div(a44, det_));
539 throw std::runtime_error(
"Unsupported block tensor dimension");
552 if constexpr (Rows == Cols)
556 return (this->tr() * (*this)).inv() * this->tr();
566 auto det_ = this->det();
568 if constexpr (Rows == 1 && Cols == 1) {
570 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
572 }
else if constexpr (Rows == 2 && Cols == 2) {
575 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
576 result[1] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
577 result[2] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
578 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
580 }
else if constexpr (Rows == 3 && Cols == 3) {
582 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
583 torch::mul(*Base::data_[5], *Base::data_[7]);
584 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
585 torch::mul(*Base::data_[1], *Base::data_[8]);
586 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
587 torch::mul(*Base::data_[2], *Base::data_[4]);
588 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
589 torch::mul(*Base::data_[3], *Base::data_[8]);
590 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
591 torch::mul(*Base::data_[2], *Base::data_[6]);
592 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
593 torch::mul(*Base::data_[0], *Base::data_[5]);
594 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
595 torch::mul(*Base::data_[4], *Base::data_[6]);
596 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
597 torch::mul(*Base::data_[0], *Base::data_[7]);
598 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
599 torch::mul(*Base::data_[1], *Base::data_[3]);
602 result[0] = std::make_shared<T>(torch::div(a11, det_));
603 result[1] = std::make_shared<T>(torch::div(a21, det_));
604 result[2] = std::make_shared<T>(torch::div(a31, det_));
605 result[3] = std::make_shared<T>(torch::div(a12, det_));
606 result[4] = std::make_shared<T>(torch::div(a22, det_));
607 result[5] = std::make_shared<T>(torch::div(a32, det_));
608 result[6] = std::make_shared<T>(torch::div(a13, det_));
609 result[7] = std::make_shared<T>(torch::div(a23, det_));
610 result[8] = std::make_shared<T>(torch::div(a33, det_));
612 }
else if constexpr (Rows == 4 && Cols == 4) {
614 auto a11 = torch::mul(*Base::data_[5],
615 (torch::mul(*Base::data_[10], *Base::data_[15]) -
616 torch::mul(*Base::data_[11], *Base::data_[14]))) -
617 torch::mul(*Base::data_[9],
618 (torch::mul(*Base::data_[6], *Base::data_[15]) -
619 torch::mul(*Base::data_[7], *Base::data_[14]))) -
620 torch::mul(*Base::data_[13],
621 (torch::mul(*Base::data_[7], *Base::data_[10]) -
622 torch::mul(*Base::data_[6], *Base::data_[11])));
624 auto a12 = torch::mul(*Base::data_[1],
625 (torch::mul(*Base::data_[11], *Base::data_[14]) -
626 torch::mul(*Base::data_[10], *Base::data_[15]))) -
627 torch::mul(*Base::data_[9],
628 (torch::mul(*Base::data_[3], *Base::data_[14]) -
629 torch::mul(*Base::data_[2], *Base::data_[15]))) -
630 torch::mul(*Base::data_[13],
631 (torch::mul(*Base::data_[2], *Base::data_[11]) -
632 torch::mul(*Base::data_[3], *Base::data_[10])));
634 auto a13 = torch::mul(*Base::data_[1],
635 (torch::mul(*Base::data_[6], *Base::data_[15]) -
636 torch::mul(*Base::data_[7], *Base::data_[14]))) -
637 torch::mul(*Base::data_[5],
638 (torch::mul(*Base::data_[2], *Base::data_[15]) -
639 torch::mul(*Base::data_[3], *Base::data_[14]))) -
640 torch::mul(*Base::data_[13],
641 (torch::mul(*Base::data_[3], *Base::data_[6]) -
642 torch::mul(*Base::data_[2], *Base::data_[7])));
644 auto a14 = torch::mul(*Base::data_[1],
645 (torch::mul(*Base::data_[7], *Base::data_[10]) -
646 torch::mul(*Base::data_[6], *Base::data_[11]))) -
647 torch::mul(*Base::data_[5],
648 (torch::mul(*Base::data_[3], *Base::data_[10]) -
649 torch::mul(*Base::data_[2], *Base::data_[11]))) -
650 torch::mul(*Base::data_[9],
651 (torch::mul(*Base::data_[2], *Base::data_[7]) -
652 torch::mul(*Base::data_[3], *Base::data_[6])));
654 auto a21 = torch::mul(*Base::data_[4],
655 (torch::mul(*Base::data_[11], *Base::data_[14]) -
656 torch::mul(*Base::data_[10], *Base::data_[15]))) -
657 torch::mul(*Base::data_[8],
658 (torch::mul(*Base::data_[7], *Base::data_[14]) -
659 torch::mul(*Base::data_[6], *Base::data_[15]))) -
660 torch::mul(*Base::data_[12],
661 (torch::mul(*Base::data_[6], *Base::data_[11]) -
662 torch::mul(*Base::data_[7], *Base::data_[10])));
664 auto a22 = torch::mul(*Base::data_[0],
665 (torch::mul(*Base::data_[10], *Base::data_[15]) -
666 torch::mul(*Base::data_[11], *Base::data_[14]))) -
667 torch::mul(*Base::data_[8],
668 (torch::mul(*Base::data_[2], *Base::data_[15]) -
669 torch::mul(*Base::data_[3], *Base::data_[14]))) -
670 torch::mul(*Base::data_[12],
671 (torch::mul(*Base::data_[3], *Base::data_[10]) -
672 torch::mul(*Base::data_[2], *Base::data_[11])));
674 auto a23 = torch::mul(*Base::data_[0],
675 (torch::mul(*Base::data_[7], *Base::data_[14]) -
676 torch::mul(*Base::data_[6], *Base::data_[15]))) -
677 torch::mul(*Base::data_[4],
678 (torch::mul(*Base::data_[3], *Base::data_[14]) -
679 torch::mul(*Base::data_[2], *Base::data_[15]))) -
680 torch::mul(*Base::data_[12],
681 (torch::mul(*Base::data_[2], *Base::data_[7]) -
682 torch::mul(*Base::data_[3], *Base::data_[6])));
684 auto a24 = torch::mul(*Base::data_[0],
685 (torch::mul(*Base::data_[6], *Base::data_[11]) -
686 torch::mul(*Base::data_[7], *Base::data_[10]))) -
687 torch::mul(*Base::data_[4],
688 (torch::mul(*Base::data_[2], *Base::data_[11]) -
689 torch::mul(*Base::data_[3], *Base::data_[10]))) -
690 torch::mul(*Base::data_[8],
691 (torch::mul(*Base::data_[3], *Base::data_[6]) -
692 torch::mul(*Base::data_[2], *Base::data_[7])));
694 auto a31 = torch::mul(*Base::data_[4],
695 (torch::mul(*Base::data_[9], *Base::data_[15]) -
696 torch::mul(*Base::data_[11], *Base::data_[13]))) -
697 torch::mul(*Base::data_[8],
698 (torch::mul(*Base::data_[5], *Base::data_[15]) -
699 torch::mul(*Base::data_[7], *Base::data_[13]))) -
700 torch::mul(*Base::data_[12],
701 (torch::mul(*Base::data_[7], *Base::data_[9]) -
702 torch::mul(*Base::data_[5], *Base::data_[11])));
704 auto a32 = torch::mul(*Base::data_[0],
705 (torch::mul(*Base::data_[11], *Base::data_[13]) -
706 torch::mul(*Base::data_[9], *Base::data_[15]))) -
707 torch::mul(*Base::data_[8],
708 (torch::mul(*Base::data_[3], *Base::data_[13]) -
709 torch::mul(*Base::data_[1], *Base::data_[15]))) -
710 torch::mul(*Base::data_[12],
711 (torch::mul(*Base::data_[1], *Base::data_[11]) -
712 torch::mul(*Base::data_[3], *Base::data_[9])));
714 auto a33 = torch::mul(*Base::data_[0],
715 (torch::mul(*Base::data_[5], *Base::data_[15]) -
716 torch::mul(*Base::data_[7], *Base::data_[13]))) -
717 torch::mul(*Base::data_[4],
718 (torch::mul(*Base::data_[1], *Base::data_[15]) -
719 torch::mul(*Base::data_[3], *Base::data_[13]))) -
720 torch::mul(*Base::data_[12],
721 (torch::mul(*Base::data_[3], *Base::data_[5]) -
722 torch::mul(*Base::data_[1], *Base::data_[7])));
724 auto a34 = torch::mul(*Base::data_[0],
725 (torch::mul(*Base::data_[7], *Base::data_[9]) -
726 torch::mul(*Base::data_[5], *Base::data_[11]))) -
727 torch::mul(*Base::data_[4],
728 (torch::mul(*Base::data_[3], *Base::data_[9]) -
729 torch::mul(*Base::data_[1], *Base::data_[11]))) -
730 torch::mul(*Base::data_[8],
731 (torch::mul(*Base::data_[1], *Base::data_[7]) -
732 torch::mul(*Base::data_[3], *Base::data_[5])));
734 auto a41 = torch::mul(*Base::data_[4],
735 (torch::mul(*Base::data_[10], *Base::data_[13]) -
736 torch::mul(*Base::data_[9], *Base::data_[14]))) -
737 torch::mul(*Base::data_[8],
738 (torch::mul(*Base::data_[6], *Base::data_[13]) -
739 torch::mul(*Base::data_[5], *Base::data_[14]))) -
740 torch::mul(*Base::data_[12],
741 (torch::mul(*Base::data_[5], *Base::data_[10]) -
742 torch::mul(*Base::data_[6], *Base::data_[9])));
744 auto a42 = torch::mul(*Base::data_[0],
745 (torch::mul(*Base::data_[9], *Base::data_[14]) -
746 torch::mul(*Base::data_[10], *Base::data_[13]))) -
747 torch::mul(*Base::data_[8],
748 (torch::mul(*Base::data_[1], *Base::data_[14]) -
749 torch::mul(*Base::data_[2], *Base::data_[13]))) -
750 torch::mul(*Base::data_[12],
751 (torch::mul(*Base::data_[2], *Base::data_[9]) -
752 torch::mul(*Base::data_[1], *Base::data_[10])));
754 auto a43 = torch::mul(*Base::data_[0],
755 (torch::mul(*Base::data_[6], *Base::data_[13]) -
756 torch::mul(*Base::data_[5], *Base::data_[14]))) -
757 torch::mul(*Base::data_[4],
758 (torch::mul(*Base::data_[2], *Base::data_[13]) -
759 torch::mul(*Base::data_[1], *Base::data_[14]))) -
760 torch::mul(*Base::data_[12],
761 (torch::mul(*Base::data_[1], *Base::data_[6]) -
762 torch::mul(*Base::data_[2], *Base::data_[5])));
764 auto a44 = torch::mul(*Base::data_[0],
765 (torch::mul(*Base::data_[5], *Base::data_[10]) -
766 torch::mul(*Base::data_[6], *Base::data_[9]))) -
767 torch::mul(*Base::data_[4],
768 (torch::mul(*Base::data_[1], *Base::data_[10]) -
769 torch::mul(*Base::data_[2], *Base::data_[9]))) -
770 torch::mul(*Base::data_[8],
771 (torch::mul(*Base::data_[2], *Base::data_[5]) -
772 torch::mul(*Base::data_[1], *Base::data_[6])));
775 result[0] = std::make_shared<T>(torch::div(a11, det_));
776 result[1] = std::make_shared<T>(torch::div(a21, det_));
777 result[2] = std::make_shared<T>(torch::div(a31, det_));
778 result[3] = std::make_shared<T>(torch::div(a41, det_));
779 result[4] = std::make_shared<T>(torch::div(a12, det_));
780 result[5] = std::make_shared<T>(torch::div(a22, det_));
781 result[6] = std::make_shared<T>(torch::div(a32, det_));
782 result[7] = std::make_shared<T>(torch::div(a42, det_));
783 result[8] = std::make_shared<T>(torch::div(a13, det_));
784 result[9] = std::make_shared<T>(torch::div(a23, det_));
785 result[10] = std::make_shared<T>(torch::div(a33, det_));
786 result[11] = std::make_shared<T>(torch::div(a43, det_));
787 result[12] = std::make_shared<T>(torch::div(a14, det_));
788 result[13] = std::make_shared<T>(torch::div(a24, det_));
789 result[14] = std::make_shared<T>(torch::div(a34, det_));
790 result[15] = std::make_shared<T>(torch::div(a44, det_));
793 throw std::runtime_error(
"Unsupported block tensor dimension");
808 if constexpr (Rows == Cols)
809 return this->invtr();
812 return (*
this) * (this->tr() * (*this)).invtr();
817 static_assert(Rows == Cols,
"trace(.) requires square block tensor");
819 if constexpr (Rows == 1)
822 else if constexpr (Rows == 2)
825 else if constexpr (Rows == 3)
829 else if constexpr (Rows == 4)
831 *Base::data_[10] + *Base::data_[15]);
834 throw std::runtime_error(
"Unsupported block tensor dimension");
839 template <std::size_t... Is>
840 inline auto norm_(std::index_sequence<Is...>)
const {
841 return torch::sqrt(std::apply([](
const auto&... tensors) {
842 return (tensors + ...);
843 }, std::make_tuple(std::get<Is>(Base::data_)->
square()...)));
854 template <std::size_t... Is>
855 inline auto normalize_(std::index_sequence<Is...> is)
const {
863 return normalize_(std::make_index_sequence<Rows*Cols>{});
868 template <std::size_t... Is>
870 return std::apply([](
const auto&... tensors) {
871 return (tensors + ...);
872 }, std::make_tuple(torch::mul(*std::get<Is>(Base::data_), *std::get<Is>(other.data_))...));
878 return BlockTensor<T, 1, 1>(std::make_shared<T>(dot_(std::make_index_sequence<Rows*Cols>{}, other)));
884 os << Base::name() <<
"\n";
885 for (std::size_t row = 0; row < Rows; ++row)
886 for (std::size_t col = 0; col < Cols; ++col)
887 os <<
"[" << row <<
"," << col <<
"] = \n"
888 << *Base::data_[Cols * row + col] <<
"\n";
894template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
899 for (std::size_t row = 0; row < Rows; ++row)
900 for (std::size_t col = 0; col < Cols; ++col) {
902 (lhs[Common * row]->dim() > rhs[col]->dim()
903 ? torch::mul(*lhs[Common * row], rhs[col]->unsqueeze(-1))
904 : (lhs[Common * row]->dim() < rhs[col]->dim()
905 ? torch::mul(lhs[Common * row]->unsqueeze(-1), *rhs[col])
906 : torch::mul(*lhs[Common * row], *rhs[col])));
907 for (std::size_t idx = 1; idx < Common; ++idx)
908 tmp += (lhs[Common * row]->dim() > rhs[col]->dim()
909 ? torch::mul(*lhs[Common * row + idx],
910 rhs[Cols * idx + col]->unsqueeze(-1))
911 : (lhs[Common * row]->dim() < rhs[col]->dim()
912 ? torch::mul(lhs[Common * row + idx]->unsqueeze(-1),
913 *rhs[Cols * idx + col])
914 : torch::mul(*lhs[Common * row + idx],
915 *rhs[Cols * idx + col])));
916 result[Cols * row + col] = std::make_shared<T>(tmp);
927template <
typename T, std::
size_t Rows, std::
size_t Cols, std::
size_t Slices>
937 inline static constexpr std::size_t
rows() {
return Rows; }
940 inline static constexpr std::size_t
cols() {
return Cols; }
943 inline static constexpr std::size_t
slices() {
return Slices; }
945 using Base::operator();
949 std::size_t slice)
const {
950 assert(0 <= row && row < Rows && 0 <= col && col < Cols && 0 <= slice &&
952 return *Base::data_[Rows * Cols * slice + Cols * row + col];
956 inline T &
operator()(std::size_t row, std::size_t col, std::size_t slice) {
957 assert(0 <= row && row < Rows && 0 <= col && col < Cols && 0 <= slice &&
959 return *Base::data_[Rows * Cols * slice + Cols * row + col];
965 template <
typename D>
966 inline T &
set(std::size_t row, std::size_t col, std::size_t slice, D &&data) {
967 Base::data_[Rows * Cols * slice + Cols * row + col] =
968 make_shared<D>(std::forward<D>(data));
969 return *Base::data_[Rows * Cols * slice + Cols * row + col];
973 inline auto slice(std::size_t slice)
const {
974 assert(0 <= slice && slice < Slices);
976 for (std::size_t row = 0; row < Rows; ++row)
977 for (std::size_t col = 0; col < Cols; ++col)
978 result[Cols * row + col] =
979 Base::data_[Rows * Cols * slice + Cols * row + col];
987 for (std::size_t slice = 0; slice < Slices; ++slice)
988 for (std::size_t row = 0; row < Rows; ++row)
989 for (std::size_t col = 0; col < Cols; ++col)
990 result[Rows * Slices * col + Slices * row + slice] =
991 Base::data_[Rows * Cols * slice + Cols * row + col];
1000 for (std::size_t slice = 0; slice < Slices; ++slice)
1001 for (std::size_t row = 0; row < Rows; ++row)
1002 for (std::size_t col = 0; col < Cols; ++col)
1003 result[Rows * Cols * slice + Rows * col + row] =
1004 Base::data_[Rows * Cols * slice + Cols * row + col];
1012 for (std::size_t slice = 0; slice < Slices; ++slice)
1013 for (std::size_t row = 0; row < Rows; ++row)
1014 for (std::size_t col = 0; col < Cols; ++col)
1015 result[Slices * Cols * row + Cols * slice + col] =
1016 Base::data_[Rows * Cols * slice + Cols * row + col];
1024 for (std::size_t slice = 0; slice < Slices; ++slice)
1025 for (std::size_t row = 0; row < Rows; ++row)
1026 for (std::size_t col = 0; col < Cols; ++col)
1027 result[Slices * Rows * col + Rows * slice + row] =
1028 Base::data_[Rows * Cols * slice + Cols * row + col];
1035 os << Base::name() <<
"\n";
1036 for (std::size_t slice = 0; slice < Slices; ++slice)
1037 for (std::size_t row = 0; row < Rows; ++row)
1038 for (std::size_t col = 0; col < Cols; ++col)
1039 os <<
"[" << row <<
"," << col <<
"," << slice <<
"] = \n"
1040 << *Base::data_[Rows * Cols * slice + Cols * row + col] <<
"\n";
1046template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
1047 std::size_t Cols, std::size_t Slices>
1051 for (std::size_t slice = 0; slice < Slices; ++slice)
1052 for (std::size_t row = 0; row < Rows; ++row)
1053 for (std::size_t col = 0; col < Cols; ++col) {
1055 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1056 ? torch::mul(*lhs[Common * row],
1057 rhs[Rows * Cols * slice + col]->unsqueeze(-1))
1058 : (lhs[Common * row]->dim() <
1059 rhs[Rows * Cols * slice + col]->dim()
1060 ? torch::mul(lhs[Common * row]->unsqueeze(-1),
1061 *rhs[Rows * Cols * slice + col])
1062 : torch::mul(*lhs[Common * row],
1063 *rhs[Rows * Cols * slice + col])));
1064 for (std::size_t idx = 1; idx < Common; ++idx)
1066 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1068 *lhs[Common * row + idx],
1069 rhs[Rows * Cols * slice + Cols * idx + col]->unsqueeze(
1071 : (lhs[Common * row]->dim() <
1072 rhs[Rows * Cols * slice + col]->dim()
1074 lhs[Common * row + idx]->unsqueeze(-1),
1075 *rhs[Rows * Cols * slice + Cols * idx + col])
1077 *lhs[Common * row + idx],
1078 *rhs[Rows * Cols * slice + Cols * idx + col])));
1079 result[Rows * Cols * slice + Cols * row + col] =
1080 std::make_shared<T>(tmp);
1087template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
1088 std::size_t Cols, std::size_t Slices>
1092 for (std::size_t slice = 0; slice < Slices; ++slice)
1093 for (std::size_t row = 0; row < Rows; ++row)
1094 for (std::size_t col = 0; col < Cols; ++col) {
1096 (lhs[Rows * Cols * slice + Common * row]->dim() > rhs[col]->dim()
1097 ? torch::mul(*lhs[Rows * Cols * slice + Common * row],
1098 rhs[col]->unsqueeze(-1))
1099 : (lhs[Rows * Cols * slice + Common * row]->dim() <
1101 ? torch::mul(lhs[Rows * Cols * slice + Common * row]
1104 : torch::mul(*lhs[Rows * Cols * slice + Common * row],
1106 for (std::size_t idx = 1; idx < Common; ++idx)
1108 (lhs[Rows * Cols * slice + Common * row + idx]->dim() >
1109 rhs[Cols * idx + col]->dim()
1110 ? torch::mul(*lhs[Rows * Cols * slice + Common * row + idx],
1111 rhs[Cols * idx + col])
1113 : (lhs[Rows * Cols * slice + Common * row + idx]->dim() <
1114 rhs[Cols * idx + col]->dim()
1116 lhs[Rows * Cols * slice + Common * row + idx]
1118 *rhs[Cols * idx + col])
1120 *lhs[Rows * Cols * slice + Common * row + idx],
1121 *rhs[Cols * idx + col])));
1122 result[Rows * Cols * slice + Cols * row + col] =
1123 std::make_shared<T>(tmp);
1128#define blocktensor_unary_op(name) \
1129 template <typename T, std::size_t... Dims> \
1130 inline auto name(const BlockTensor<T, Dims...> &input) { \
1131 BlockTensor<T, Dims...> result; \
1132 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1133 result[idx] = std::make_shared<T>(torch::name(*input[idx])); \
1137#define blocktensor_unary_special_op(name) \
1138 template <typename T, std::size_t... Dims> \
1139 inline auto name(const BlockTensor<T, Dims...> &input) { \
1140 BlockTensor<T, Dims...> result; \
1141 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1142 result[idx] = std::make_shared<T>(torch::special::name(*input[idx])); \
1146#define blocktensor_binary_op(name) \
1147 template <typename T, typename U, std::size_t... Dims> \
1148 inline auto name(const BlockTensor<T, Dims...> &input, \
1149 const BlockTensor<U, Dims...> &other) { \
1150 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1151 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1153 std::make_shared<T>(torch::name(*input[idx], *other[idx])); \
1157#define blocktensor_binary_special_op(name) \
1158 template <typename T, typename U, std::size_t... Dims> \
1159 inline auto name(const BlockTensor<T, Dims...> &input, \
1160 const BlockTensor<U, Dims...> &other) { \
1161 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1162 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1164 std::make_shared<T>(torch::special::name(*input[idx], *other[idx])); \
1191template <
typename T,
typename U,
typename V, std::size_t... Dims>
1195 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1197 std::make_shared<T>(torch::add(*input[idx], *other[idx], alpha));
1203template <
typename T,
typename U,
typename V, std::size_t... Dims>
1206 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1207 result[idx] = std::make_shared<T>(torch::add(*input[idx], other, alpha));
1213template <
typename T,
typename U,
typename V, std::size_t... Dims>
1216 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1217 result[idx] = std::make_shared<T>(torch::add(input, *other[idx], alpha));
1224template <
typename T,
typename U,
typename V,
typename W, std::size_t... Dims>
1229 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1230 result[idx] = std::make_shared<T>(
1231 torch::addcdiv(*input[idx], *tensor1[idx], *tensor2[idx], value));
1239template <
typename T,
typename U,
typename V,
typename W, std::size_t... Dims>
1244 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1245 result[idx] = std::make_shared<T>(
1246 torch::addcmul(*input[idx], *tensor1[idx], *tensor2[idx], value));
1287#if TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11 || \
1288 TORCH_VERSION_MAJOR >= 2
1324template <
typename T,
typename U, std::size_t... Dims>
1327 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1328 result[idx] = std::make_shared<T>(torch::clamp(*input[idx], min, max));
1333template <
typename T,
typename U, std::size_t... Dims>
1336 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1337 result[idx] = std::make_shared<T>(torch::clip(*input[idx], min, max));
1374template <typename T,
std::
size_t Rows,
std::
size_t Cols>
1376 return input.dot(tensor);
1593template <typename T, typename U, typename V,
std::
size_t... Dims>
1595 const
BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1597 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1599 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1604template <
typename T,
typename U,
typename V, std::size_t... Dims>
1608 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1610 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1631template <typename T, typename U,
std::
size_t... Dims>
1635 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1636 result[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1642template <
typename T,
typename U, std::size_t... Dims>
1645 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1646 result[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1652template <
typename T,
typename U, std::size_t... Dims>
1655 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1656 result[idx] = std::make_shared<U>(lhs + *rhs[idx]);
1661template <
typename T,
typename U, std::size_t... Dims>
1664 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1665 lhs[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1670template <
typename T,
typename U, std::size_t... Dims>
1672 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1673 lhs[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1679template <
typename T,
typename U, std::size_t... Dims>
1683 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1684 result[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1690template <
typename T,
typename U, std::size_t... Dims>
1693 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1694 result[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1700template <
typename T,
typename U, std::size_t... Dims>
1703 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1704 result[idx] = std::make_shared<U>(lhs - *rhs[idx]);
1709template <
typename T,
typename U, std::size_t... Dims>
1712 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1713 lhs[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1718template <
typename T,
typename U, std::size_t... Dims>
1720 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1721 lhs[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1727template <
typename T,
typename U, std::size_t... Dims>
1730 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1732 (lhs[idx]->dim() > rhs.dim()
1733 ? std::make_shared<T>(*lhs[idx] * rhs.unsqueeze(-1))
1734 : (lhs[idx]->dim() < rhs.dim()
1735 ? std::make_shared<T>(lhs[idx]->unsqueeze(-1) * rhs)
1736 : std::make_shared<T>(*lhs[idx] * rhs)));
1743template <
typename T,
typename U, std::size_t... Dims>
1746 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1748 (lhs.dim() > rhs[idx]->dim()
1749 ? std::make_shared<U>(lhs * rhs[idx]->unsqueeze(-1))
1750 : (lhs.dim() < rhs[idx]->dim()
1751 ? std::make_shared<U>(lhs.unsqueeze(-1) * *rhs[idx])
1752 : std::make_shared<U>(lhs * *rhs[idx])));
1757template <
typename T,
typename U, std::size_t... TDims, std::size_t... UDims>
1760 if constexpr ((
sizeof...(TDims) !=
sizeof...(UDims)) ||
1761 ((TDims != UDims) || ...))
1765 for (std::size_t idx = 0; idx < (TDims * ...); ++idx)
1766 result = result && torch::equal(*lhs[idx], *rhs[idx]);
1772template <
typename T,
typename U, std::size_t... TDims, std::size_t... UDims>
1775 return !(lhs == rhs);
#define blocktensor_unary_op(name)
Definition blocktensor.hpp:1128
#define blocktensor_unary_special_op(name)
Definition blocktensor.hpp:1137
#define blocktensor_binary_op(name)
Definition blocktensor.hpp:1146
#define blocktensor_binary_special_op(name)
Definition blocktensor.hpp:1157
static constexpr std::size_t slices()
Returns the number of slices.
Definition blocktensor.hpp:943
auto reorder_jik() const
Returns a new block tensor with rows and columns transposed and slices remaining fixed....
Definition blocktensor.hpp:998
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the BSplineCommon object.
Definition blocktensor.hpp:1034
const T & operator()(std::size_t row, std::size_t col, std::size_t slice) const
Returns a constant reference to entry (row, col, slice)
Definition blocktensor.hpp:948
T & set(std::size_t row, std::size_t col, std::size_t slice, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:966
auto reorder_kij() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1022
auto reorder_ikj() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:985
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:937
T & operator()(std::size_t row, std::size_t col, std::size_t slice)
Returns a non-constant reference to entry (row, col, slice)
Definition blocktensor.hpp:956
auto reorder_kji() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1010
auto slice(std::size_t slice) const
Returns a rank-2 tensor of the k-th slice.
Definition blocktensor.hpp:973
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:940
auto inv() const
Returns the inverse of the block tensor.
Definition blocktensor.hpp:312
auto dot_(std::index_sequence< Is... >, const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:869
auto ginv() const
Returns the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:551
auto ginvtr() const
Returns the transpose of the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:807
auto tr() const
Returns the transpose of the block tensor.
Definition blocktensor.hpp:225
auto invtr() const
Returns the transpose of the inverse of the block tensor.
Definition blocktensor.hpp:564
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:198
auto normalize_(std::index_sequence< Is... > is) const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:855
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:195
auto norm() const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:848
const T & operator()(std::size_t row, std::size_t col) const
Returns a constant reference to entry (row, col)
Definition blocktensor.hpp:203
T & set(std::size_t row, std::size_t col, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:218
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:883
auto dot(const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:877
auto trace() const
Returns the trace of the block tensor.
Definition blocktensor.hpp:816
auto det() const
Returns the determinant of a square block tensor.
Definition blocktensor.hpp:237
T & operator()(std::size_t row, std::size_t col)
Returns a non-constant reference to entry (row, col)
Definition blocktensor.hpp:209
auto norm_(std::index_sequence< Is... >) const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:840
auto normalize() const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:862
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:170
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:174
Compile-time block tensor core.
Definition blocktensor.hpp:50
const std::array< std::shared_ptr< T >,(Dims *...)> & data() const
Returns a constant reference to the data array.
Definition blocktensor.hpp:109
static constexpr auto dims()
Returns all dimensions as array.
Definition blocktensor.hpp:90
BlockTensorCore(BlockTensorCore< Ts, dims... > &&...other)
Constructur from BlockTensorCore objects.
Definition blocktensor.hpp:62
BlockTensorCore(Ts &&...data)
Constructor from variadic templates.
Definition blocktensor.hpp:86
BlockTensorCore()=default
Default constructor.
std::shared_ptr< T > & operator[](std::size_t idx)
Returns a non-constant shared pointer to entry (idx)
Definition blocktensor.hpp:123
T & set(std::size_t idx, Data &&data)
Stores the given data object at the given index.
Definition blocktensor.hpp:141
virtual void pretty_print(std::ostream &os=Log(log::info)) const noexcept=0
Returns a string representation of the BlockTensorCore object.
static constexpr std::size_t dim()
Returns the i-th dimension.
Definition blocktensor.hpp:95
const std::shared_ptr< T > & operator[](std::size_t idx) const
Returns a constant shared pointer to entry (idx)
Definition blocktensor.hpp:117
BlockTensorCore(BlockTensor< Ts, dims... > &&...other)
Constructur from BlockTensor objects.
Definition blocktensor.hpp:74
static constexpr std::size_t entries()
Returns the total number of entries.
Definition blocktensor.hpp:106
std::array< std::shared_ptr< T >,(Dims *...)> & data()
Returns a non-constant reference to the data array.
Definition blocktensor.hpp:114
static constexpr std::size_t size()
Returns the number of dimensions.
Definition blocktensor.hpp:103
const T & operator()(std::size_t idx) const
Returns a constant reference to entry (idx)
Definition blocktensor.hpp:129
std::array< std::shared_ptr< T >,(Dims *...)> data_
Array storing the data.
Definition blocktensor.hpp:54
T & operator()(std::size_t idx)
Returns a non-constant reference to entry (idx)
Definition blocktensor.hpp:135
Full qualified name descriptor.
Definition fqn.hpp:26
Full qualified name utility functions.
auto addcmul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 multiplied by the elements of tensor2,...
Definition blocktensor.hpp:1240
auto log2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-2 of the elements of input
Definition blocktensor.hpp:1454
auto tan(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the tangent of the elements of input.
Definition blocktensor.hpp:1616
auto square(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square of the elements of input
Definition blocktensor.hpp:1590
auto mul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the product of each element of input and other
Definition blocktensor.hpp:1506
auto divide(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for div()
Definition blocktensor.hpp:1366
auto exp2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the base-2 exponential of the elements of input
Definition blocktensor.hpp:1397
auto frexp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the decomposition of the elements of input into mantissae and exponen...
Definition blocktensor.hpp:1425
auto xlogy(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Computes input * log(other)
Definition blocktensor.hpp:1627
auto operator-(const BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Subtracts one compile-time block tensor from another and returns a new compile-time block tensor.
Definition blocktensor.hpp:1680
auto floor(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the floor of the elements of input, the largest integer less than or ...
Definition blocktensor.hpp:1413
auto bitwise_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the bitwise NOT of the elements of input
Definition blocktensor.hpp:1295
auto i0(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise zeroth order modified Bessel function of the first k...
Definition blocktensor.hpp:1488
auto float_power(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input raised to the power of exponent,...
Definition blocktensor.hpp:1409
auto round(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input rounded to the nearest integer.
Definition blocktensor.hpp:1547
auto hypot(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
logit
Definition blocktensor.hpp:1483
auto imag(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the imaginary values of the elements of input
Definition blocktensor.hpp:1429
auto atan(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arctangent of the elements of input
Definition blocktensor.hpp:1270
auto copysign(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the magnitude of the elements of input and the sign of the elements o...
Definition blocktensor.hpp:1347
auto add(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Returns a new block tensor with the elements of other, scaled by alpha, added to the elements of inpu...
Definition blocktensor.hpp:1192
auto asin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arcsine of the elements of input
Definition blocktensor.hpp:1256
bool operator==(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are equal.
Definition blocktensor.hpp:1758
auto angle(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the angle (in radians) of the elements of input
Definition blocktensor.hpp:1252
auto nextafter(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Return a new block tensor with the next elementwise floating-point value after input towards other
Definition blocktensor.hpp:1520
auto arcsinh(const BlockTensor< T, Dims... > &input)
Alias for asinh()
Definition blocktensor.hpp:1266
auto sub(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Subtracts other, scaled by alpha, from input.
Definition blocktensor.hpp:1594
auto sign(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input
Definition blocktensor.hpp:1562
auto logical_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical AND of the elements of input and other
Definition blocktensor.hpp:1466
auto absolute(const BlockTensor< T, Dims... > &input)
Alias for abs()
Definition blocktensor.hpp:1173
auto fix(const BlockTensor< T, Dims... > &input)
Alias for trunc()
Definition blocktensor.hpp:1404
auto sinc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the normalized sinc of the elements of input
Definition blocktensor.hpp:1578
auto logical_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise logical NOT of the elements of input
Definition blocktensor.hpp:1470
auto positive(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the input
Definition blocktensor.hpp:1523
auto sqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square-root of the elements of input
Definition blocktensor.hpp:1586
auto reciprocal(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the elements of input
Definition blocktensor.hpp:1539
auto clip(const BlockTensor< T, Dims... > &input, U min, U max)
Alias for clamp()
Definition blocktensor.hpp:1334
auto arcsin(const BlockTensor< T, Dims... > &input)
Alias for asin()
Definition blocktensor.hpp:1259
auto bitwise_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise OR of the elements of input and other
Definition blocktensor.hpp:1303
auto atanh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic tangent of the elements of input
Definition blocktensor.hpp:1277
auto subtract(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Alias for sub()
Definition blocktensor.hpp:1605
auto atan2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the arctangent of the elements in input and other with consideration ...
Definition blocktensor.hpp:1285
auto expit(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the expit (also known as the logistic sigmoid function) of the elemen...
Definition blocktensor.hpp:1555
auto rsqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the square-root of the elements of input
Definition blocktensor.hpp:1551
auto sin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the sine of the elements of input
Definition blocktensor.hpp:1574
auto cosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1355
std::ostream & operator<<(std::ostream &os, const BlockTensorCore< T, Dims... > &obj)
Prints (as string) a compile-time block tensor object.
Definition blocktensor.hpp:154
auto bitwise_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise AND of the elements of input and other
Definition blocktensor.hpp:1299
auto erfc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the complementary error function of the elements of input
Definition blocktensor.hpp:1385
auto gammainc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized lower incomplete gamma function of each element of in...
Definition blocktensor.hpp:1492
auto operator-=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Decrements one compile-time block tensor by another.
Definition blocktensor.hpp:1710
auto arccos(const BlockTensor< T, Dims... > &input)
Alias for acos()
Definition blocktensor.hpp:1180
auto negative(const BlockTensor< T, Dims... > &input)
Alias for neg()
Definition blocktensor.hpp:1516
auto multiply(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for mul()
Definition blocktensor.hpp:1509
auto dot(const BlockTensor< T, Rows, Cols > &input, const BlockTensor< T, Rows, Cols > &tensor)
Returns a new block tensor with the dot product of the two input block tensors.
Definition blocktensor.hpp:1375
auto logaddexp2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input ...
Definition blocktensor.hpp:1462
auto pow(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the power of each element in input with exponent other
Definition blocktensor.hpp:1527
auto bitwise_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise XOR of the elements of input and other
Definition blocktensor.hpp:1307
auto ldexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input multiplied by 2**other.
Definition blocktensor.hpp:1433
auto igammac(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1502
auto neg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the negative of the elements of input
Definition blocktensor.hpp:1513
auto exp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential of the elements of input
Definition blocktensor.hpp:1393
auto arctan(const BlockTensor< T, Dims... > &input)
Alias for atan()
Definition blocktensor.hpp:1273
auto log1p(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of (1 + the elements of input)
Definition blocktensor.hpp:1450
auto arctanh(const BlockTensor< T, Dims... > &input)
Alias for atanh()
Definition blocktensor.hpp:1280
auto ceil(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the ceil of the elements of input, the smallest integer greater than ...
Definition blocktensor.hpp:1320
auto trunc(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the truncated integer values of the elements of input.
Definition blocktensor.hpp:1624
auto cos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the cosine of the elements of input
Definition blocktensor.hpp:1351
auto erfinv(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse error function of the elements of input
Definition blocktensor.hpp:1389
bool operator!=(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are not equal.
Definition blocktensor.hpp:1773
auto expm1(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential minus 1 of the elements of input
Definition blocktensor.hpp:1401
auto signbit(const BlockTensor< T, Dims... > &input)
Tests if each element of input has its sign bit set (is less than zero) or not.
Definition blocktensor.hpp:1570
auto conj_physical(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the conjugate of the elements of input tensor.
Definition blocktensor.hpp:1343
auto sinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic sine of the elements of input
Definition blocktensor.hpp:1582
auto remainder(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the modulus of the elements of input
Definition blocktensor.hpp:1543
auto digamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithmic derivative of the gamma function of the elements of i...
Definition blocktensor.hpp:1370
auto asinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic sine of the elements of input
Definition blocktensor.hpp:1263
auto div(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input divided by the elements of other
Definition blocktensor.hpp:1363
auto igamma(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1495
auto bitwise_left_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the left arithmetic shift of the elements of input by other bits.
Definition blocktensor.hpp:1311
auto rad2deg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with each of the elements of input converted from angles in radians to deg...
Definition blocktensor.hpp:1531
auto real(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the real values of the elements of input
Definition blocktensor.hpp:1535
auto lgamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of the absolute value of the gamma function of ...
Definition blocktensor.hpp:1438
auto gammaincc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized upper incomplete gamma function of each element of in...
Definition blocktensor.hpp:1499
auto acos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse cosine of the elements of input
Definition blocktensor.hpp:1177
auto fmod(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the fmod of the elements of input and other
Definition blocktensor.hpp:1417
auto bitwise_right_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the right arithmetic shift of the element of input by other bits.
Definition blocktensor.hpp:1315
auto operator+=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Increments one compile-time block tensor by another.
Definition blocktensor.hpp:1662
auto sgn(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input, extension to complex value.
Definition blocktensor.hpp:1566
auto arccosh(const BlockTensor< T, Dims... > &input)
Alias for acosh()`.
Definition blocktensor.hpp:1187
auto abs(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the absolute value of the elements of input
Definition blocktensor.hpp:1170
auto logical_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical XOR of the elements of input and other
Definition blocktensor.hpp:1478
auto addcdiv(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 divided by the elements of tensor2,...
Definition blocktensor.hpp:1225
auto deg2rad(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input converted from angles in degrees to radians.
Definition blocktensor.hpp:1359
auto logaddexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input
Definition blocktensor.hpp:1458
auto erf(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the error function of the elements of input
Definition blocktensor.hpp:1381
auto operator*(const BlockTensor< T, Rows, Common > &lhs, const BlockTensor< U, Common, Cols > &rhs)
Multiplies one compile-time rank-2 block tensor with another compile-time rank-2 block tensor.
Definition blocktensor.hpp:896
auto log10(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-10 of the elements of input
Definition blocktensor.hpp:1446
auto make_shared(T &&arg)
Returns an std::shared_ptr<T> object from arg.
Definition blocktensor.hpp:38
auto acosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1184
auto clamp(const BlockTensor< T, Dims... > &input, U min, U max)
Returns a new block tensor with the elements of input clamped into the range [ min,...
Definition blocktensor.hpp:1325
auto logical_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical OR of the elements of input and other
Definition blocktensor.hpp:1474
auto frac(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the fractional portion of the elements of input
Definition blocktensor.hpp:1421
Forward declaration of BlockTensor.
Definition blocktensor.hpp:46
Definition boundary.hpp:22
constexpr auto operator+(deriv lhs, deriv rhs)
Adds two enumerators for specifying the derivative of B-spline evaluation.
Definition bspline.hpp:91
struct iganet::@0 Log
Logger.
log
Enumerator for specifying the logging level.
Definition core.hpp:90
Type trait checks if template argument is of type std::shared_ptr<T>
Definition blocktensor.hpp:31