24 const torch::Tensor b,
28 auto x = torch::zeros_like(b);
30 if (b.norm().item<
double>() < tol)
31 return std::make_tuple(x, -1, b.norm().item<
double>());
36 for (
int iter = 0; iter < max_iter; iter++) {
38 auto Ap = A.matmul(p);
39 auto beta = torch::dot(r, r);
40 auto alpha = beta / torch::dot(Ap, p);
45 if (r.norm().item<
double>() < tol)
46 return std::make_tuple(x, iter, r.norm().item<
double>());
48 beta = torch::dot(r, r) / beta;
52 return std::make_tuple(x, max_iter, r.norm().item<
double>());
58 const torch::Tensor b,
62 auto x = torch::zeros_like(b);
64 if (b.norm().item<
double>() < tol)
65 return std::make_tuple(x, -1, b.norm().item<
double>());
68 auto r_hat = b.clone();
70 auto alpha = torch::scalar_tensor(1.0, b.options());
71 auto omega = torch::scalar_tensor(1.0, b.options());
72 auto rho = torch::scalar_tensor(1.0, b.options());
74 auto p = torch::zeros_like(b);
75 auto v = torch::zeros_like(b);
77 for (
int iter = 0; iter < max_iter; iter++) {
79 auto rho_hat = torch::dot(r_hat, r);
80 auto beta = rho_hat / rho * alpha / omega;
82 p = r + beta * (p - omega * v);
85 alpha = rho_hat / torch::dot(r_hat, v);
86 auto s = r - alpha * v;
88 if (s.norm().item<
double>() < tol) {
90 return std::make_tuple(x, iter, s.norm().item<
double>());
94 omega = torch::dot(s, t) / torch::dot(t, t);
95 x += alpha * p + omega * s;
99 return std::make_tuple(x, max_iter, r.norm().item<
double>());
auto solve_cg(const torch::Tensor &A, const torch::Tensor b, int max_iter=1000, double tol=1e-10)
Solves the linear system A * x = b using the Conjugate Gradient (CG) method.
Definition solver.hpp:23
auto solve_bicgstab(const torch::Tensor &A, const torch::Tensor b, int max_iter=1000, double tol=1e-10)
Solves the linear system A * x = b using the Bi-Conjugate Gradient Stabilized (BiCGStab) method.
Definition solver.hpp:57