-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Training Proposal: Spec Changes and Gradient Operator (onnx#2314)
* ONNX Training proposal. Major changes: 1. Add a protobuf message, `TrainingInfoProto` originally designed in onnx#2013, to store training information. 2. In `TrainingInfoProto`, the user can store training algorithm in `algorithm` field as a `GraphProto`. 3. The user can also store initialization algorithm for resetting the model in `TrainingInfoProto.initialization` (proposed by @tbennun in onnx#2517 and agreed by Training WG). 4. `ModelProto.graph` is callable inside `TrainingInfoProto.algorithm`. `ModelProto.graph.initializer` are visible to nodes in `TrainingInfoProto.algorithm.node`. 5. This PR also introduces a `Gradient` operator to differentiate a function represented by a (sub-)graph. This idea is from onnx#2168. Contribution list: Baihan Huang: spec design. Tal Ben-Nun: model initialization design. Wei-Sheng Chin: spec design, Gradient operator design. Jonny Shipton and active WG members and participants: many valuable comments and reviews. Co-authored-by: Sherlock <[email protected]> Co-authored-by: Tal Ben-Nun <[email protected]> Co-authored-by: Jonny Shipton <[email protected]> * Address comments * Address a comment * Move Gradient to ai.onnx.training Update Gradient test models * Address comments 1. Create initialization_binding instead of using update_binding for initialization. 2. Swap key and velue in update_binding. 3. Refine documents accordingly. * Clarify sementics of algorithm and initialization * Fix typos * Address comment and explain the two computation modes of ModelProto.training_info * Fix typo and explain default behavior * Update onnx/checker.cc Co-Authored-By: Jonny Shipton <[email protected]> * Address comments * Make normalization_binding a repeated field * Add GraphCall operator * Polish GraphCall * GraphCall now uses position to map inputs and outputs * Address comments: 1. Clarify GraphCall's semantic. 2. Implicitly force trainable tensors to be inference graph's inputs. 3. Training operators cannot be called in the inference graph. * Add accidently removed changes back * Use protobuf lite * Polish the helper script * Fix windows build and polish helper script * Fix linux and mac builds * One more line * fix the attribute types section in IR.md (onnx#2590) * fix the attribute types section in IR.md * update per comments. * Some changes around the behavior of optional inference inputs. 1. Use pass-by-value to optional inference inputs. 2. Due to the semantic of GraphCall, we implicitly force trainable inputs to be added into inference graph's input list. Revise docs * Update spec per WG discussion * update_binding is optional now because user might only want to store initialization * Polish doc * Address comments. Polish words. * Use an alternative field to declar global variables. In yesterday's Operator SIG meeting, we agree to still put global variables in the inference graph and add a model-level field to indicate global variables. This way we can have smaller impact to the inference engines, because they don't need to move trainable tensors to a new field. * polish docs * Allow training initializers to be promoted to global & mutable variables * Merge the functions of global_mutable_initializer_names into update_binding * Polish docs * Remove restriction on using ai.onnx.training in the inference graph * Split training register from ai.onnx register file Co-authored-by: Sherlock <[email protected]> Co-authored-by: Tal Ben-Nun <[email protected]> Co-authored-by: Jonny Shipton <[email protected]> Co-authored-by: Ke Zhang <[email protected]>
- Loading branch information
1 parent
9fdae4c
commit 807c62c
Showing
39 changed files
with
1,986 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.