37 return std::forward<std::decay_t<T>>(arg);
39 return std::make_shared<std::decay_t<T>>(std::forward<T>(arg));
46template <
typename T, std::size_t... Dims>
51 std::array<std::shared_ptr<T>, (Dims * ...)>
data_;
58 template <
typename... Ts, std::size_t...
dims>
60 auto it =
data_.begin();
61 (std::transform(other.data().begin(), other.data().end(), it,
62 [&it]<
typename D>(D &&d) {
64 return std::forward<D>(d);
70 template <
typename... Ts, std::size_t...
dims>
72 auto it =
data_.begin();
73 (std::transform(other.data().begin(), other.data().end(), it,
74 [&it]<
typename D>(D &&d) {
76 return std::forward<D>(d);
82 template <
typename... Ts>
84 :
data_({make_shared<Ts>(std::forward<Ts>(
data))...}) {}
87 inline static constexpr auto dims() {
88 return std::array<std::size_t,
sizeof...(Dims)>({Dims...});
92 template <std::
size_t i>
inline static constexpr std::size_t
dim() {
93 if constexpr (i <
sizeof...(Dims))
94 return std::get<i>(std::forward_as_tuple(Dims...));
100 inline static constexpr std::size_t
size() {
return sizeof...(Dims); }
103 inline static constexpr std::size_t
entries() {
return (Dims * ...); }
106 inline const std::array<std::shared_ptr<T>, (Dims * ...)> &
data()
const {
111 inline std::array<std::shared_ptr<T>, (Dims * ...)> &
data() {
return data_; }
114 inline const std::shared_ptr<T> &
operator[](std::size_t idx)
const {
115 assert(idx < (Dims * ...));
121 assert(idx < (Dims * ...));
127 assert(idx < (Dims * ...));
133 assert(idx < (Dims * ...));
138 template <
typename Data>
inline T &
set(std::size_t idx, Data &&
data) {
139 assert(idx < (Dims * ...));
140 data_[idx] = make_shared<Data>(std::forward<Data>(
data));
145 inline void pretty_print(std::ostream &os)
const noexcept override = 0;
149template <
typename T, std::size_t... Dims>
157template <
typename T, std::
size_t Rows>
166 inline static constexpr std::size_t
rows() {
return Rows; }
170 os << Base::name() <<
"\n";
171 for (std::size_t row = 0; row < Rows; ++row)
172 os <<
"[" << row <<
"] = \n" << *Base::data_[row] <<
"\n";
181template <
typename T, std::
size_t Rows, std::
size_t Cols>
190 inline static constexpr std::size_t
rows() {
return Rows; }
193 inline static constexpr std::size_t
cols() {
return Cols; }
195 using Base::operator();
198 inline const T &
operator()(std::size_t row, std::size_t col)
const {
199 assert(row < Rows && col < Cols);
200 return *Base::data_[Cols * row + col];
205 assert(row < Rows && col < Cols);
206 return *Base::data_[Cols * row + col];
212 template <
typename D>
213 inline T &
set(std::size_t row, std::size_t col, D &&data) {
214 assert(row < Rows && col < Cols);
215 Base::data_[Cols * row + col] = make_shared<D>(std::forward<D>(data));
216 return *Base::data_[Cols * row + col];
220 inline auto tr()
const {
222 for (std::size_t row = 0; row < Rows; ++row)
223 for (std::size_t col = 0; col < Cols; ++col)
224 result[Rows * col + row] = Base::data_[Cols * row + col];
233 if constexpr (Rows == 1 && Cols == 1) {
234 auto result = *Base::data_[0];
236 }
else if constexpr (Rows == 2 && Cols == 2) {
237 auto result = torch::mul(*Base::data_[0], *Base::data_[3]) -
238 torch::mul(*Base::data_[1], *Base::data_[2]);
240 }
else if constexpr (Rows == 3 && Cols == 3) {
242 torch::mul(*Base::data_[0],
243 torch::mul(*Base::data_[4], *Base::data_[8]) -
244 torch::mul(*Base::data_[5], *Base::data_[7])) -
245 torch::mul(*Base::data_[1],
246 torch::mul(*Base::data_[3], *Base::data_[8]) -
247 torch::mul(*Base::data_[5], *Base::data_[6])) +
248 torch::mul(*Base::data_[2],
249 torch::mul(*Base::data_[3], *Base::data_[7]) -
250 torch::mul(*Base::data_[4], *Base::data_[6]));
252 }
else if constexpr (Rows == 4 && Cols == 4) {
253 auto a11 = torch::mul(*Base::data_[5],
254 (torch::mul(*Base::data_[10], *Base::data_[15]) -
255 torch::mul(*Base::data_[11], *Base::data_[14]))) -
256 torch::mul(*Base::data_[9],
257 (torch::mul(*Base::data_[6], *Base::data_[15]) -
258 torch::mul(*Base::data_[7], *Base::data_[14]))) -
259 torch::mul(*Base::data_[13],
260 (torch::mul(*Base::data_[7], *Base::data_[10]) -
261 torch::mul(*Base::data_[6], *Base::data_[11])));
263 auto a21 = torch::mul(*Base::data_[4],
264 (torch::mul(*Base::data_[11], *Base::data_[14]) -
265 torch::mul(*Base::data_[10], *Base::data_[15]))) -
266 torch::mul(*Base::data_[8],
267 (torch::mul(*Base::data_[7], *Base::data_[14]) -
268 torch::mul(*Base::data_[6], *Base::data_[15]))) -
269 torch::mul(*Base::data_[12],
270 (torch::mul(*Base::data_[6], *Base::data_[11]) -
271 torch::mul(*Base::data_[7], *Base::data_[10])));
273 auto a31 = torch::mul(*Base::data_[4],
274 (torch::mul(*Base::data_[9], *Base::data_[15]) -
275 torch::mul(*Base::data_[11], *Base::data_[13]))) -
276 torch::mul(*Base::data_[8],
277 (torch::mul(*Base::data_[5], *Base::data_[15]) -
278 torch::mul(*Base::data_[7], *Base::data_[13]))) -
279 torch::mul(*Base::data_[12],
280 (torch::mul(*Base::data_[7], *Base::data_[9]) -
281 torch::mul(*Base::data_[5], *Base::data_[11])));
283 auto a41 = torch::mul(*Base::data_[4],
284 (torch::mul(*Base::data_[10], *Base::data_[13]) -
285 torch::mul(*Base::data_[9], *Base::data_[14]))) -
286 torch::mul(*Base::data_[8],
287 (torch::mul(*Base::data_[6], *Base::data_[13]) -
288 torch::mul(*Base::data_[5], *Base::data_[14]))) -
289 torch::mul(*Base::data_[12],
290 (torch::mul(*Base::data_[5], *Base::data_[10]) -
291 torch::mul(*Base::data_[6], *Base::data_[9])));
294 torch::mul(*Base::data_[0], a11) + torch::mul(*Base::data_[1], a21) +
295 torch::mul(*Base::data_[2], a31) + torch::mul(*Base::data_[3], a41);
299 throw std::runtime_error(
"Unsupported block tensor dimension");
309 auto det_ = this->det();
311 if constexpr (Rows == 1 && Cols == 1) {
313 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
315 }
else if constexpr (Rows == 2 && Cols == 2) {
318 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
319 result[1] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
320 result[2] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
321 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
323 }
else if constexpr (Rows == 3 && Cols == 3) {
325 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
326 torch::mul(*Base::data_[5], *Base::data_[7]);
327 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
328 torch::mul(*Base::data_[1], *Base::data_[8]);
329 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
330 torch::mul(*Base::data_[2], *Base::data_[4]);
331 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
332 torch::mul(*Base::data_[3], *Base::data_[8]);
333 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
334 torch::mul(*Base::data_[2], *Base::data_[6]);
335 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
336 torch::mul(*Base::data_[0], *Base::data_[5]);
337 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
338 torch::mul(*Base::data_[4], *Base::data_[6]);
339 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
340 torch::mul(*Base::data_[0], *Base::data_[7]);
341 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
342 torch::mul(*Base::data_[1], *Base::data_[3]);
345 result[0] = std::make_shared<T>(torch::div(a11, det_));
346 result[1] = std::make_shared<T>(torch::div(a12, det_));
347 result[2] = std::make_shared<T>(torch::div(a13, det_));
348 result[3] = std::make_shared<T>(torch::div(a21, det_));
349 result[4] = std::make_shared<T>(torch::div(a22, det_));
350 result[5] = std::make_shared<T>(torch::div(a23, det_));
351 result[6] = std::make_shared<T>(torch::div(a31, det_));
352 result[7] = std::make_shared<T>(torch::div(a32, det_));
353 result[8] = std::make_shared<T>(torch::div(a33, det_));
355 }
else if constexpr (Rows == 4 && Cols == 4) {
356 auto a11 = torch::mul(*Base::data_[5],
357 (torch::mul(*Base::data_[10], *Base::data_[15]) -
358 torch::mul(*Base::data_[11], *Base::data_[14]))) -
359 torch::mul(*Base::data_[9],
360 (torch::mul(*Base::data_[6], *Base::data_[15]) -
361 torch::mul(*Base::data_[7], *Base::data_[14]))) -
362 torch::mul(*Base::data_[13],
363 (torch::mul(*Base::data_[7], *Base::data_[10]) -
364 torch::mul(*Base::data_[6], *Base::data_[11])));
366 auto a12 = torch::mul(*Base::data_[1],
367 (torch::mul(*Base::data_[11], *Base::data_[14]) -
368 torch::mul(*Base::data_[10], *Base::data_[15]))) -
369 torch::mul(*Base::data_[9],
370 (torch::mul(*Base::data_[3], *Base::data_[14]) -
371 torch::mul(*Base::data_[2], *Base::data_[15]))) -
372 torch::mul(*Base::data_[13],
373 (torch::mul(*Base::data_[2], *Base::data_[11]) -
374 torch::mul(*Base::data_[3], *Base::data_[10])));
376 auto a13 = torch::mul(*Base::data_[1],
377 (torch::mul(*Base::data_[6], *Base::data_[15]) -
378 torch::mul(*Base::data_[7], *Base::data_[14]))) -
379 torch::mul(*Base::data_[5],
380 (torch::mul(*Base::data_[2], *Base::data_[15]) -
381 torch::mul(*Base::data_[3], *Base::data_[14]))) -
382 torch::mul(*Base::data_[13],
383 (torch::mul(*Base::data_[3], *Base::data_[6]) -
384 torch::mul(*Base::data_[2], *Base::data_[7])));
386 auto a14 = torch::mul(*Base::data_[1],
387 (torch::mul(*Base::data_[7], *Base::data_[10]) -
388 torch::mul(*Base::data_[6], *Base::data_[11]))) -
389 torch::mul(*Base::data_[5],
390 (torch::mul(*Base::data_[3], *Base::data_[10]) -
391 torch::mul(*Base::data_[2], *Base::data_[11]))) -
392 torch::mul(*Base::data_[9],
393 (torch::mul(*Base::data_[2], *Base::data_[7]) -
394 torch::mul(*Base::data_[3], *Base::data_[6])));
396 auto a21 = torch::mul(*Base::data_[4],
397 (torch::mul(*Base::data_[11], *Base::data_[14]) -
398 torch::mul(*Base::data_[10], *Base::data_[15]))) -
399 torch::mul(*Base::data_[8],
400 (torch::mul(*Base::data_[7], *Base::data_[14]) -
401 torch::mul(*Base::data_[6], *Base::data_[15]))) -
402 torch::mul(*Base::data_[12],
403 (torch::mul(*Base::data_[6], *Base::data_[11]) -
404 torch::mul(*Base::data_[7], *Base::data_[10])));
406 auto a22 = torch::mul(*Base::data_[0],
407 (torch::mul(*Base::data_[10], *Base::data_[15]) -
408 torch::mul(*Base::data_[11], *Base::data_[14]))) -
409 torch::mul(*Base::data_[8],
410 (torch::mul(*Base::data_[2], *Base::data_[15]) -
411 torch::mul(*Base::data_[3], *Base::data_[14]))) -
412 torch::mul(*Base::data_[12],
413 (torch::mul(*Base::data_[3], *Base::data_[10]) -
414 torch::mul(*Base::data_[2], *Base::data_[11])));
416 auto a23 = torch::mul(*Base::data_[0],
417 (torch::mul(*Base::data_[7], *Base::data_[14]) -
418 torch::mul(*Base::data_[6], *Base::data_[15]))) -
419 torch::mul(*Base::data_[4],
420 (torch::mul(*Base::data_[3], *Base::data_[14]) -
421 torch::mul(*Base::data_[2], *Base::data_[15]))) -
422 torch::mul(*Base::data_[12],
423 (torch::mul(*Base::data_[2], *Base::data_[7]) -
424 torch::mul(*Base::data_[3], *Base::data_[6])));
426 auto a24 = torch::mul(*Base::data_[0],
427 (torch::mul(*Base::data_[6], *Base::data_[11]) -
428 torch::mul(*Base::data_[7], *Base::data_[10]))) -
429 torch::mul(*Base::data_[4],
430 (torch::mul(*Base::data_[2], *Base::data_[11]) -
431 torch::mul(*Base::data_[3], *Base::data_[10]))) -
432 torch::mul(*Base::data_[8],
433 (torch::mul(*Base::data_[3], *Base::data_[6]) -
434 torch::mul(*Base::data_[2], *Base::data_[7])));
436 auto a31 = torch::mul(*Base::data_[4],
437 (torch::mul(*Base::data_[9], *Base::data_[15]) -
438 torch::mul(*Base::data_[11], *Base::data_[13]))) -
439 torch::mul(*Base::data_[8],
440 (torch::mul(*Base::data_[5], *Base::data_[15]) -
441 torch::mul(*Base::data_[7], *Base::data_[13]))) -
442 torch::mul(*Base::data_[12],
443 (torch::mul(*Base::data_[7], *Base::data_[9]) -
444 torch::mul(*Base::data_[5], *Base::data_[11])));
446 auto a32 = torch::mul(*Base::data_[0],
447 (torch::mul(*Base::data_[11], *Base::data_[13]) -
448 torch::mul(*Base::data_[9], *Base::data_[15]))) -
449 torch::mul(*Base::data_[8],
450 (torch::mul(*Base::data_[3], *Base::data_[13]) -
451 torch::mul(*Base::data_[1], *Base::data_[15]))) -
452 torch::mul(*Base::data_[12],
453 (torch::mul(*Base::data_[1], *Base::data_[11]) -
454 torch::mul(*Base::data_[3], *Base::data_[9])));
456 auto a33 = torch::mul(*Base::data_[0],
457 (torch::mul(*Base::data_[5], *Base::data_[15]) -
458 torch::mul(*Base::data_[7], *Base::data_[13]))) -
459 torch::mul(*Base::data_[4],
460 (torch::mul(*Base::data_[1], *Base::data_[15]) -
461 torch::mul(*Base::data_[3], *Base::data_[13]))) -
462 torch::mul(*Base::data_[12],
463 (torch::mul(*Base::data_[3], *Base::data_[5]) -
464 torch::mul(*Base::data_[1], *Base::data_[7])));
466 auto a34 = torch::mul(*Base::data_[0],
467 (torch::mul(*Base::data_[7], *Base::data_[9]) -
468 torch::mul(*Base::data_[5], *Base::data_[11]))) -
469 torch::mul(*Base::data_[4],
470 (torch::mul(*Base::data_[3], *Base::data_[9]) -
471 torch::mul(*Base::data_[1], *Base::data_[11]))) -
472 torch::mul(*Base::data_[8],
473 (torch::mul(*Base::data_[1], *Base::data_[7]) -
474 torch::mul(*Base::data_[3], *Base::data_[5])));
476 auto a41 = torch::mul(*Base::data_[4],
477 (torch::mul(*Base::data_[10], *Base::data_[13]) -
478 torch::mul(*Base::data_[9], *Base::data_[14]))) -
479 torch::mul(*Base::data_[8],
480 (torch::mul(*Base::data_[6], *Base::data_[13]) -
481 torch::mul(*Base::data_[5], *Base::data_[14]))) -
482 torch::mul(*Base::data_[12],
483 (torch::mul(*Base::data_[5], *Base::data_[10]) -
484 torch::mul(*Base::data_[6], *Base::data_[9])));
486 auto a42 = torch::mul(*Base::data_[0],
487 (torch::mul(*Base::data_[9], *Base::data_[14]) -
488 torch::mul(*Base::data_[10], *Base::data_[13]))) -
489 torch::mul(*Base::data_[8],
490 (torch::mul(*Base::data_[1], *Base::data_[14]) -
491 torch::mul(*Base::data_[2], *Base::data_[13]))) -
492 torch::mul(*Base::data_[12],
493 (torch::mul(*Base::data_[2], *Base::data_[9]) -
494 torch::mul(*Base::data_[1], *Base::data_[10])));
496 auto a43 = torch::mul(*Base::data_[0],
497 (torch::mul(*Base::data_[6], *Base::data_[13]) -
498 torch::mul(*Base::data_[5], *Base::data_[14]))) -
499 torch::mul(*Base::data_[4],
500 (torch::mul(*Base::data_[2], *Base::data_[13]) -
501 torch::mul(*Base::data_[1], *Base::data_[14]))) -
502 torch::mul(*Base::data_[12],
503 (torch::mul(*Base::data_[1], *Base::data_[6]) -
504 torch::mul(*Base::data_[2], *Base::data_[5])));
506 auto a44 = torch::mul(*Base::data_[0],
507 (torch::mul(*Base::data_[5], *Base::data_[10]) -
508 torch::mul(*Base::data_[6], *Base::data_[9]))) -
509 torch::mul(*Base::data_[4],
510 (torch::mul(*Base::data_[1], *Base::data_[10]) -
511 torch::mul(*Base::data_[2], *Base::data_[9]))) -
512 torch::mul(*Base::data_[8],
513 (torch::mul(*Base::data_[2], *Base::data_[5]) -
514 torch::mul(*Base::data_[1], *Base::data_[6])));
516 result[0] = std::make_shared<T>(torch::div(a11, det_));
517 result[1] = std::make_shared<T>(torch::div(a12, det_));
518 result[2] = std::make_shared<T>(torch::div(a13, det_));
519 result[3] = std::make_shared<T>(torch::div(a14, det_));
520 result[4] = std::make_shared<T>(torch::div(a21, det_));
521 result[5] = std::make_shared<T>(torch::div(a22, det_));
522 result[6] = std::make_shared<T>(torch::div(a23, det_));
523 result[7] = std::make_shared<T>(torch::div(a24, det_));
524 result[8] = std::make_shared<T>(torch::div(a31, det_));
525 result[9] = std::make_shared<T>(torch::div(a32, det_));
526 result[10] = std::make_shared<T>(torch::div(a33, det_));
527 result[11] = std::make_shared<T>(torch::div(a34, det_));
528 result[12] = std::make_shared<T>(torch::div(a41, det_));
529 result[13] = std::make_shared<T>(torch::div(a42, det_));
530 result[14] = std::make_shared<T>(torch::div(a43, det_));
531 result[15] = std::make_shared<T>(torch::div(a44, det_));
534 throw std::runtime_error(
"Unsupported block tensor dimension");
547 if constexpr (Rows == Cols)
551 return (this->tr() * (*this)).inv() * this->tr();
561 auto det_ = this->det();
563 if constexpr (Rows == 1 && Cols == 1) {
565 result[0] = std::make_shared<T>(torch::reciprocal(*Base::data_[0]));
567 }
else if constexpr (Rows == 2 && Cols == 2) {
570 result[0] = std::make_shared<T>(torch::div(*Base::data_[3], det_));
571 result[1] = std::make_shared<T>(torch::div(*Base::data_[1], -det_));
572 result[2] = std::make_shared<T>(torch::div(*Base::data_[2], -det_));
573 result[3] = std::make_shared<T>(torch::div(*Base::data_[0], det_));
575 }
else if constexpr (Rows == 3 && Cols == 3) {
577 auto a11 = torch::mul(*Base::data_[4], *Base::data_[8]) -
578 torch::mul(*Base::data_[5], *Base::data_[7]);
579 auto a12 = torch::mul(*Base::data_[2], *Base::data_[7]) -
580 torch::mul(*Base::data_[1], *Base::data_[8]);
581 auto a13 = torch::mul(*Base::data_[1], *Base::data_[5]) -
582 torch::mul(*Base::data_[2], *Base::data_[4]);
583 auto a21 = torch::mul(*Base::data_[5], *Base::data_[6]) -
584 torch::mul(*Base::data_[3], *Base::data_[8]);
585 auto a22 = torch::mul(*Base::data_[0], *Base::data_[8]) -
586 torch::mul(*Base::data_[2], *Base::data_[6]);
587 auto a23 = torch::mul(*Base::data_[2], *Base::data_[3]) -
588 torch::mul(*Base::data_[0], *Base::data_[5]);
589 auto a31 = torch::mul(*Base::data_[3], *Base::data_[7]) -
590 torch::mul(*Base::data_[4], *Base::data_[6]);
591 auto a32 = torch::mul(*Base::data_[1], *Base::data_[6]) -
592 torch::mul(*Base::data_[0], *Base::data_[7]);
593 auto a33 = torch::mul(*Base::data_[0], *Base::data_[4]) -
594 torch::mul(*Base::data_[1], *Base::data_[3]);
597 result[0] = std::make_shared<T>(torch::div(a11, det_));
598 result[1] = std::make_shared<T>(torch::div(a21, det_));
599 result[2] = std::make_shared<T>(torch::div(a31, det_));
600 result[3] = std::make_shared<T>(torch::div(a12, det_));
601 result[4] = std::make_shared<T>(torch::div(a22, det_));
602 result[5] = std::make_shared<T>(torch::div(a32, det_));
603 result[6] = std::make_shared<T>(torch::div(a13, det_));
604 result[7] = std::make_shared<T>(torch::div(a23, det_));
605 result[8] = std::make_shared<T>(torch::div(a33, det_));
607 }
else if constexpr (Rows == 4 && Cols == 4) {
609 auto a11 = torch::mul(*Base::data_[5],
610 (torch::mul(*Base::data_[10], *Base::data_[15]) -
611 torch::mul(*Base::data_[11], *Base::data_[14]))) -
612 torch::mul(*Base::data_[9],
613 (torch::mul(*Base::data_[6], *Base::data_[15]) -
614 torch::mul(*Base::data_[7], *Base::data_[14]))) -
615 torch::mul(*Base::data_[13],
616 (torch::mul(*Base::data_[7], *Base::data_[10]) -
617 torch::mul(*Base::data_[6], *Base::data_[11])));
619 auto a12 = torch::mul(*Base::data_[1],
620 (torch::mul(*Base::data_[11], *Base::data_[14]) -
621 torch::mul(*Base::data_[10], *Base::data_[15]))) -
622 torch::mul(*Base::data_[9],
623 (torch::mul(*Base::data_[3], *Base::data_[14]) -
624 torch::mul(*Base::data_[2], *Base::data_[15]))) -
625 torch::mul(*Base::data_[13],
626 (torch::mul(*Base::data_[2], *Base::data_[11]) -
627 torch::mul(*Base::data_[3], *Base::data_[10])));
629 auto a13 = torch::mul(*Base::data_[1],
630 (torch::mul(*Base::data_[6], *Base::data_[15]) -
631 torch::mul(*Base::data_[7], *Base::data_[14]))) -
632 torch::mul(*Base::data_[5],
633 (torch::mul(*Base::data_[2], *Base::data_[15]) -
634 torch::mul(*Base::data_[3], *Base::data_[14]))) -
635 torch::mul(*Base::data_[13],
636 (torch::mul(*Base::data_[3], *Base::data_[6]) -
637 torch::mul(*Base::data_[2], *Base::data_[7])));
639 auto a14 = torch::mul(*Base::data_[1],
640 (torch::mul(*Base::data_[7], *Base::data_[10]) -
641 torch::mul(*Base::data_[6], *Base::data_[11]))) -
642 torch::mul(*Base::data_[5],
643 (torch::mul(*Base::data_[3], *Base::data_[10]) -
644 torch::mul(*Base::data_[2], *Base::data_[11]))) -
645 torch::mul(*Base::data_[9],
646 (torch::mul(*Base::data_[2], *Base::data_[7]) -
647 torch::mul(*Base::data_[3], *Base::data_[6])));
649 auto a21 = torch::mul(*Base::data_[4],
650 (torch::mul(*Base::data_[11], *Base::data_[14]) -
651 torch::mul(*Base::data_[10], *Base::data_[15]))) -
652 torch::mul(*Base::data_[8],
653 (torch::mul(*Base::data_[7], *Base::data_[14]) -
654 torch::mul(*Base::data_[6], *Base::data_[15]))) -
655 torch::mul(*Base::data_[12],
656 (torch::mul(*Base::data_[6], *Base::data_[11]) -
657 torch::mul(*Base::data_[7], *Base::data_[10])));
659 auto a22 = torch::mul(*Base::data_[0],
660 (torch::mul(*Base::data_[10], *Base::data_[15]) -
661 torch::mul(*Base::data_[11], *Base::data_[14]))) -
662 torch::mul(*Base::data_[8],
663 (torch::mul(*Base::data_[2], *Base::data_[15]) -
664 torch::mul(*Base::data_[3], *Base::data_[14]))) -
665 torch::mul(*Base::data_[12],
666 (torch::mul(*Base::data_[3], *Base::data_[10]) -
667 torch::mul(*Base::data_[2], *Base::data_[11])));
669 auto a23 = torch::mul(*Base::data_[0],
670 (torch::mul(*Base::data_[7], *Base::data_[14]) -
671 torch::mul(*Base::data_[6], *Base::data_[15]))) -
672 torch::mul(*Base::data_[4],
673 (torch::mul(*Base::data_[3], *Base::data_[14]) -
674 torch::mul(*Base::data_[2], *Base::data_[15]))) -
675 torch::mul(*Base::data_[12],
676 (torch::mul(*Base::data_[2], *Base::data_[7]) -
677 torch::mul(*Base::data_[3], *Base::data_[6])));
679 auto a24 = torch::mul(*Base::data_[0],
680 (torch::mul(*Base::data_[6], *Base::data_[11]) -
681 torch::mul(*Base::data_[7], *Base::data_[10]))) -
682 torch::mul(*Base::data_[4],
683 (torch::mul(*Base::data_[2], *Base::data_[11]) -
684 torch::mul(*Base::data_[3], *Base::data_[10]))) -
685 torch::mul(*Base::data_[8],
686 (torch::mul(*Base::data_[3], *Base::data_[6]) -
687 torch::mul(*Base::data_[2], *Base::data_[7])));
689 auto a31 = torch::mul(*Base::data_[4],
690 (torch::mul(*Base::data_[9], *Base::data_[15]) -
691 torch::mul(*Base::data_[11], *Base::data_[13]))) -
692 torch::mul(*Base::data_[8],
693 (torch::mul(*Base::data_[5], *Base::data_[15]) -
694 torch::mul(*Base::data_[7], *Base::data_[13]))) -
695 torch::mul(*Base::data_[12],
696 (torch::mul(*Base::data_[7], *Base::data_[9]) -
697 torch::mul(*Base::data_[5], *Base::data_[11])));
699 auto a32 = torch::mul(*Base::data_[0],
700 (torch::mul(*Base::data_[11], *Base::data_[13]) -
701 torch::mul(*Base::data_[9], *Base::data_[15]))) -
702 torch::mul(*Base::data_[8],
703 (torch::mul(*Base::data_[3], *Base::data_[13]) -
704 torch::mul(*Base::data_[1], *Base::data_[15]))) -
705 torch::mul(*Base::data_[12],
706 (torch::mul(*Base::data_[1], *Base::data_[11]) -
707 torch::mul(*Base::data_[3], *Base::data_[9])));
709 auto a33 = torch::mul(*Base::data_[0],
710 (torch::mul(*Base::data_[5], *Base::data_[15]) -
711 torch::mul(*Base::data_[7], *Base::data_[13]))) -
712 torch::mul(*Base::data_[4],
713 (torch::mul(*Base::data_[1], *Base::data_[15]) -
714 torch::mul(*Base::data_[3], *Base::data_[13]))) -
715 torch::mul(*Base::data_[12],
716 (torch::mul(*Base::data_[3], *Base::data_[5]) -
717 torch::mul(*Base::data_[1], *Base::data_[7])));
719 auto a34 = torch::mul(*Base::data_[0],
720 (torch::mul(*Base::data_[7], *Base::data_[9]) -
721 torch::mul(*Base::data_[5], *Base::data_[11]))) -
722 torch::mul(*Base::data_[4],
723 (torch::mul(*Base::data_[3], *Base::data_[9]) -
724 torch::mul(*Base::data_[1], *Base::data_[11]))) -
725 torch::mul(*Base::data_[8],
726 (torch::mul(*Base::data_[1], *Base::data_[7]) -
727 torch::mul(*Base::data_[3], *Base::data_[5])));
729 auto a41 = torch::mul(*Base::data_[4],
730 (torch::mul(*Base::data_[10], *Base::data_[13]) -
731 torch::mul(*Base::data_[9], *Base::data_[14]))) -
732 torch::mul(*Base::data_[8],
733 (torch::mul(*Base::data_[6], *Base::data_[13]) -
734 torch::mul(*Base::data_[5], *Base::data_[14]))) -
735 torch::mul(*Base::data_[12],
736 (torch::mul(*Base::data_[5], *Base::data_[10]) -
737 torch::mul(*Base::data_[6], *Base::data_[9])));
739 auto a42 = torch::mul(*Base::data_[0],
740 (torch::mul(*Base::data_[9], *Base::data_[14]) -
741 torch::mul(*Base::data_[10], *Base::data_[13]))) -
742 torch::mul(*Base::data_[8],
743 (torch::mul(*Base::data_[1], *Base::data_[14]) -
744 torch::mul(*Base::data_[2], *Base::data_[13]))) -
745 torch::mul(*Base::data_[12],
746 (torch::mul(*Base::data_[2], *Base::data_[9]) -
747 torch::mul(*Base::data_[1], *Base::data_[10])));
749 auto a43 = torch::mul(*Base::data_[0],
750 (torch::mul(*Base::data_[6], *Base::data_[13]) -
751 torch::mul(*Base::data_[5], *Base::data_[14]))) -
752 torch::mul(*Base::data_[4],
753 (torch::mul(*Base::data_[2], *Base::data_[13]) -
754 torch::mul(*Base::data_[1], *Base::data_[14]))) -
755 torch::mul(*Base::data_[12],
756 (torch::mul(*Base::data_[1], *Base::data_[6]) -
757 torch::mul(*Base::data_[2], *Base::data_[5])));
759 auto a44 = torch::mul(*Base::data_[0],
760 (torch::mul(*Base::data_[5], *Base::data_[10]) -
761 torch::mul(*Base::data_[6], *Base::data_[9]))) -
762 torch::mul(*Base::data_[4],
763 (torch::mul(*Base::data_[1], *Base::data_[10]) -
764 torch::mul(*Base::data_[2], *Base::data_[9]))) -
765 torch::mul(*Base::data_[8],
766 (torch::mul(*Base::data_[2], *Base::data_[5]) -
767 torch::mul(*Base::data_[1], *Base::data_[6])));
770 result[0] = std::make_shared<T>(torch::div(a11, det_));
771 result[1] = std::make_shared<T>(torch::div(a21, det_));
772 result[2] = std::make_shared<T>(torch::div(a31, det_));
773 result[3] = std::make_shared<T>(torch::div(a41, det_));
774 result[4] = std::make_shared<T>(torch::div(a12, det_));
775 result[5] = std::make_shared<T>(torch::div(a22, det_));
776 result[6] = std::make_shared<T>(torch::div(a32, det_));
777 result[7] = std::make_shared<T>(torch::div(a42, det_));
778 result[8] = std::make_shared<T>(torch::div(a13, det_));
779 result[9] = std::make_shared<T>(torch::div(a23, det_));
780 result[10] = std::make_shared<T>(torch::div(a33, det_));
781 result[11] = std::make_shared<T>(torch::div(a43, det_));
782 result[12] = std::make_shared<T>(torch::div(a14, det_));
783 result[13] = std::make_shared<T>(torch::div(a24, det_));
784 result[14] = std::make_shared<T>(torch::div(a34, det_));
785 result[15] = std::make_shared<T>(torch::div(a44, det_));
788 throw std::runtime_error(
"Unsupported block tensor dimension");
803 if constexpr (Rows == Cols)
804 return this->invtr();
807 return (*
this) * (this->tr() * (*this)).invtr();
812 static_assert(Rows == Cols,
"trace(.) requires square block tensor");
814 if constexpr (Rows == 1)
817 else if constexpr (Rows == 2)
820 else if constexpr (Rows == 3)
824 else if constexpr (Rows == 4)
826 *Base::data_[10] + *Base::data_[15]);
829 throw std::runtime_error(
"Unsupported block tensor dimension");
834 template <std::size_t... Is>
835 inline auto norm_(std::index_sequence<Is...>)
const {
837 std::apply([](
const auto &...tensors) {
return (tensors + ...); },
838 std::make_tuple(std::get<Is>(Base::data_)->
square()...)));
845 std::make_shared<T>(norm_(std::make_index_sequence<Rows * Cols>{})));
850 template <std::size_t... Is>
851 inline auto normalize_(std::index_sequence<Is...> is)
const {
854 std::make_shared<T>(*std::get<Is>(Base::data_) / n_)...);
860 return normalize_(std::make_index_sequence<Rows * Cols>{});
865 template <std::size_t... Is>
866 inline auto dot_(std::index_sequence<Is...>,
869 [](
const auto &...tensors) {
return (tensors + ...); },
870 std::make_tuple(torch::mul(*std::get<Is>(Base::data_),
871 *std::get<Is>(other.data_))...));
878 dot_(std::make_index_sequence<Rows * Cols>{}, other)));
883 os << Base::name() <<
"\n";
884 for (std::size_t row = 0; row < Rows; ++row)
885 for (std::size_t col = 0; col < Cols; ++col)
886 os <<
"[" << row <<
"," << col <<
"] = \n"
887 << *Base::data_[Cols * row + col] <<
"\n";
893template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
898 for (std::size_t row = 0; row < Rows; ++row)
899 for (std::size_t col = 0; col < Cols; ++col) {
901 (lhs[Common * row]->dim() > rhs[col]->dim()
902 ? torch::mul(*lhs[Common * row], rhs[col]->unsqueeze(-1))
903 : (lhs[Common * row]->dim() < rhs[col]->dim()
904 ? torch::mul(lhs[Common * row]->unsqueeze(-1), *rhs[col])
905 : torch::mul(*lhs[Common * row], *rhs[col])));
906 for (std::size_t idx = 1; idx < Common; ++idx)
907 tmp += (lhs[Common * row]->dim() > rhs[col]->dim()
908 ? torch::mul(*lhs[Common * row + idx],
909 rhs[Cols * idx + col]->unsqueeze(-1))
910 : (lhs[Common * row]->dim() < rhs[col]->dim()
911 ? torch::mul(lhs[Common * row + idx]->unsqueeze(-1),
912 *rhs[Cols * idx + col])
913 : torch::mul(*lhs[Common * row + idx],
914 *rhs[Cols * idx + col])));
915 result[Cols * row + col] = std::make_shared<T>(tmp);
926template <
typename T, std::
size_t Rows, std::
size_t Cols, std::
size_t Slices>
936 inline static constexpr std::size_t
rows() {
return Rows; }
939 inline static constexpr std::size_t
cols() {
return Cols; }
942 inline static constexpr std::size_t
slices() {
return Slices; }
944 using Base::operator();
948 std::size_t slice)
const {
949 assert(row < Rows && col < Cols && slice < Slices);
950 return *Base::data_[Rows * Cols * slice + Cols * row + col];
954 inline T &
operator()(std::size_t row, std::size_t col, std::size_t slice) {
955 assert(row < Rows && col < Cols && slice < Slices);
956 return *Base::data_[Rows * Cols * slice + Cols * row + col];
962 template <
typename D>
963 inline T &
set(std::size_t row, std::size_t col, std::size_t slice, D &&data) {
964 Base::data_[Rows * Cols * slice + Cols * row + col] =
965 make_shared<D>(std::forward<D>(data));
966 return *Base::data_[Rows * Cols * slice + Cols * row + col];
970 inline auto slice(std::size_t slice)
const {
971 assert(slice < Slices);
973 for (std::size_t row = 0; row < Rows; ++row)
974 for (std::size_t col = 0; col < Cols; ++col)
975 result[Cols * row + col] =
976 Base::data_[Rows * Cols * slice + Cols * row + col];
984 for (std::size_t slice = 0; slice < Slices; ++slice)
985 for (std::size_t row = 0; row < Rows; ++row)
986 for (std::size_t col = 0; col < Cols; ++col)
987 result[Rows * Slices * col + Slices * row + slice] =
988 Base::data_[Rows * Cols * slice + Cols * row + col];
997 for (std::size_t slice = 0; slice < Slices; ++slice)
998 for (std::size_t row = 0; row < Rows; ++row)
999 for (std::size_t col = 0; col < Cols; ++col)
1000 result[Rows * Cols * slice + Rows * col + row] =
1001 Base::data_[Rows * Cols * slice + Cols * row + col];
1009 for (std::size_t slice = 0; slice < Slices; ++slice)
1010 for (std::size_t row = 0; row < Rows; ++row)
1011 for (std::size_t col = 0; col < Cols; ++col)
1012 result[Slices * Cols * row + Cols * slice + col] =
1013 Base::data_[Rows * Cols * slice + Cols * row + col];
1021 for (std::size_t slice = 0; slice < Slices; ++slice)
1022 for (std::size_t row = 0; row < Rows; ++row)
1023 for (std::size_t col = 0; col < Cols; ++col)
1024 result[Slices * Rows * col + Rows * slice + row] =
1025 Base::data_[Rows * Cols * slice + Cols * row + col];
1031 os << Base::name() <<
"\n";
1032 for (std::size_t slice = 0; slice < Slices; ++slice)
1033 for (std::size_t row = 0; row < Rows; ++row)
1034 for (std::size_t col = 0; col < Cols; ++col)
1035 os <<
"[" << row <<
"," << col <<
"," << slice <<
"] = \n"
1036 << *Base::data_[Rows * Cols * slice + Cols * row + col] <<
"\n";
1042template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
1043 std::size_t Cols, std::size_t Slices>
1047 for (std::size_t slice = 0; slice < Slices; ++slice)
1048 for (std::size_t row = 0; row < Rows; ++row)
1049 for (std::size_t col = 0; col < Cols; ++col) {
1051 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1052 ? torch::mul(*lhs[Common * row],
1053 rhs[Rows * Cols * slice + col]->unsqueeze(-1))
1054 : (lhs[Common * row]->dim() <
1055 rhs[Rows * Cols * slice + col]->dim()
1056 ? torch::mul(lhs[Common * row]->unsqueeze(-1),
1057 *rhs[Rows * Cols * slice + col])
1058 : torch::mul(*lhs[Common * row],
1059 *rhs[Rows * Cols * slice + col])));
1060 for (std::size_t idx = 1; idx < Common; ++idx)
1062 (lhs[Common * row]->dim() > rhs[Rows * Cols * slice + col]->dim()
1064 *lhs[Common * row + idx],
1065 rhs[Rows * Cols * slice + Cols * idx + col]->unsqueeze(
1067 : (lhs[Common * row]->dim() <
1068 rhs[Rows * Cols * slice + col]->dim()
1070 lhs[Common * row + idx]->unsqueeze(-1),
1071 *rhs[Rows * Cols * slice + Cols * idx + col])
1073 *lhs[Common * row + idx],
1074 *rhs[Rows * Cols * slice + Cols * idx + col])));
1075 result[Rows * Cols * slice + Cols * row + col] =
1076 std::make_shared<T>(tmp);
1083template <
typename T,
typename U, std::size_t Rows, std::size_t Common,
1084 std::size_t Cols, std::size_t Slices>
1088 for (std::size_t slice = 0; slice < Slices; ++slice)
1089 for (std::size_t row = 0; row < Rows; ++row)
1090 for (std::size_t col = 0; col < Cols; ++col) {
1092 (lhs[Rows * Cols * slice + Common * row]->dim() > rhs[col]->dim()
1093 ? torch::mul(*lhs[Rows * Cols * slice + Common * row],
1094 rhs[col]->unsqueeze(-1))
1095 : (lhs[Rows * Cols * slice + Common * row]->dim() <
1097 ? torch::mul(lhs[Rows * Cols * slice + Common * row]
1100 : torch::mul(*lhs[Rows * Cols * slice + Common * row],
1102 for (std::size_t idx = 1; idx < Common; ++idx)
1104 (lhs[Rows * Cols * slice + Common * row + idx]->dim() >
1105 rhs[Cols * idx + col]->dim()
1106 ? torch::mul(*lhs[Rows * Cols * slice + Common * row + idx],
1107 rhs[Cols * idx + col])
1109 : (lhs[Rows * Cols * slice + Common * row + idx]->dim() <
1110 rhs[Cols * idx + col]->dim()
1112 lhs[Rows * Cols * slice + Common * row + idx]
1114 *rhs[Cols * idx + col])
1116 *lhs[Rows * Cols * slice + Common * row + idx],
1117 *rhs[Cols * idx + col])));
1118 result[Rows * Cols * slice + Cols * row + col] =
1119 std::make_shared<T>(tmp);
1124#define blocktensor_unary_op(name) \
1125 template <typename T, std::size_t... Dims> \
1126 inline auto name(const BlockTensor<T, Dims...> &input) { \
1127 BlockTensor<T, Dims...> result; \
1128 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1129 result[idx] = std::make_shared<T>(torch::name(*input[idx])); \
1133#define blocktensor_unary_special_op(name) \
1134 template <typename T, std::size_t... Dims> \
1135 inline auto name(const BlockTensor<T, Dims...> &input) { \
1136 BlockTensor<T, Dims...> result; \
1137 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1138 result[idx] = std::make_shared<T>(torch::special::name(*input[idx])); \
1142#define blocktensor_binary_op(name) \
1143 template <typename T, typename U, std::size_t... Dims> \
1144 inline auto name(const BlockTensor<T, Dims...> &input, \
1145 const BlockTensor<U, Dims...> &other) { \
1146 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1147 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1149 std::make_shared<T>(torch::name(*input[idx], *other[idx])); \
1153#define blocktensor_binary_special_op(name) \
1154 template <typename T, typename U, std::size_t... Dims> \
1155 inline auto name(const BlockTensor<T, Dims...> &input, \
1156 const BlockTensor<U, Dims...> &other) { \
1157 BlockTensor<typename std::common_type<T, U>::type, Dims...> result; \
1158 for (std::size_t idx = 0; idx < (Dims * ...); ++idx) \
1160 std::make_shared<T>(torch::special::name(*input[idx], *other[idx])); \
1187template <
typename T,
typename U,
typename V, std::size_t... Dims>
1191 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1193 std::make_shared<T>(torch::add(*input[idx], *other[idx], alpha));
1199template <
typename T,
typename U,
typename V, std::size_t... Dims>
1202 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1203 result[idx] = std::make_shared<T>(torch::add(*input[idx], other, alpha));
1209template <
typename T,
typename U,
typename V, std::size_t... Dims>
1212 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1213 result[idx] = std::make_shared<T>(torch::add(input, *other[idx], alpha));
1220template <
typename T,
typename U,
typename V,
typename W, std::size_t... Dims>
1225 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1226 result[idx] = std::make_shared<T>(
1227 torch::addcdiv(*input[idx], *tensor1[idx], *tensor2[idx], value));
1235template <
typename T,
typename U,
typename V,
typename W, std::size_t... Dims>
1240 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1241 result[idx] = std::make_shared<T>(
1242 torch::addcmul(*input[idx], *tensor1[idx], *tensor2[idx], value));
1283#if TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11 || \
1284 TORCH_VERSION_MAJOR >= 2
1320template <
typename T,
typename U, std::size_t... Dims>
1323 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1324 result[idx] = std::make_shared<T>(torch::clamp(*input[idx], min, max));
1329template <
typename T,
typename U, std::size_t... Dims>
1332 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1333 result[idx] = std::make_shared<T>(torch::clip(*input[idx], min, max));
1370template <typename T,
std::
size_t Rows,
std::
size_t Cols>
1373 return input.dot(tensor);
1590template <typename T, typename U, typename V,
std::
size_t... Dims>
1592 const
BlockTensor<U, Dims...> &other, V alpha = 1.0) {
1594 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1596 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1601template <
typename T,
typename U,
typename V, std::size_t... Dims>
1605 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1607 std::make_shared<T>(torch::sub(*input[idx], *other[idx], alpha));
1628template <typename T, typename U,
std::
size_t... Dims>
1632 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1633 result[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1639template <
typename T,
typename U, std::size_t... Dims>
1642 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1643 result[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1649template <
typename T,
typename U, std::size_t... Dims>
1652 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1653 result[idx] = std::make_shared<U>(lhs + *rhs[idx]);
1658template <
typename T,
typename U, std::size_t... Dims>
1661 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1662 lhs[idx] = std::make_shared<T>(*lhs[idx] + *rhs[idx]);
1667template <
typename T,
typename U, std::size_t... Dims>
1669 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1670 lhs[idx] = std::make_shared<T>(*lhs[idx] + rhs);
1676template <
typename T,
typename U, std::size_t... Dims>
1680 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1681 result[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1687template <
typename T,
typename U, std::size_t... Dims>
1690 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1691 result[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1697template <
typename T,
typename U, std::size_t... Dims>
1700 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1701 result[idx] = std::make_shared<U>(lhs - *rhs[idx]);
1706template <
typename T,
typename U, std::size_t... Dims>
1709 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1710 lhs[idx] = std::make_shared<T>(*lhs[idx] - *rhs[idx]);
1715template <
typename T,
typename U, std::size_t... Dims>
1717 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1718 lhs[idx] = std::make_shared<T>(*lhs[idx] - rhs);
1724template <
typename T,
typename U, std::size_t... Dims>
1727 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1729 (lhs[idx]->dim() > rhs.dim()
1730 ? std::make_shared<T>(*lhs[idx] * rhs.unsqueeze(-1))
1731 : (lhs[idx]->dim() < rhs.dim()
1732 ? std::make_shared<T>(lhs[idx]->unsqueeze(-1) * rhs)
1733 : std::make_shared<T>(*lhs[idx] * rhs)));
1740template <
typename T,
typename U, std::size_t... Dims>
1743 for (std::size_t idx = 0; idx < (Dims * ...); ++idx)
1745 (lhs.dim() > rhs[idx]->dim()
1746 ? std::make_shared<U>(lhs * rhs[idx]->unsqueeze(-1))
1747 : (lhs.dim() < rhs[idx]->dim()
1748 ? std::make_shared<U>(lhs.unsqueeze(-1) * *rhs[idx])
1749 : std::make_shared<U>(lhs * *rhs[idx])));
1754template <
typename T,
typename U, std::size_t... TDims, std::size_t... UDims>
1757 if constexpr ((
sizeof...(TDims) !=
sizeof...(UDims)) ||
1758 ((TDims != UDims) || ...))
1762 for (std::size_t idx = 0; idx < (TDims * ...); ++idx)
1763 result = result && torch::equal(*lhs[idx], *rhs[idx]);
1769template <
typename T,
typename U, std::size_t... TDims, std::size_t... UDims>
1772 return !(lhs == rhs);
#define blocktensor_unary_op(name)
Definition blocktensor.hpp:1124
#define blocktensor_unary_special_op(name)
Definition blocktensor.hpp:1133
#define blocktensor_binary_op(name)
Definition blocktensor.hpp:1142
#define blocktensor_binary_special_op(name)
Definition blocktensor.hpp:1153
static constexpr std::size_t slices()
Returns the number of slices.
Definition blocktensor.hpp:942
auto reorder_jik() const
Returns a new block tensor with rows and columns transposed and slices remaining fixed....
Definition blocktensor.hpp:995
const T & operator()(std::size_t row, std::size_t col, std::size_t slice) const
Returns a constant reference to entry (row, col, slice)
Definition blocktensor.hpp:947
T & set(std::size_t row, std::size_t col, std::size_t slice, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:963
auto reorder_kij() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1019
auto reorder_ikj() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:982
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:936
T & operator()(std::size_t row, std::size_t col, std::size_t slice)
Returns a non-constant reference to entry (row, col, slice)
Definition blocktensor.hpp:954
auto reorder_kji() const
Returns a new block tensor with rows, columns, and slices permuted according to (i,...
Definition blocktensor.hpp:1007
auto slice(std::size_t slice) const
Returns a rank-2 tensor of the k-th slice.
Definition blocktensor.hpp:970
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BSplineCommon object.
Definition blocktensor.hpp:1030
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:939
auto inv() const
Returns the inverse of the block tensor.
Definition blocktensor.hpp:307
auto dot_(std::index_sequence< Is... >, const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:866
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:882
auto ginv() const
Returns the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:546
auto ginvtr() const
Returns the transpose of the (generalized) inverse of the block tensor.
Definition blocktensor.hpp:802
auto tr() const
Returns the transpose of the block tensor.
Definition blocktensor.hpp:220
auto invtr() const
Returns the transpose of the inverse of the block tensor.
Definition blocktensor.hpp:559
static constexpr std::size_t cols()
Returns the number of columns.
Definition blocktensor.hpp:193
auto normalize_(std::index_sequence< Is... > is) const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:851
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:190
auto norm() const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:843
const T & operator()(std::size_t row, std::size_t col) const
Returns a constant reference to entry (row, col)
Definition blocktensor.hpp:198
T & set(std::size_t row, std::size_t col, D &&data)
Stores the given data object at the given position.
Definition blocktensor.hpp:213
auto dot(const BlockTensor< T, Rows, Cols > &other) const
Returns the dot product of two BlockTensor objects.
Definition blocktensor.hpp:876
auto trace() const
Returns the trace of the block tensor.
Definition blocktensor.hpp:811
auto det() const
Returns the determinant of a square block tensor.
Definition blocktensor.hpp:232
T & operator()(std::size_t row, std::size_t col)
Returns a non-constant reference to entry (row, col)
Definition blocktensor.hpp:204
auto norm_(std::index_sequence< Is... >) const
Returns the norm of the BlockTensor object.
Definition blocktensor.hpp:835
auto normalize() const
Returns the normalized BlockTensor object.
Definition blocktensor.hpp:859
void pretty_print(std::ostream &os) const noexcept override
Returns a string representation of the BlockTensor object.
Definition blocktensor.hpp:169
static constexpr std::size_t rows()
Returns the number of rows.
Definition blocktensor.hpp:166
Compile-time block tensor core.
Definition blocktensor.hpp:47
const std::array< std::shared_ptr< T >,(Dims *...)> & data() const
Returns a constant reference to the data array.
Definition blocktensor.hpp:106
static constexpr auto dims()
Returns all dimensions as array.
Definition blocktensor.hpp:87
BlockTensorCore(BlockTensorCore< Ts, dims... > &&...other)
Constructor from BlockTensorCore objects.
Definition blocktensor.hpp:59
BlockTensorCore(Ts &&...data)
Constructor from variadic templates.
Definition blocktensor.hpp:83
BlockTensorCore()=default
Default constructor.
std::shared_ptr< T > & operator[](std::size_t idx)
Returns a non-constant shared pointer to entry (idx)
Definition blocktensor.hpp:120
void pretty_print(std::ostream &os) const noexcept override=0
Returns a string representation of the BlockTensorCore object.
T & set(std::size_t idx, Data &&data)
Stores the given data object at the given index.
Definition blocktensor.hpp:138
static constexpr std::size_t dim()
Returns the i-th dimension.
Definition blocktensor.hpp:92
const std::shared_ptr< T > & operator[](std::size_t idx) const
Returns a constant shared pointer to entry (idx)
Definition blocktensor.hpp:114
BlockTensorCore(BlockTensor< Ts, dims... > &&...other)
Constructor from BlockTensor objects.
Definition blocktensor.hpp:71
static constexpr std::size_t entries()
Returns the total number of entries.
Definition blocktensor.hpp:103
std::array< std::shared_ptr< T >,(Dims *...)> & data()
Returns a non-constant reference to the data array.
Definition blocktensor.hpp:111
static constexpr std::size_t size()
Returns the number of dimensions.
Definition blocktensor.hpp:100
const T & operator()(std::size_t idx) const
Returns a constant reference to entry (idx)
Definition blocktensor.hpp:126
std::array< std::shared_ptr< T >,(Dims *...)> data_
Array storing the data.
Definition blocktensor.hpp:51
T & operator()(std::size_t idx)
Returns a non-constant reference to entry (idx)
Definition blocktensor.hpp:132
Full qualified name descriptor.
Definition fqn.hpp:22
Full qualified name utility functions.
Definition blocktensor.hpp:24
auto addcmul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 multiplied by the elements of tensor2,...
Definition blocktensor.hpp:1236
auto log2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-2 of the elements of input
Definition blocktensor.hpp:1451
auto tan(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the tangent of the elements of input.
Definition blocktensor.hpp:1613
auto square(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square of the elements of input
Definition blocktensor.hpp:1587
auto mul(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the product of each element of input and other
Definition blocktensor.hpp:1503
auto divide(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for div()
Definition blocktensor.hpp:1362
auto exp2(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the base-2 exponential of the elements of input
Definition blocktensor.hpp:1394
auto frexp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the decomposition of the elements of input into mantissae and exponen...
Definition blocktensor.hpp:1422
auto xlogy(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Computes input * log(other)
Definition blocktensor.hpp:1624
auto operator-(const BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Subtracts one compile-time block tensor from another and returns a new compile-time block tensor.
Definition blocktensor.hpp:1677
auto floor(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the floor of the elements of input, the largest integer less than or ...
Definition blocktensor.hpp:1410
auto bitwise_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the bitwise NOT of the elements of input
Definition blocktensor.hpp:1291
auto i0(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise zeroth order modified Bessel function of the first k...
Definition blocktensor.hpp:1485
auto float_power(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input raised to the power of exponent,...
Definition blocktensor.hpp:1406
auto round(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input rounded to the nearest integer.
Definition blocktensor.hpp:1544
auto hypot(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
logit
Definition blocktensor.hpp:1480
auto imag(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the imaginary values of the elements of input
Definition blocktensor.hpp:1426
auto atan(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arctangent of the elements of input
Definition blocktensor.hpp:1266
auto copysign(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the magnitude of the elements of input and the sign of the elements o...
Definition blocktensor.hpp:1343
auto add(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Returns a new block tensor with the elements of other, scaled by alpha, added to the elements of inpu...
Definition blocktensor.hpp:1188
auto asin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the arcsine of the elements of input
Definition blocktensor.hpp:1252
bool operator==(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are equal.
Definition blocktensor.hpp:1755
auto angle(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the angle (in radians) of the elements of input
Definition blocktensor.hpp:1248
auto nextafter(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Return a new block tensor with the next elementwise floating-point value after input towards other
Definition blocktensor.hpp:1517
auto arcsinh(const BlockTensor< T, Dims... > &input)
Alias for asinh()
Definition blocktensor.hpp:1262
auto sub(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Subtracts other, scaled by alpha, from input.
Definition blocktensor.hpp:1591
auto sign(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input
Definition blocktensor.hpp:1559
auto logical_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical AND of the elements of input and other
Definition blocktensor.hpp:1463
auto absolute(const BlockTensor< T, Dims... > &input)
Alias for abs()
Definition blocktensor.hpp:1169
auto fix(const BlockTensor< T, Dims... > &input)
Alias for trunc()
Definition blocktensor.hpp:1401
auto sinc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the normalized sinc of the elements of input
Definition blocktensor.hpp:1575
auto logical_not(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the element-wise logical NOT of the elements of input
Definition blocktensor.hpp:1467
auto positive(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the input
Definition blocktensor.hpp:1520
auto sqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the square-root of the elements of input
Definition blocktensor.hpp:1583
auto reciprocal(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the elements of input
Definition blocktensor.hpp:1536
auto clip(const BlockTensor< T, Dims... > &input, U min, U max)
Alias for clamp()
Definition blocktensor.hpp:1330
auto arcsin(const BlockTensor< T, Dims... > &input)
Alias for asin()
Definition blocktensor.hpp:1255
auto bitwise_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise OR of the elements of input and other
Definition blocktensor.hpp:1299
auto atanh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic tangent of the elements of input
Definition blocktensor.hpp:1273
auto subtract(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other, V alpha=1.0)
Alias for sub()
Definition blocktensor.hpp:1602
auto atan2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the arctangent of the elements in input and other with consideration ...
Definition blocktensor.hpp:1281
auto expit(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the expit (also known as the logistic sigmoid function) of the elemen...
Definition blocktensor.hpp:1552
auto rsqrt(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the reciprocal of the square-root of the elements of input
Definition blocktensor.hpp:1548
auto sin(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the sine of the elements of input
Definition blocktensor.hpp:1571
auto cosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1351
std::ostream & operator<<(std::ostream &os, const BlockTensorCore< T, Dims... > &obj)
Prints (as string) a compile-time block tensor object.
Definition blocktensor.hpp:150
auto bitwise_and(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise AND of the elements of input and other
Definition blocktensor.hpp:1295
auto erfc(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the complementary error function of the elements of input
Definition blocktensor.hpp:1382
auto gammainc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized lower incomplete gamma function of each element of in...
Definition blocktensor.hpp:1489
auto operator-=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Decrements one compile-time block tensor by another.
Definition blocktensor.hpp:1707
auto arccos(const BlockTensor< T, Dims... > &input)
Alias for acos()
Definition blocktensor.hpp:1176
auto negative(const BlockTensor< T, Dims... > &input)
Alias for neg()
Definition blocktensor.hpp:1513
auto multiply(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for mul()
Definition blocktensor.hpp:1506
auto dot(const BlockTensor< T, Rows, Cols > &input, const BlockTensor< T, Rows, Cols > &tensor)
Returns a new block tensor with the dot product of the two input block tensors.
Definition blocktensor.hpp:1371
auto logaddexp2(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input ...
Definition blocktensor.hpp:1459
auto pow(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the power of each element in input with exponent other
Definition blocktensor.hpp:1524
auto bitwise_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the bitwise XOR of the elements of input and other
Definition blocktensor.hpp:1303
auto ldexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input multiplied by 2**other.
Definition blocktensor.hpp:1430
auto igammac(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1499
auto neg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the negative of the elements of input
Definition blocktensor.hpp:1510
auto exp(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential of the elements of input
Definition blocktensor.hpp:1390
auto arctan(const BlockTensor< T, Dims... > &input)
Alias for atan()
Definition blocktensor.hpp:1269
auto log1p(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of (1 + the elements of input)
Definition blocktensor.hpp:1447
auto arctanh(const BlockTensor< T, Dims... > &input)
Alias for atanh()
Definition blocktensor.hpp:1276
auto ceil(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the ceil of the elements of input, the smallest integer greater than ...
Definition blocktensor.hpp:1316
auto trunc(const BlockTensor< T, Dims... > &input)
Returns a new tensor with the truncated integer values of the elements of input.
Definition blocktensor.hpp:1621
auto cos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the cosine of the elements of input
Definition blocktensor.hpp:1347
auto erfinv(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse error function of the elements of input
Definition blocktensor.hpp:1386
bool operator!=(const BlockTensor< T, TDims... > &lhs, const BlockTensor< U, UDims... > &rhs)
Returns true if both compile-time block tensors are not equal.
Definition blocktensor.hpp:1770
auto expm1(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the exponential minus 1 of the elements of input
Definition blocktensor.hpp:1398
auto signbit(const BlockTensor< T, Dims... > &input)
Tests if each element of input has its sign bit set (is less than zero) or not.
Definition blocktensor.hpp:1567
auto conj_physical(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the conjugate of the elements of input tensor.
Definition blocktensor.hpp:1339
auto sinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the hyperbolic sine of the elements of input
Definition blocktensor.hpp:1579
auto remainder(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the modulus of the elements of input
Definition blocktensor.hpp:1540
auto digamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithmic derivative of the gamma function of the elements of i...
Definition blocktensor.hpp:1366
auto asinh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic sine of the elements of input
Definition blocktensor.hpp:1259
auto div(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the elements of input divided by the elements of other
Definition blocktensor.hpp:1359
auto igamma(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Alias for gammainc()
Definition blocktensor.hpp:1492
auto bitwise_left_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the left arithmetic shift of the elements of input by other bits.
Definition blocktensor.hpp:1307
auto rad2deg(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with each of the elements of input converted from angles in radians to deg...
Definition blocktensor.hpp:1528
auto real(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the real values of the elements of input
Definition blocktensor.hpp:1532
auto lgamma(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the natural logarithm of the absolute value of the gamma function of ...
Definition blocktensor.hpp:1435
auto gammaincc(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the regularized upper incomplete gamma function of each element of in...
Definition blocktensor.hpp:1496
auto acos(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse cosine of the elements of input
Definition blocktensor.hpp:1173
auto fmod(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the fmod of the elements of input and other
Definition blocktensor.hpp:1414
auto bitwise_right_shift(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the right arithmetic shift of the element of input by other bits.
Definition blocktensor.hpp:1311
auto operator+=(BlockTensor< T, Dims... > &lhs, const BlockTensor< U, Dims... > &rhs)
Increments one compile-time block tensor by another.
Definition blocktensor.hpp:1659
auto sgn(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the signs of the elements of input, extension to complex value.
Definition blocktensor.hpp:1563
auto arccosh(const BlockTensor< T, Dims... > &input)
Alias for acosh()`.
Definition blocktensor.hpp:1183
auto abs(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the absolute value of the elements of input
Definition blocktensor.hpp:1166
auto logical_xor(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical XOR of the elements of input and other
Definition blocktensor.hpp:1475
auto addcdiv(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &tensor1, const BlockTensor< V, Dims... > &tensor2, W value=1.0)
Returns a new block tensor with the elements of tensor1 divided by the elements of tensor2,...
Definition blocktensor.hpp:1221
auto deg2rad(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the elements of input converted from angles in degrees to radians.
Definition blocktensor.hpp:1355
auto logaddexp(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block-vector with the logarithm of the sum of exponentiations of the elements of input
Definition blocktensor.hpp:1455
auto erf(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the error function of the elements of input
Definition blocktensor.hpp:1378
auto operator*(const BlockTensor< T, Rows, Common > &lhs, const BlockTensor< U, Common, Cols > &rhs)
Multiplies one compile-time rank-2 block tensor with another compile-time rank-2 block tensor.
Definition blocktensor.hpp:895
auto log10(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the logarithm to the base-10 of the elements of input
Definition blocktensor.hpp:1443
auto make_shared(T &&arg)
Returns a std::shared_ptr<T> object from arg.
Definition blocktensor.hpp:35
auto acosh(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the inverse hyperbolic cosine of the elements of input
Definition blocktensor.hpp:1180
auto clamp(const BlockTensor< T, Dims... > &input, U min, U max)
Returns a new block tensor with the elements of input clamped into the range [ min,...
Definition blocktensor.hpp:1321
auto logical_or(const BlockTensor< T, Dims... > &input, const BlockTensor< U, Dims... > &other)
Returns a new block tensor with the element-wise logical OR of the elements of input and other
Definition blocktensor.hpp:1471
auto frac(const BlockTensor< T, Dims... > &input)
Returns a new block tensor with the fractional portion of the elements of input
Definition blocktensor.hpp:1418
Forward declaration of BlockTensor.
Definition blocktensor.hpp:43
constexpr auto operator+(deriv lhs, deriv rhs)
Adds two enumerators for specifying the derivative of B-spline evaluation.
Definition bspline.hpp:89
log
Enumerator for specifying the logging level.
Definition core.hpp:90
Type trait checks if template argument is of type std::shared_ptr<T>
Definition blocktensor.hpp:28