Skip to content

Commit

Permalink
add valueerror for != 4 corners and create matrices directly as float64
Browse files Browse the repository at this point in the history
  • Loading branch information
Parskatt committed Feb 5, 2024
1 parent 8e4c931 commit 159d577
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,16 +676,17 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i
Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""
# TODO: this should raise an error if < 4 points are provided (and probably also if more than 4) //Parskatt
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float)
if len(startpoints) != 4 or len(endpoints) != 4:
raise ValueError(f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints.")
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)

for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])

b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8)
b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
# do least squares in double precision to prevent numerical issues
res = torch.linalg.lstsq(a_matrix.double(), b_matrix.double(), driver="gels").solution.to(torch.float32)
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)

output: List[float] = res.tolist()
return output
Expand Down

0 comments on commit 159d577

Please sign in to comment.