47inline auto VSlice(torch::Tensor index, int64_t start_offset,
48 int64_t stop_offset) {
49 if constexpr (transpose)
50 return index.repeat_interleave(stop_offset - start_offset) +
51 torch::linspace(start_offset, stop_offset - 1,
52 stop_offset - start_offset, index.options())
53 .repeat(index.numel());
55 return index.repeat(stop_offset - start_offset) +
56 torch::linspace(start_offset, stop_offset - 1,
57 stop_offset - start_offset, index.options())
58 .repeat_interleave(index.numel());
73 const std::array<int64_t, N> &start_offset,
74 const std::array<int64_t, N> &stop_offset,
75 const std::array<int64_t, N - 1> &leading_dim =
76 make_array<int64_t, N - 1>(1)) {
79 for (std::size_t i = 1; i < N; ++i)
80 assert(index[i - 1].numel() == index[i].numel());
82 auto dist = stop_offset - start_offset;
84 if constexpr (transpose) {
87 auto vslice_summand_ = [&](std::size_t k) {
89 return (index[k].repeat_interleave(
utils::prod(dist, 0, k)) +
90 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
93 .repeat(index[0].numel())) *
96 if constexpr (N == 2) {
97 return index[0].repeat_interleave(dist[0]).repeat_interleave(
99 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
101 .repeat(index[1].numel())
104 return index[0].repeat_interleave(dist[0]).repeat_interleave(
106 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
108 .repeat(index[0].numel())
114 .repeat_interleave(
utils::prod(dist, k + 1, N - 1)) +
115 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
118 .repeat(index[0].numel())
125 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
126 return (vslice_summand_(Is) + ...);
129 return vslice_(std::make_index_sequence<N>{});
133 auto vslice_summand_ = [&](std::size_t k) {
136 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
138 .repeat_interleave(index[0].numel() *
142 if constexpr (N == 2) {
143 return (index[0].repeat(dist[0]) +
144 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
146 .repeat_interleave(index[0].numel()))
149 return (index[0].repeat(dist[0]) +
150 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
152 .repeat_interleave(index[0].numel()))
157 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
159 .repeat_interleave(index[0].numel() *
167 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
168 return (vslice_summand_(Is) + ...);
171 return vslice_(std::make_index_sequence<N>{});
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:221