Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph custom gradient support #292

Merged
merged 27 commits into from
Nov 11, 2021
Merged

Conversation

rnett
Copy link
Contributor

@rnett rnett commented Apr 19, 2021

This PR aims to add support for custom gradients for graphs, using the legacy gradient setup. Eventually it will be replaced by the gradient API in #283, but we have no idea when that will happen.

@rnett
Copy link
Contributor Author

rnett commented Apr 19, 2021

@saudet I'm getting a bunch of JavaCPP errors from GradFunc, that seem to be related to the std::vector adapter.

See here. Do you have any idea what could be causing it? I'm not doing any special mapping around those classes.

@saudet
Copy link
Contributor

saudet commented Apr 19, 2021

We probably need to "define" a wrapper class for the std::vector<tensorflow::Output> class, with something like this:
https://github.com/saudet/tensorflow-java/blob/add-gradient-tape/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java#L268
BTW, you don't need to give them special names like NativeOutput. They are already in a different package. Also, are you sure we can actually define custom gradients that way? It didn't seem possible the last time I looked at that, or it had limitations like it didn't work with eager mode, or something like that.

@zaleslaw
Copy link
Contributor

@rnett looks very cool, especially your test with registered grads for the Concat function.
Do you have any plans to add these gradients as a part of the library, not the test or examples only?
I could help with the testing and writing a few gradients too.

Did you take grad formulas from C++ code or from Python?

@rnett
Copy link
Contributor Author

rnett commented Apr 19, 2021

@rnett looks very cool, especially your test with registered grads for the Concat function.
Do you have any plans to add these gradients as a part of the library, not the test or examples only?
I could help with the testing and writing a few gradients too.

I don't have any specific plans, but I think it would be good to add missing grads in the TensorFlow init. Maybe extract that out to another class. Definitely something that should live in this repo imo.

Unfortunately writing gradients is rather hard atm since you don't have good ways to access attributes or inputs for Ops (like say axis for Concat) without using the GraphOperation and native code. I'm going to try to add those to the Op generator.

Did you take grad formulas from C++ code or from Python?

I'm not taking grad formulas from anywhere yet, this method adds the gradient to those in tensorflow/cc/gradients.

@rnett
Copy link
Contributor Author

rnett commented Apr 19, 2021

BTW, you don't need to give them special names like NativeOutput. They are already in a different package.

Yeah I know, but Imo it's cleaner than just relying on different packages. There's a number of methods (GradientHelpers mostly) that need both.

Also, are you sure we can actually define custom gradients that way? It didn't seem possible the last time I looked at that, or it had limitations like it didn't work with eager mode, or something like that.

Yeah, it's graph only, it's the legacy graph gradients. But until the graph backend starts using the gradient tape API, it's the only way to add gradients for graphs.

@rnett
Copy link
Contributor Author

rnett commented Apr 26, 2021

cc @saudet

Hmm, ok, I still get errors with a vector adapter: (new Info("std::vector<tensorflow::Output>").valueTypes("@StdMove NativeOutputVector").pointerTypes("NativeOutputVector").define())

/windows/Users/jimne/Desktop/OtherStuff/tensorflow_java/tensorflow-core/tensorflow-core-api/target/native/org/tensorflow/internal/c_api/linux-x86_64/jnitensorflow.cpp: In member function ‘tensorflow::Status JavaCPP_org_tensorflow_internal_c_1api_GradFunc::operator()(const tensorflow::Scope&, const tensorflow::Operation&, std::vector<tensorflow::Output>*, std::vector<tensorflow::Output>*)’:
/windows/Users/jimne/Desktop/OtherStuff/tensorflow_java/tensorflow-core/tensorflow-core-api/target/native/org/tensorflow/internal/c_api/linux-x86_64/jnitensorflow.cpp:1999:65: error: no matching function for call to ‘MoveAdapter<std::vector<tensorflow::Output> >::MoveAdapter(std::vector<tensorflow::Output>*&)’
     MoveAdapter< std::vector<tensorflow::Output> > adapter2(arg2);
                                                                 ^
/windows/Users/jimne/Desktop/OtherStuff/tensorflow_java/tensorflow-core/tensorflow-core-api/target/native/org/tensorflow/internal/c_api/linux-x86_64/jnitensorflow.cpp:723:5: note: candidate: MoveAdapter<T>::MoveAdapter(T&&) [with T = std::vector<tensorflow::Output>]
     MoveAdapter(T&& ptr) : ptr(&movedPtr), size(0), owner(0), movedPtr((T&&)ptr) { }
     ^~~~~~~~~~~
