IgANet
IGAnets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
vslice.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <array>
18
19#include <core/core.hpp>
20#include <utils/linalg.hpp>
21#include <utils/tensorarray.hpp>
22
23namespace iganet::utils {
24
44template <bool transpose = false>
45inline auto VSlice(torch::Tensor index, int64_t start_offset,
46 int64_t stop_offset) {
47 if constexpr (transpose)
48 return index.repeat_interleave(stop_offset - start_offset) +
49 torch::linspace(start_offset, stop_offset - 1,
50 stop_offset - start_offset, index.options())
51 .repeat(index.numel());
52 else
53 return index.repeat(stop_offset - start_offset) +
54 torch::linspace(start_offset, stop_offset - 1,
55 stop_offset - start_offset, index.options())
56 .repeat_interleave(index.numel());
57}
58
69template <bool transpose = false, std::size_t N>
70inline auto VSlice(const utils::TensorArray<N> &index,
71 const std::array<int64_t, N> &start_offset,
72 const std::array<int64_t, N> &stop_offset,
73 const std::array<int64_t, N - 1> &leading_dim =
74 make_array<int64_t, N - 1>(1)) {
75
76 // Check compatibility of arguments
77 for (std::size_t i = 1; i < N; ++i)
78 assert(index[i - 1].numel() == index[i].numel());
79
80 auto dist = stop_offset - start_offset;
81
82 if constexpr (transpose) {
83
84 // Lambda expression to evaluate the k-th summand of the vslice
85 auto vslice_summand_ = [&](std::size_t k) {
86 if (k == N - 1) {
87 return (index[k].repeat_interleave(utils::prod(dist, 0, k)) +
88 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
89 index[0].options())
90 .repeat_interleave(utils::prod(dist, 0, k - 1))
91 .repeat(index[0].numel())) *
92 utils::prod(leading_dim, 0, k - 1);
93 } else if (k == 0) {
94 if constexpr (N == 2) {
95 return index[0].repeat_interleave(dist[0]).repeat_interleave(
96 dist[1]) +
97 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
98 index[0].options())
99 .repeat(index[1].numel())
100 .repeat(dist[1]);
101 } else { // N > 2
102 return index[0].repeat_interleave(dist[0]).repeat_interleave(
103 utils::prod(dist, 1, N - 1)) +
104 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
105 index[0].options())
106 .repeat(index[0].numel())
107 .repeat(utils::prod(dist, 1, N - 1));
108 }
109 } else {
110 return (index[k]
111 .repeat_interleave(utils::prod(dist, 0, k))
112 .repeat_interleave(utils::prod(dist, k + 1, N - 1)) +
113 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
114 index[0].options())
115 .repeat_interleave(utils::prod(dist, 0, k - 1))
116 .repeat(index[0].numel())
117 .repeat(utils::prod(dist, k + 1, N - 1))) *
118 utils::prod(leading_dim, 0, k - 1);
119 }
120 };
121
122 // Lambda expression to evaluate the vslice
123 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
124 return (vslice_summand_(Is) + ...);
125 };
126
127 return vslice_(std::make_index_sequence<N>{});
128 } else {
129
130 // Lambda expression to evaluate the k-th summand of the vslice
131 auto vslice_summand_ = [&](std::size_t k) {
132 if (k == N - 1) {
133 return (index[k].repeat(utils::prod(dist, 0, k)) +
134 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
135 index[0].options())
136 .repeat_interleave(index[0].numel() *
137 utils::prod(dist, 0, k - 1))) *
138 utils::prod(leading_dim, 0, k - 1);
139 } else if (k == 0) {
140 if constexpr (N == 2) {
141 return (index[0].repeat(dist[0]) +
142 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
143 index[0].options())
144 .repeat_interleave(index[0].numel()))
145 .repeat(utils::prod(dist, k + 1, N - 1));
146 } else { // N > 2
147 return (index[0].repeat(dist[0]) +
148 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
149 index[0].options())
150 .repeat_interleave(index[0].numel()))
151 .repeat(utils::prod(dist, k + 1, N - 1));
152 }
153 } else {
154 return (index[k].repeat(utils::prod(dist, 0, k)) +
155 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
156 index[0].options())
157 .repeat_interleave(index[0].numel() *
158 utils::prod(dist, 0, k - 1)))
159 .repeat(utils::prod(dist, k + 1, N - 1)) *
160 utils::prod(leading_dim, 0, k - 1);
161 }
162 };
163
164 // Lambda expression to evaluate the vslice
165 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
166 return (vslice_summand_(Is) + ...);
167 };
168
169 return vslice_(std::make_index_sequence<N>{});
170 }
171}
172
173} // namespace iganet::utils
Core components.
Linear algebra utility functions.
Definition blocktensor.hpp:24
std::array< torch::Tensor, N > TensorArray
Definition tensorarray.hpp:26
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:220
auto VSlice(torch::Tensor index, int64_t start_offset, int64_t stop_offset)
Vectorized version of torch::indexing::Slice (see https://pytorch.org/cppdocs/notes/tensor_indexing....
Definition vslice.hpp:45
TensorArray utility functions.