-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow linalg.lstsq to use svd to compute the result for rank deficient matrices. #126652
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126652
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fd07d5e with merge base 853081a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
solution.set_(solution.storage(), solution_view.storage_offset(), | ||
solution_view.sizes(), solution_view.strides()); | ||
} else { | ||
solution = at::zeros({solution.size(-1), n}, solution.options()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is going on here??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're referring to just everything inside the else
clause correct?
I found that with a tensor A
that has rows > cols
A = torch.tensor([[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
[7.0, 8.0]], device='cuda')
# Create tensor B with shape (4, 1)
B = torch.tensor([[1.0],
[2.0],
[3.0],
[4.0]], device='cuda')
X_lstsq = torch.linalg.lstsq(A, B, driver='gelss').solution
would lead to
RuntimeError: start (2) + length (2) exceeds dimension size (2).
Is this incorrect? I'm refreshing my linear algebra here and I might not have the correct understanding.
def svd_lstsq(AA, BB, tol=1e-5):
U, S, Vh = torch.linalg.svd(AA, full_matrices=False)
Spinv = torch.zeros_like(S)
Spinv[S>tol] = 1/S[S>tol]
UhBB = U.adjoint() @ BB
if Spinv.ndim!=UhBB.ndim:
Spinv = Spinv.unsqueeze(-1)
SpinvUhBB = Spinv * UhBB
return Vh.adjoint() @ SpinvUhBB
X_svd= svd_lstsq(A, B)
This fo example will not throw an error with the same tensors.
Also I should have clarified this earlier. Sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why should you allocate a new tensor when you already have a solution allocated in the else path? And why a tensor of zeros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, you're right, I shouldn't allocate a new tensor.
So wrt the zeros, in hindsight, it makes no sense to zero it out as that's not the correct behavior(this should have a solution, right?). An exception actually tells the user too where this is just silent UB. How do we handle this case though? Does solution
need to be reshape
d before using set_
or something in the else
path?
Since im still new, curious to know if this is out of scope for this PR? This exception occurs in general for when the solution.size(-2) < n
. I don't mind doing it in this PR since it is small(better use of github runners too rather than 2 split PRs).
if (input.numel() == 0) { | ||
auto output_shape = input.sizes().vec(); | ||
output_shape.back() = other.size(-1); | ||
rank.zero_(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank
is required later on, this solves the problem of the integer overflow later when toInt()
is called, because it wasn't set to anything.
Fixes #117122
This PR adds the logic so that in the case of rank deficient matrices, it can fallback to an SVD backend for batched mode.
I apologize for the previous PR... I messed up a rebase and it ended up showing a million changes.
cc @lezcano