/windows/Users/jimne/Desktop/OtherStuff/tensorflow_java/tensorflow-core/tensorflow-core-api/target/native/org/tensorflow/internal/c_api/linux-x86_64/jnitensorflow.cpp:723:5: note:   no known conversion for argument 1 from ‘std::vector<tensorflow::Output>*’ to ‘std::vector<tensorflow::Output>&&’
/windows/Users/jimne/Desktop/OtherStuff/tensorflow_java/tensorflow-core/tensorflow-core-api/target/native/org/tensorflow/internal/c_api/linux-x86_64/jnitensorflow.cpp:722:5: note: candidate: MoveAdapter<T>::MoveAdapter(const T&) [with T = std::vector<tensorflow::Output>]
     MoveAdapter(const T& ptr) : ptr(&movedPtr), size(0), owner(0), movedPtr(std::move((T&)ptr)) { }
     ^~~~~~~~~~~

Making the adapter type a pointer doesn't help either.

@saudet
Copy link
Contributor

saudet commented Apr 27, 2021

Adapters don't work for defining function types like that, whether it is @StdVector, @StdMove, or anything else like that, but here it doesn't use an rvalue reference declaration with && so we don't need @StdMove there anyway, and you can remove it.

@rnett
Copy link
Contributor Author

rnett commented Apr 27, 2021

Had to remove valueTypes as well but that worked.

@saudet
Copy link
Contributor

saudet commented Apr 27, 2021

Sounds like memory corruption. Something is probably getting deallocated too early. We can set the "org.bytedeco.javacpp.nopointergc" system property to "true" and see that way if it's not GC doing that.

@rnett
Copy link
Contributor Author

rnett commented Apr 27, 2021

Yeah, my function pointers were getting GC'd, I forgot to save them.

@rnett
Copy link
Contributor Author

rnett commented Apr 27, 2021

Next question: I'm generating a wrapper for std::unordered_map here, but I need the erase method and I'm not seeing it. Is it inherited or something? We're using cpp11 so I would think it should be there.

@saudet
Copy link
Contributor

saudet commented Apr 27, 2021

Right, the way that works is by mapping a minimalist set of functions that are usually available in these kinds of templates. There should be some other way to erase an element though, there isn't? In any case, let me figure out some way to customize the output of that a bit...

@rnett
Copy link
Contributor Author

rnett commented Apr 27, 2021

There should be some other way to erase an element though, there isn't?

Doesn't seem like it, if there is it's not coming up on google

@saudet
Copy link
Contributor

saudet commented Apr 28, 2021

Ok, I'm confident enough that pretty much all "map" containers have an erase(iterator) method, so I've added that in commit bytedeco/javacpp@dcc06df. You'll have to use JavaCPP 1.5.6-SNAPSHOT for it to appear. If you need to remove using the key, we can also add overloads with something like the following in this case:

.put(new Info("std::unordered_map<tensorflow::string,tensorflow::Node*>").pointerTypes("NameMap").define().javaText("public native long erase(@StdString BytePointer key);"))

BTW, it looks like you're starting to map all of the legacy C++ API. You could pick up from what has already been done for TF 1.x:
https://github.com/bytedeco/javacpp-presets/blob/master/tensorflow/src/main/java/org/bytedeco/tensorflow/presets/tensorflow.java

@rnett
Copy link
Contributor Author

rnett commented Apr 28, 2021

Thanks, that works nicely.

BTW, it looks like you're starting to map all of the legacy C++ API. You could pick up from what has already been done for TF 1.x:
https://github.com/bytedeco/javacpp-presets/blob/master/tensorflow/src/main/java/org/bytedeco/tensorflow/presets/tensorflow.java

I'm trying not to, but it's gotten pretty big, that should help.

@rnett
Copy link
Contributor Author

rnett commented Apr 28, 2021

Now I'm getting a segfault from TF_OperationNumControlOutputs. Do you have any idea what would cause this? It's from GraphOperationTest.controlConsumers which works fine on master, and it fails independently of the gradient test being ran. The only change to TF_Operation is adding the node getter and removing @Opaque, I'm not sure how that would cause a segfault.

