-
Notifications
You must be signed in to change notification settings - Fork 119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BREAKING] Rust representation of GPU memory re-worked #412
Conversation
pub fn cuda_malloc_for_device(count: usize, device_id: usize) -> CudaResult<Self> { | ||
check_device(device_id); | ||
Self::cuda_malloc(count) | ||
} | ||
|
||
pub fn cuda_malloc_async_for_device(count: usize, stream: &CudaStream, device_id: usize) -> CudaResult<Self> { | ||
check_device(device_id); | ||
Self::cuda_malloc_async(count, stream) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats the use case here for passing a specific deviceId if we require it to match the current device's id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This version is preferred by @vhnatyk over implicit choice of device id in malloc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tl;dr - not that it's preferred but seems aligned with everything else. Longer version: the check ensures we maintain correct device_id through our calls - since we don't have implicit device_id
management elsewhere except here (by my legacy wip implementation and missed on review?). Maybe we can think of fully implicit or umm automated device_id
management inherent from the call to get_device()
on the current thread - but since we have the field in DeviceContext
- that makes it redundant and anyway feels like can cause multiple bugs? I guess I had to provide basic description in the doc - but I somehow hoped it will emerge from everyone's edit, since it was so intensively discussed:) have to update doc with the current version of API and diagrams
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think internal management of device_id
by thread makes the most sense. DeviceContext
can also get set this way by using get_device()
. It might be more bug prone on our end but I think it will be less error prone on the user end which we don't have control over
@@ -205,9 +227,8 @@ impl<T, const D_ID: usize> DeviceSlice<T, D_ID> { | |||
} | |||
} | |||
|
|||
impl<T, const D_ID: usize> DeviceVec<T, D_ID> { | |||
impl<T> DeviceVec<T> { | |||
pub fn cuda_malloc(count: usize) -> CudaResult<Self> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this and cuda_malloc_async
be private now that we have cuda_malloc_for_device<_async>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know... If you're only using one card, there's no need in the for_device
version of these methods. I personally would use cuda_malloc_async
in multi-device settings too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good , just one minor thing similar error strings to const, pr is pretty big one 😄 - maybe worth one more quick look
if let Some(device_id) = input.device_id() { | ||
assert_eq!( | ||
device_id, ctx_device_id, | ||
"Device ids in input and context are different" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe worth to do string const
I like the distinction made for Would like to confirm one thing: the updated
(currently I'm using |
@alxiong yea, the new functions do accept either |
Closing in favor of #443 which includes these changes |
Overview of the PR
Currently, we use a pretty awkward representation of on-device memory. This PR creates a more idiomatic pair:
DeviceVec
(which is not really a vector but a boxed slice) which allocates, deallocates and owns device memory, andDeviceSlice
which provides mutable and immutable views into device memory.@ChickenLover I also changed vector operations a little bit - Montgomery is removed as it was working differently for multiplication and addition/subtraction, plus device id checks have been added. Also, I wasn't sure how to change Poseidon related data that is currently raw slices, is it correct to say that here we should use
DeviceVec
(which should help with the memory leak, at least on the Rust side) and digest here should be a host slice?Unresolved questions
One potential improvement for H2D/D2H memory operations would be to pin memory inside
HostSlice::from_slice
method. But I'm still not sure how effective pinned memory is for modern GPUs, plus we don't really have ownership of host data so it's hard to make sure it's pinned and unpinned exactly once (which I'm not sure is a real issue). If someone has good understanding of pinned memory, please comment. And maybe we can just provide a flag for the user to decide if they want to pin memory or not.