45inline auto VSlice(torch::Tensor index, int64_t start_offset,
46 int64_t stop_offset) {
47 if constexpr (transpose)
48 return index.repeat_interleave(stop_offset - start_offset) +
49 torch::linspace(start_offset, stop_offset - 1,
50 stop_offset - start_offset, index.options())
51 .repeat(index.numel());
53 return index.repeat(stop_offset - start_offset) +
54 torch::linspace(start_offset, stop_offset - 1,
55 stop_offset - start_offset, index.options())
56 .repeat_interleave(index.numel());
71 const std::array<int64_t, N> &start_offset,
72 const std::array<int64_t, N> &stop_offset,
73 const std::array<int64_t, N - 1> &leading_dim =
74 make_array<int64_t, N - 1>(1)) {
77 for (std::size_t i = 1; i < N; ++i)
78 assert(index[i - 1].numel() == index[i].numel());
80 auto dist = stop_offset - start_offset;
82 if constexpr (transpose) {
85 auto vslice_summand_ = [&](std::size_t k) {
87 return (index[k].repeat_interleave(
utils::prod(dist, 0, k)) +
88 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
91 .repeat(index[0].numel())) *
94 if constexpr (N == 2) {
95 return index[0].repeat_interleave(dist[0]).repeat_interleave(
97 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
99 .repeat(index[1].numel())
102 return index[0].repeat_interleave(dist[0]).repeat_interleave(
104 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
106 .repeat(index[0].numel())
112 .repeat_interleave(
utils::prod(dist, k + 1, N - 1)) +
113 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
116 .repeat(index[0].numel())
123 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
124 return (vslice_summand_(Is) + ...);
127 return vslice_(std::make_index_sequence<N>{});
131 auto vslice_summand_ = [&](std::size_t k) {
134 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
136 .repeat_interleave(index[0].numel() *
140 if constexpr (N == 2) {
141 return (index[0].repeat(dist[0]) +
142 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
144 .repeat_interleave(index[0].numel()))
147 return (index[0].repeat(dist[0]) +
148 torch::linspace(start_offset[0], stop_offset[0] - 1, dist[0],
150 .repeat_interleave(index[0].numel()))
155 torch::linspace(start_offset[k], stop_offset[k] - 1, dist[k],
157 .repeat_interleave(index[0].numel() *
165 auto vslice_ = [&]<std::size_t... Is>(std::index_sequence<Is...>) {
166 return (vslice_summand_(Is) + ...);
169 return vslice_(std::make_index_sequence<N>{});
T prod(std::array< T, N > array, std::size_t start_index=0, std::size_t stop_index=N - 1)
Computes the (partial) product of all std::array entries.
Definition linalg.hpp:220