dump file

@saudet
Copy link
Contributor

saudet commented Apr 29, 2021

If those TF_Operation objects end up with a deallocator, they may be getting deallocated prematurely. If that's the case, we can use PointerScope as appropriate to make sure that doesn't happen, and to prevent those objects from sticking around longer than necessary, slowing things down as well.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

Yeah, that's what I thought it could be, but I didn't change anything around those methods, and the gradient methods all use PointerScopes. Also, the failing methods all verify the pointer is not null before calling.

I cherry-picked the JavaCPP generation changes (presets/tensorflow.java, maven configuration) to master, and that causes it to happen, so it doesn't seem like it's a deallocation issue.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

Another note: only controlConsumers and consumers fail, all the other tests work fine.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

The method looks like this:

try (PointerScope scope = new PointerScope()) {
  TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index);
  return TF_OperationOutputNumConsumers(output);
}

If I breakpoint right before the TF_OperationOutputNumConsumers call, output is valid, I can get output.oper().node().name().getString(). I can also get the consumer's inputs manually.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

Ok, I've found the cause. It only happens when I include tensorflow/c/c_api.cc, even if I skip parsing it. I need a few functions from there (ToOperation, TF_FinishOperationLocked, etc) that aren't published in the header file. I assumed that it would be supported even though it isn't a header file, is it not?

@saudet
Copy link
Contributor

saudet commented Apr 29, 2021 via email

@Craigacp
Copy link
Collaborator

Let's ask upstream first before ad-hoc expanding the C API. There may be a reason those functions aren't part of the C API. Did you check to see if libtensorflow exports those symbols?

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

It doesn't, some of them are static and others are in an anonymous namespace. The namespaced ones are just helpers that would be nice to have, the static ones ignore the graphs lock (which is required to define gradient functions via the C API).

I can make an issue in tensorflow, but I'm not sure what to ask for other than a full custom gradient C API which wouldn't be worth doing for the old version. These are functions that shouldn't really be public, the C API just isn't made with custom gradients in mind.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

Ok, things work now, but it's a bit hacky. I'm going to make a tensorflow issue asking for the necessary functions, but even if they approve exporting them, we may want to merge this w/ the patch instead of waiting for 2.6 or whenever they make it in.

I need three functions:

  • TF_Operation* ToOperation(Node* node) to work with the c++ API for gradient defs.
  • TF_NewOperationLocked and
  • TF_FinishOperationLocked because both the C API gradient definition function and the normal versions of those functions lock the graph's muxex, preventing you from using the C API op def functions in a gradient definition.

@rnett
Copy link
Contributor Author

rnett commented Apr 29, 2021

Also, can you mark this with CI Build?

@rnett rnett requested a review from Craigacp November 5, 2021 01:32
@Craigacp
Copy link
Collaborator

Craigacp commented Nov 5, 2021

Ok, I think the scopes should be named if possible, and the docs on CustomGradient and RawCustomGradient need to be tidied up a bit but otherwise it's fine. If there are issues throwing exceptions through the TypedGradientAdapter then let's use TF's status signalling mechanism instead, it'll still result in a Java exception on the other end.

@rnett
Copy link
Contributor Author

rnett commented Nov 5, 2021

Can I get someone to re-run the CI jobs? The cache needs to be populated.

@Craigacp Craigacp added CI build Triggers a full native build on a pull request and removed CI build Triggers a full native build on a pull request labels Nov 5, 2021
@karllessard
Copy link
Collaborator

@Craigacp , @rnett : is this ready to be merged now?

Craigacp
Craigacp previously approved these changes Nov 10, 2021
@Craigacp
Copy link
Collaborator

I think so.

@rnett
Copy link
Contributor Author

rnett commented Nov 10, 2021

I'll document the rawtypes and then push the generation later today.

Signed-off-by: Ryan Nett <[email protected]>
@karllessard
Copy link
Collaborator

All right, merging this now, thanks again for that great contribution, @rnett !

@karllessard karllessard merged commit e0eec4a into tensorflow:master Nov 11, 2021
@rnett
Copy link
Contributor Author

rnett commented Nov 11, 2021

@karllessard @Craigacp generation is pushed, we're good to go.

Edit: Welp I got ninja'd.

@rnett rnett deleted the rn_custom_gradients branch November 11, 2021 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI build Triggers a full native build on a pull request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants