-
Notifications
You must be signed in to change notification settings - Fork 956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add warnings and fallback for unassigned devices in infer_auto_device_map #3066
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…map" This reverts commit d607bfb.
The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring.
Implemented fallback allocation to allow modules to be allocated to devices using BFS when regular allocation fails. This enhancement improves the allocation process by ensuring that at least one module is assigned to the device, even under tight memory constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates and the style fix. I'm not very knowledgeable about the whole logic being applied here, so I won't comment on that.
Personally, I find the use of continue
in addition to many nested conditionals makes the logic super hard to follow. Usually, I would try to stick to either if + continue
or if ... elif ...
without continue
. Not sure if the code could be simplified here.
One thing I believe we should ensure is that the new logic does not add any unnecessary warnings. Right now, we have some unit tests to ensure that specific warnings are there, but AFAICT we don't have tests to ensure that for other cases, there are no warnings. Maybe it would be good to add tests for the "happy path" and show that there is no warning. Potentially, we can even use existing tests and just add a check there is no warning. WDYT?
test_infer_auto_device_map and test_infer_auto_device_map_with_fallback_allocation now each have a no-warning test case. Simplified and rewrote code sections that were made unreadable by the linter.
Added complete return type hinting for _init_infer_auto_device_map
Hey @BenjaminBossan, I appreciate your feedback! Regarding the use of continue and nested conditionals, I've tried simplifying the logic where possible. Now the I completely agree with your point about avoiding unnecessary warnings. I've added checks in both |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning the code up and extending the tests.
I agree that the logic was already complex beforehand so it's not just because of this PR. But I think your recent changes helped a little bit to make it easier to understand, even if the overall complexity is still high and I can't say I understand all that's going on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! This will be very handy. cc @SunMarc for a final look since it's big model inference :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Nech-C ! Really appreciate that you are putting a lot of effort into this PR ! I will review it soon but first I have a question: could you explain a bit with an example of what this fallback allocation will do ? From our conversation last time, the biggest issue with infer_auto_device_map is that we are saving memory for the largest layer in case we need to offload it to the cpu. I think that in your case, you are trying to find a module that fits the device memory - largest layer ?
Hi @SunMarc, sure thing! You are right. My code doesn't directly address the issue that the function may reserve space on a device for a module that won't be loaded onto it during inference when there are multiple execution devices. I have tried to come up with new allocation strategies, but the task is really complex. If possible, I would like to open a separate PR to address this issue when I come up with a reasonable solution. While this PR doesn't solve the most significant concern, it does alleviate the problem. The constraint for allocating a module to a device is roughly module size + max layer size <= device memory. The aforementioned issue focuses on lowering max layer size, and this PR focuses on lowering module size by looking for a smaller module in the module list. It tries to assign a module to a device that receives no assignment when the regular allocation logic fails. In theory, a device can be used during execution if it has more memory than the largest layer, even with no module assigned to it. Thus, we can achieve the same result without going through such a roundabout approach. However, I believe this would be a breaking change, as we need the returned value Thanks for your feedback. I'm open to further suggestions or clarifications if needed. |
Nice explanation @Nech-C ! Thanks for confirming !
I think that a quick solution to the max_layer size issue would be the following algorithm
We can add your fallback option each time we run infer_auto_device_map if wanted. This will help fixing this following issue I saw a couple of time: Let me know what you think ! Nevertheless, I think it will be nice to first merge this PR before moving the max_layer size fix. |
Ohhh, now I get it @SunMarc . Thanks for breaking it down. Working on the code really helped me understand your idea. TBH, I didn't fully understand it when you first mentioned it in the issue 😅. Your algorithm idea sounds solid. I'm on board with merging this PR first, then tackling the max_layer size fix. And how should I proceed with the max_layer fix? Do I just open a new PR referencing the original issue, or do we need a new issue for this? Also, just a heads up, I've got a couple of busy weeks coming up, so I may not be able to start working on this right away. But I'll definitely get to it as soon as I can. Any tweaks you want me to make to this PR before we move on? |
Sounds good ! I'll try to review this today ! |
What does this PR do?
This PR is proposed changes to the infer_auto_device_map function from #3041. It will make the following improvements:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.