-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vecops's Rust wrapper fix for batch>1 #741
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Koren-Brand <[email protected]>
…ed tests of sum and product such that result size matches batch size Signed-off-by: Koren-Brand <[email protected]>
Signed-off-by: Koren-Brand <[email protected]>
@@ -165,16 +165,16 @@ pub fn check_vec_ops_scalars_sum<F: FieldImpl>(test_size: usize) | |||
where | |||
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>, | |||
{ | |||
let cfg = VecOpsConfig::default(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's define batch>1 if it should work.
Note that you need to resize the output vector since you cannot assign the cfg.batch field in rust (by api design).
Also it means that using the cfg.batch_size like you did is equivalent to using 1, so don't do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm planning on changing it after seeing the CI pass when batch_size is still 1
EDIT: added in test.rs with batch = 3
output: &(impl HostOrDeviceSlice<F> + ?Sized), | ||
cfg: &VecOpsConfig, | ||
) -> VecOpsConfig { | ||
if input.len() as u64 != size_in * cfg.batch_size as u64 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could remove size_in and check that batch divides input.size() like other apis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this (current multiplication) is less taxing than a modulo operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's minor and we better have all apis aligned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But size_in is part of the Rust api (see the function that calls this function) and it should be checked so this is the logical way to check it together with cfg.batch_size
Unless you talk about removing size_in from the API as a whole which is another thing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having batch_size be implictly defined by division led to modulo be used instead of multiplication
cfg.batch_size | ||
); | ||
} | ||
if offset + (size_out - 1) * stride >= size_in { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it 'offset+size*stride-1'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first element of the slice is at offset
The last element of the slice is at offset + (size-1)*stride (The slice is of size size_out)
This checks that the last element of the slice does not overflow the inputs (Or panics if the condition above is met)
So the -1 is correctly placed
Signed-off-by: Koren-Brand <[email protected]>
result: &(impl HostOrDeviceSlice<F> + ?Sized), | ||
cfg: &VecOpsConfig, | ||
) -> VecOpsConfig { | ||
if result.len() != cfg.batch_size as usize || input.len() % cfg.batch_size as usize != 0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you cannot assing batch_size in the rust api and instead we use result.len() as the batch size (see apis such as ntt, msm).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assign*
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made batch_size be determined implicitly instead of user-defined
result: &(impl HostOrDeviceSlice<F> + ?Sized), | ||
cfg: &VecOpsConfig, | ||
) -> VecOpsConfig { | ||
if a.len() != cfg.batch_size as usize || b.len() != result.len() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would a.len() be the batch? If this is the scalar-vector ops then I think the scalar should be a single element. Otherwise you need to allocate a full array with the same element. In that case you don't even need those apis.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the api for the cpp function
Signed-off-by: Koren-Brand <[email protected]>
… in this PR Signed-off-by: Koren-Brand <[email protected]>
Signed-off-by: Koren-Brand <[email protected]>
Describe the changes
This PR fixes the Vecops rust wrapper to work with batch > 1 (Which affected most Vector operations besides vector vector operations).
Added assertions when calling vecops functions in rust to prevent invalid configurations.