Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vecops's Rust wrapper fix for batch>1 #741

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Koren-Brand
Copy link
Contributor

@Koren-Brand Koren-Brand commented Jan 15, 2025

Describe the changes

This PR fixes the Vecops rust wrapper to work with batch > 1 (Which affected most Vector operations besides vector vector operations).
Added assertions when calling vecops functions in rust to prevent invalid configurations.

…ed tests of sum and product such that result size matches batch size

Signed-off-by: Koren-Brand <[email protected]>
@@ -165,16 +165,16 @@ pub fn check_vec_ops_scalars_sum<F: FieldImpl>(test_size: usize)
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
let cfg = VecOpsConfig::default();
Copy link
Collaborator

@yshekel yshekel Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's define batch>1 if it should work.
Note that you need to resize the output vector since you cannot assign the cfg.batch field in rust (by api design).
Also it means that using the cfg.batch_size like you did is equivalent to using 1, so don't do that

Copy link
Contributor Author

@Koren-Brand Koren-Brand Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm planning on changing it after seeing the CI pass when batch_size is still 1
EDIT: added in test.rs with batch = 3

output: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> VecOpsConfig {
if input.len() as u64 != size_in * cfg.batch_size as u64 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could remove size_in and check that batch divides input.size() like other apis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this (current multiplication) is less taxing than a modulo operation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's minor and we better have all apis aligned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But size_in is part of the Rust api (see the function that calls this function) and it should be checked so this is the logical way to check it together with cfg.batch_size
Unless you talk about removing size_in from the API as a whole which is another thing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having batch_size be implictly defined by division led to modulo be used instead of multiplication

cfg.batch_size
);
}
if offset + (size_out - 1) * stride >= size_in {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it 'offset+size*stride-1'?

Copy link
Contributor Author

@Koren-Brand Koren-Brand Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first element of the slice is at offset
The last element of the slice is at offset + (size-1)*stride (The slice is of size size_out)
This checks that the last element of the slice does not overflow the inputs (Or panics if the condition above is met)
So the -1 is correctly placed

result: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> VecOpsConfig {
if result.len() != cfg.batch_size as usize || input.len() % cfg.batch_size as usize != 0 {
Copy link
Collaborator

@yshekel yshekel Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you cannot assing batch_size in the rust api and instead we use result.len() as the batch size (see apis such as ntt, msm).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assign*

Copy link
Contributor Author

@Koren-Brand Koren-Brand Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made batch_size be determined implicitly instead of user-defined

result: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> VecOpsConfig {
if a.len() != cfg.batch_size as usize || b.len() != result.len() {
Copy link
Collaborator

@yshekel yshekel Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would a.len() be the batch? If this is the scalar-vector ops then I think the scalar should be a single element. Otherwise you need to allocate a full array with the same element. In that case you don't even need those apis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the api for the cpp function

@Koren-Brand Koren-Brand marked this pull request as ready for review January 16, 2025 08:05
@Koren-Brand Koren-Brand requested a review from yshekel January 16, 2025 16:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants