61 switch (t1.sizes().size()) {
63 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
64 t1.repeat({t0.size(dim)}));
66 if constexpr (dim == 0)
67 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
68 t1.repeat({t0.size(dim), 1}));
69 else if constexpr (dim == 1 || dim == -1)
70 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
71 t1.repeat({1, t0.size(dim)}));
73 if constexpr (dim == 0)
74 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
75 t1.repeat({t0.size(dim), 1, 1}));
76 else if constexpr (dim == 1)
77 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
78 t1.repeat({1, t0.size(dim), 1}));
79 else if constexpr (dim == 2 || dim == -1)
80 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
81 t1.repeat({1, 1, t0.size(dim)}));
83 if constexpr (dim == 0)
84 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
85 t1.repeat({t0.size(dim), 1, 1, 1}));
86 else if constexpr (dim == 1)
87 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
88 t1.repeat({1, t0.size(dim), 1, 1}));
89 else if constexpr (dim == 2)
90 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
91 t1.repeat({1, 1, t0.size(dim), 1}));
92 else if constexpr (dim == 3 || dim == -1)
93 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
94 t1.repeat({1, 1, 1, t0.size(dim)}));
96 if constexpr (dim == 0)
97 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
98 t1.repeat({t0.size(dim), 1, 1, 1, 1}));
99 else if constexpr (dim == 1)
100 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
101 t1.repeat({1, t0.size(dim), 1, 1, 1}));
102 else if constexpr (dim == 2)
103 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
104 t1.repeat({1, 1, t0.size(dim), 1, 1}));
105 else if constexpr (dim == 3)
106 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
107 t1.repeat({1, 1, 1, t0.size(dim), 1}));
108 else if constexpr (dim == 4 || dim == -1)
109 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
110 t1.repeat({1, 1, 1, 1, t0.size(dim)}));
112 if constexpr (dim == 0)
113 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
114 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1}));
115 else if constexpr (dim == 1)
116 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
117 t1.repeat({1, t0.size(dim), 1, 1, 1, 1}));
118 else if constexpr (dim == 2)
119 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
120 t1.repeat({1, 1, t0.size(dim), 1, 1, 1}));
121 else if constexpr (dim == 3)
122 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
123 t1.repeat({1, 1, 1, t0.size(dim), 1, 1}));
124 else if constexpr (dim == 4)
125 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
126 t1.repeat({1, 1, 1, 1, t0.size(dim), 1}));
127 else if constexpr (dim == 5 || dim == -1)
128 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
129 t1.repeat({1, 1, 1, 1, 1, t0.size(dim)}));
131 if constexpr (dim == 0)
132 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
133 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1}));
134 else if constexpr (dim == 1)
135 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
136 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1}));
137 else if constexpr (dim == 2)
138 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
139 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1}));
140 else if constexpr (dim == 3)
141 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
142 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1}));
143 else if constexpr (dim == 4)
144 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
145 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1}));
146 else if constexpr (dim == 5)
147 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
148 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1}));
149 else if constexpr (dim == 6 || dim == -1)
150 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
151 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim)}));
153 if constexpr (dim == 0)
154 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
155 t1.repeat({t0.size(dim), 1, 1, 1, 1, 1, 1, 1}));
156 else if constexpr (dim == 1)
157 return torch::mul(t0.repeat_interleave(t1.size(dim), 1),
158 t1.repeat({1, t0.size(dim), 1, 1, 1, 1, 1, 1}));
159 else if constexpr (dim == 2)
160 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
161 t1.repeat({1, 1, t0.size(dim), 1, 1, 1, 1, 1}));
162 else if constexpr (dim == 3)
163 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
164 t1.repeat({1, 1, 1, t0.size(dim), 1, 1, 1, 1}));
165 else if constexpr (dim == 4)
166 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
167 t1.repeat({1, 1, 1, 1, t0.size(dim), 1, 1, 1}));
168 else if constexpr (dim == 5)
169 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
170 t1.repeat({1, 1, 1, 1, 1, t0.size(dim), 1, 1}));
171 else if constexpr (dim == 6)
172 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
173 t1.repeat({1, 1, 1, 1, 1, 1, t0.size(dim), 1}));
174 else if constexpr (dim == 7 || dim == -1)
175 return torch::mul(t0.repeat_interleave(t1.size(dim), 0),
176 t1.repeat({1, 1, 1, 1, 1, 1, 1, t0.size(dim)}));
178 throw std::runtime_error(
"Unsupported tensor dimension");