IgANet
IgANets - Isogeometric Analysis Networks
Loading...
Searching...
No Matches
vslice.hpp
Go to the documentation of this file.
1
15#pragma once
16
17#include <array>
18
19#include <utils/linalg.hpp>
20#include <utils/tensorarray.hpp>
21
22#include <torch/torch.h>
23
24namespace iganet {
25namespace utils {
26
46template <bool transpose = false>
47inline auto VSlice(torch::Tensor index, int64_t start_offset,
49 if constexpr (transpose)
50 return index.repeat_interleave(stop_offset - start_offset) +
51 torch::linspace(start_offset, stop_offset - 1,
52 stop_offset - start_offset, index.options())
53 .repeat(index.numel());
54 else
55 return index.repeat(stop_offset - start_offset) +
56 torch::linspace(start_offset, stop_offset - 1,
57 stop_offset - start_offset, index.options())
58 .repeat_interleave(index.numel());
59}
60
71template <bool transpose = false, std::size_t N>
72inline auto VSlice(const utils::TensorArray<N> &index,
73 const std::array<int64_t, N> &start_offset,
74 const std::array<int64_t, N> &stop_offset,
75 const std::array<int64_t, N - 1> &leading_dim =
77
78 // Check compatibility of arguments
79 for (std::size_t i = 1; i < N; ++i)
80 assert(index[i - 1].numel() == index[i].numel());
81
83
84 if constexpr (transpose) {
85
86 // Lambda expression to evaluate the k-th summand of the vslice
87 auto vslice_summand_ = [&](std::size_t k) {
88 if (k == N - 1) {
89 return (index[k].repeat_interleave(utils::prod(dist, 0, k)) +
90 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
91 index[0].options())
93 .repeat(index[0].numel())) *
94 utils::prod(leading_dim, 0, k - 1);
95 } else if (k == 0) {
96 if constexpr (N == 2) {
97 return index[0].repeat_interleave(dist[0]).repeat_interleave(
98 dist[1]) +
99 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
100 index[0].options())
101 .repeat(index[1].numel())
102 .repeat(dist[1]);
103 } else { // N > 2
104 return index[0].repeat_interleave(dist[0]).repeat_interleave(
105 utils::prod(dist, 1, N - 1)) +
106 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
107 index[0].options())
108 .repeat(index[0].numel())
109 .repeat(utils::prod(dist, 1, N - 1));
110 }
111 } else {
112 return (index[k]
114 .repeat_interleave(utils::prod(dist, k + 1, N - 1)) +
115 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
116 index[0].options())
118 .repeat(index[0].numel())
119 .repeat(utils::prod(dist, k + 1, N - 1))) *
120 utils::prod(leading_dim, 0, k - 1);
121 }
122 };
123
124 // Lambda expression to evaluate the vslice
125 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
126 return (vslice_summand_(Is) + ...);
127 };
128
129 return vslice_(std::make_index_sequence<N>{});
130 } else {
131
132 // Lambda expression to evaluate the k-th summand of the vslice
133 auto vslice_summand_ = [&](std::size_t k) {
134 if (k == N - 1) {
135 return (index[k].repeat(utils::prod(dist, 0, k)) +
136 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
137 index[0].options())
138 .repeat_interleave(index[0].numel() *
139 utils::prod(dist, 0, k - 1))) *
140 utils::prod(leading_dim, 0, k - 1);
141 } else if (k == 0) {
142 if constexpr (N == 2) {
143 return (index[0].repeat(dist[0]) +
144 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
145 index[0].options())
146 .repeat_interleave(index[0].numel()))
147 .repeat(utils::prod(dist, k + 1, N - 1));
148 } else { // N > 2
149 return (index[0].repeat(dist[0]) +
150 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
151 index[0].options())
152 .repeat_interleave(index[0].numel()))
153 .repeat(utils::prod(dist, k + 1, N - 1));
154 }
155 } else {
156 return (index[k].repeat(utils::prod(dist, 0, k)) +
157 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
158 index[0].options())
159 .repeat_interleave(index[0].numel() *
160 utils::prod(dist, 0, k - 1)))
161 .repeat(utils::prod(dist, k + 1, N - 1)) *
162 utils::prod(leading_dim, 0, k - 1);
163 }
164 };
165
166 // Lambda expression to evaluate the vslice
167 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
168 return (vslice_summand_(Is) + ...);
169 };
170
171 return vslice_(std::make_index_sequence<N>{});
172 }
173}
174
175} // namespace utils
176} // namespace iganet
Linear algebra utility functions.
std::array< torch::Tensor, N > TensorArray
Definition tensorarray.hpp:28
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:221
auto VSlice(torch::Tensor index, int64_t start_offset, int64_t stop_offset)
Vectorized version of torch::indexing::Slice (see https://pytorch.org/cppdocs/notes/tensor_indexing....
Definition vslice.hpp:47
Definition boundary.hpp:22
constexpr bool is_SplineType_v
Alias to the value of is_SplineType.
Definition bspline.hpp:3243
TensorArray utility functions.