62 switch (
t1.sizes().size()) {
64 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 0),
65 t1.repeat({t0.size(dim)}));
67 if constexpr (dim == 0)
68 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 0),
69 t1.repeat({
t0.size(dim), 1}));
70 else if constexpr (dim == 1)
71 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 1),
72 t1.repeat({1,
t0.size(dim)}));
74 if constexpr (dim == 0)
75 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 0),
76 t1.repeat({
t0.size(dim), 1, 1}));
77 else if constexpr (dim == 1)
78 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 1),
79 t1.repeat({1,
t0.size(dim), 1}));
80 else if constexpr (dim == 2)
81 return torch::mul(
t0.repeat_interleave(
t1.size(dim), 0),
82 t1.repeat({1, 1,
t0.size(dim)}));
84 if constexpr (dim == 0)
85 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
86 t1.repeat({t0.size(dim), 1, 1, 1}));
87 else if constexpr (dim == 1)
88 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
89 t1.repeat({1, t0.size(dim), 1, 1}));
90 else if constexpr (dim == 2)
91 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
92 t1.repeat({1, 1, t0.size(dim), 1}));
93 else if constexpr (dim == 3)
94 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
95 t1.repeat({1, 1, 1, t0.size(dim)}));
97 if constexpr (dim == 0)
98 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
99 t1.repeat({t0.size(dim), 1, 1, 1, 1}));
100 else if constexpr (dim == 1)
101 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
102 t1.repeat({1, t0.size(dim), 1, 1, 1}));
103 else if constexpr (dim == 2)
104 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
105 t1.repeat({1, 1, t0.size(dim), 1, 1}));
106 else if constexpr (dim == 3)
107 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
108 t1.repeat({1, 1, 1, t0.size(dim), 1}));
109 else if constexpr (dim == 4)
110 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
111 t1.repeat({1, 1, 1, 1, t0.size(dim)}));
113 if constexpr (dim == 0)
114 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
115 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1}));
116 else if constexpr (dim == 1)
117 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
118 t1.repeat({1, t0.size(dim), 1, 1, 1, 1}));
119 else if constexpr (dim == 2)
120 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
121 t1.repeat({1, 1, t0.size(dim), 1, 1, 1}));
122 else if constexpr (dim == 3)
123 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
124 t1.repeat({1, 1, 1, t0.size(dim), 1, 1}));
125 else if constexpr (dim == 4)
126 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
127 t1.repeat({1, 1, 1, 1, t0.size(dim), 1}));
128 else if constexpr (dim == 5)
129 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
130 t1.repeat({1, 1, 1, 1, 1, t0.size(dim)}));
132 if constexpr (dim == 0)
133 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
134 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1}));
135 else if constexpr (dim == 1)
136 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
137 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1}));
138 else if constexpr (dim == 2)
139 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
140 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1}));
141 else if constexpr (dim == 3)
142 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
143 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1}));
144 else if constexpr (dim == 4)
145 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
146 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1}));
147 else if constexpr (dim == 5)
148 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
149 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1}));
150 else if constexpr (dim == 6)
151 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
152 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim)}));
154 if constexpr (dim == 0)
155 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
156 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1, 1}));
157 else if constexpr (dim == 1)
158 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
159 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1, 1}));
160 else if constexpr (dim == 2)
161 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
162 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1, 1}));
163 else if constexpr (dim == 3)
164 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
165 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1, 1}));
166 else if constexpr (dim == 4)
167 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
168 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1, 1}));
169 else if constexpr (dim == 5)
170 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
171 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1, 1}));
172 else if constexpr (dim == 6)
173 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
174 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim), 1}));
175 else if constexpr (dim == 7)
176 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
177 t1.repeat({1, 1, 1, 1, 1, 1, 1, t0.size(dim)}));
179 throw std::runtime_error(
"Unsupported tensor dimension");