-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transform llvm.ptr arg to memref arg #314
Conversation
I think we had a bit of a race condition whree @ivanradanov also added this functionality in https://github.com/EnzymeAD/Enzyme-JAX/pull/303/files#diff-226e98da706121c678606489ff574cfe82f6143f1cfeb844c99d4d7dc2a94deb That said a separate pass is nicer for debugging, so let's go with this. A quick review: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm modulo the jitcall fix
@ivanradanov can you also give this a once over?
Updated with jit call fix |
DenseMap<FunctionOpInterface, SetVector<CallOpInterface>> funcToKernelMap; | ||
moduleOp->walk([&](CallOpInterface callOp) { | ||
auto symbolName = | ||
dyn_cast_or_null<SymbolRefAttr>(callOp.getCallableForCallee()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldn't be any call interface, but just jit_call and kernel_call. We can still store as callopinterface, but we should only walk the Jit and kernel call ops
No description provided.