-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Interpolate tensor operation (Inference Only) #1246
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1246 +/- ##
========================================
Coverage 85.54% 85.54%
========================================
Files 595 596 +1
Lines 68280 68121 -159
========================================
- Hits 58409 58275 -134
+ Misses 9871 9846 -25 ☔ View full report in Codecov by Sentry. |
Hi @Nikaidou-Shinku Thanks for this great in-progress PR. I have some idea on how to implement the Autodiff for this operation and would like to help out on this. Let me know If there have been any new updates. Would be pushing my changes as well soon! |
Hello! I'm honestly at a loss as to how to implement autodiff, so it would be great if you could help!
I'm sorry that I'm a little busy during this time and may not be able to push this PR forward. But I will find a way to get it done before March, and I really hope to see this new feature in the next big version of Burn! |
feat: bilinear interpolation for tch, ndarray and wgpu backend fix: reduce test case size to avoid exceeding floating-point precision limits feat: support nearest-neighbor interpolation for ndarray backend feat: support nearest-neighbor interpolation for wgpu backend feat: support fusion backend fix: no-std support build: upgrade dependencies
Sorry I may not have time to continue working on this PR at the moment, I would be grateful if someone could take over it, or maybe consider merging the already completed parts and keeping the rest as TODO. I will list the current progress of this PR below. Completed
All these implementations give results consistent with PyTorch. Not done
#1393 has been filed to complete the remaining items |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nikaidou-Shinku , thank you so much for this contribution! This is a big step for the full feature. We can start working on the Resize ONNX op which was a major missing op for many image related ONNX could not be implemented.
It is very responsible of you to list of TODOs and what it is completed.
I am approving it so it could be merged. We can file all remaining TODOs into an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No specific correction comments. I think this looks great. The test cases cover both upsample, downsample operation. overall, LGTM!
|
||
/// Interpolation options. | ||
#[derive(new)] | ||
pub struct InterpolateOptions { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is great for future proofing. later on we can add other options like align corners here.
id % strides.2, | ||
); | ||
|
||
let y_in = (y_ratio * h as f64).floor() as usize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think It's great that we are keeping this consistent with OpenCV version of this operation. Just as FYI, some methods may use ceil() or round() as well to find nearest coordinate so we may see Candle/Tch backend's implementation output slightly differ. But i think this makes sense.
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Fixes #455
Changes
Notes
Resize
operator #510. Maybe we can implement it in this PR, or in a new PR after this PR is merged.Testing
I will write new test cases to test these features, but currently the test cases are lacking because many features are not completed yet.
Completed
All these implementations give results consistent with PyTorch.
Not done
align_corners
for bilinear and bicubic interpolation. (for more information check here) (Interpolate function for training #1393)nearest-exact
. (see the link previous for details) (Interpolate function for training #1393)