-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
respect train/eval mode in traced network #1217
base: main
Are you sure you want to change the base?
Conversation
Note that this maybe also fixes https://github.com/mlverse/torch/pull/633/files but I need to check again |
.github/workflows/main.yaml
Outdated
@@ -27,7 +27,7 @@ jobs: | |||
config: | |||
|
|||
- {os: macOS, r_version: release, version: cpu-intel, runner: macos-13} | |||
- {os: macOS, r_version: release, version: cpu-m1, runner: macos-13} | |||
- {os: macOS, r_version: release, version: cpu-m1, runner: macos-latest} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I was looking at why we keep failing on some tests and it seems that github runners, even though running on arm64 images cannot run MPS, as this API can't be acccessed by VM's under macOS, so it requires real self-hosted machines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, fixed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think something might be broken with the custom macos runner, at least it's taking forever to start the job.
@dfalbel I think something is broken with the macOS runner. I don't think that this PR should behave differently on different operating systems, however. |
The runner was missing a |
Thanks, but I still think there is something off with the M1 runner: https://github.com/mlverse/torch/actions/runs/12633265519/job/35213843863 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sebffischer Looks good! I added some comments, let me know what you think!
tests/testthat/assets/linear.pt
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file already existed before and was used in test-script_module.R, I merely updated it
should_mangle = TRUE, | ||
manage_memory = FALSE | ||
) | ||
mod$eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should add an argument to jit_trace
that would keep the old behavior. My two concerns are:
- Tracing runs the network twice, which could be problematic for some users.
- Duplicates the size of the graph, which might be undesidered. Maybe a user wants to just trace the forward method in eval mode to export for deployment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, makes sense. But then I think calling $train()
and $eval()
should maybe result in an error, what do you think?
And should the default be to respect the train/eval-mode or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually i think there should maybe be no default to force the user to specify this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I changed my mind again, I think the default TRUE is fine, but I am happy to change it as well.
In which cases do you think running the network twice is problematic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a respect_mode
argument that triggers the double/single tracing
Further issues:
CompilationUnit
per compiled module