Skip to content

Commit

Permalink
Training Proposal: Spec Changes and Gradient Operator (onnx#2314)
Browse files Browse the repository at this point in the history
* ONNX Training proposal.

Major changes:
  1. Add a protobuf message, `TrainingInfoProto` originally designed in
     onnx#2013, to store training information.
  2. In `TrainingInfoProto`, the user can store training algorithm in
     `algorithm` field as a `GraphProto`.
  3. The user can also store initialization algorithm for resetting the
     model in `TrainingInfoProto.initialization` (proposed by @tbennun in
     onnx#2517 and agreed by Training WG).
  4. `ModelProto.graph` is callable inside `TrainingInfoProto.algorithm`.
     `ModelProto.graph.initializer` are visible to nodes in
     `TrainingInfoProto.algorithm.node`.
  5. This PR also introduces a `Gradient` operator to differentiate a
     function represented by a (sub-)graph. This idea is from onnx#2168.

Contribution list:
   Baihan Huang: spec design.
   Tal Ben-Nun: model initialization design.
   Wei-Sheng Chin: spec design, Gradient operator design.
   Jonny Shipton and active WG members and participants: many valuable comments and reviews.

Co-authored-by: Sherlock <[email protected]>
Co-authored-by: Tal Ben-Nun <[email protected]>
Co-authored-by: Jonny Shipton <[email protected]>

* Address comments

* Address a comment

* Move Gradient to ai.onnx.training

Update Gradient test models

* Address comments
1. Create initialization_binding instead of
   using update_binding for initialization.
2. Swap key and velue in update_binding.
3. Refine documents accordingly.

* Clarify sementics of algorithm and initialization

* Fix typos

* Address comment and explain the two computation modes of  ModelProto.training_info

* Fix typo and explain default behavior

* Update onnx/checker.cc

Co-Authored-By: Jonny Shipton <[email protected]>

* Address comments

* Make normalization_binding a repeated field

* Add GraphCall operator

* Polish GraphCall

* GraphCall now uses position to map inputs and outputs

* Address comments:
1. Clarify GraphCall's semantic.
2. Implicitly force trainable tensors to be inference graph's inputs.
3. Training operators cannot be called in the inference graph.

* Add accidently removed changes back

* Use protobuf lite

* Polish the helper script

* Fix windows build and polish helper script

* Fix linux and mac builds

* One more line

* fix the attribute types section in IR.md (onnx#2590)

* fix the attribute types section in IR.md

* update per comments.

* Some changes around the behavior of optional inference inputs.

1. Use pass-by-value to optional inference inputs.
2. Due to the semantic of GraphCall, we implicitly force trainable
   inputs to be added into inference graph's input list.

Revise docs

* Update spec per WG discussion

* update_binding is optional now because user might only want to store initialization

* Polish doc

* Address comments. Polish words.

* Use an alternative field to declar global variables.
In yesterday's Operator SIG meeting, we agree to still
put global variables in the inference graph and add a
model-level field to indicate global variables. This way
we can have smaller impact to the inference engines, because
they don't need to move trainable tensors to a new field.

* polish docs

* Allow training initializers to be promoted to global & mutable variables

* Merge the functions of global_mutable_initializer_names into update_binding

* Polish docs

* Remove restriction on using ai.onnx.training in the inference graph

* Split training register from ai.onnx register file

Co-authored-by: Sherlock <[email protected]>
Co-authored-by: Tal Ben-Nun <[email protected]>
Co-authored-by: Jonny Shipton <[email protected]>
Co-authored-by: Ke Zhang <[email protected]>
  • Loading branch information
5 people authored Feb 17, 2020
1 parent 9fdae4c commit 807c62c
Show file tree
Hide file tree
Showing 39 changed files with 1,986 additions and 33 deletions.
3 changes: 2 additions & 1 deletion .travis/script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ git diff --exit-code

# check auto-gen files up-to-date
python onnx/defs/gen_doc.py
python onnx/gen_proto.py
python onnx/gen_proto.py -l
python onnx/gen_proto.py -l --ml
python onnx/backend/test/stat_coverage.py
backend-test-tools generate-data
git status
Expand Down
3 changes: 2 additions & 1 deletion appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ build_script:
- cmd: pip install %_wheel%
- cmd: pytest
- cmd: python onnx/defs/gen_doc.py
- cmd: python onnx/gen_proto.py
- cmd: python onnx/gen_proto.py -l
- cmd: python onnx/gen_proto.py -l --ml
# Run type checks
- cmd: pip uninstall -y %_wheel%
- cmd: rm -rf .setuptools-cmake-build
Expand Down
297 changes: 295 additions & 2 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -14044,8 +14044,12 @@ This version of the operator has been available since version 12 of the default
### <a name="Celu-12"></a>**Celu-12**</a>

Continuously Differentiable Exponential Linear Units:
Perform the linear unit element-wise on the input tensor X
using formula: <br/> ``` max(0,x) + min(0,alpha*(exp(x/alpha)−1)) ```
Perform the linear unit element-wise on the input tensor X
using formula:

```
max(0,x) + min(0,alpha*(exp(x/alpha)-1))
```

#### Version

Expand Down Expand Up @@ -14526,3 +14530,292 @@ This version of the operator has been available since version 12 of the default
<dd>Constrain input and output types to high-precision and 8 bit numeric tensors.</dd>
</dl>

# ai.onnx.training
## Version 1 of the 'ai.onnx.training' operator set
### <a name="ai.onnx.training.Gradient-1"></a>**ai.onnx.training.Gradient-1**</a>

Gradient operator computes the partial derivatives of a specific tensor w.r.t.
some other tensors. This operator is widely used in gradient-based training
algorithms. To illustrate its use, let's consider a computation graph,

```
X -----.
|
v
W --> Conv --> H --> Gemm --> Y
^
|
Z
```

, where W and Z are trainable tensors. Note that operators' attributes are
omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of
Y with respect to W (Z). The user can compute gradient by inserting Gradient
operator to form another graph shown below.

```
W --> Conv --> H --> Gemm --> Y
| ^ ^
| | |
| X Z
| | |
| | .----------'
| | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
| | | "xs" followed by "zs")
| v v
'---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
| |
| '-----------------------------------> dY/dW (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```

By definition, the tensor "y" is a function of independent variables in "xs"
and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable
variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H"
cannot appear in "xs" and "zs". The reason is that "H" can be determined by
tensors "W" and "X" and therefore "H" is not an independent variable.

All outputs are optional. If needed, for example, user can assign an empty
string to the 1st output name of that Gradient to skip the generation of dY/dW.
Note that the concept of optional outputs can also be found in ONNX's RNN, GRU,
and LSTM.

Gradient operator can compute derivative against intermediate tensors. For
example, the gradient of Y with respect to H can be done via

```
W --> Conv --> H --> Gemm --> Y
^ | ^
| | |
X | Z
.-------' |
| .----------'
| | (H/Z is the 1st/2nd input of Gradient as shown in "xs")
v v
Gradient(xs=["H", "Z"], y="Y")
| |
| '-----------------------------------> dY/dH (1st output of Gradient)
|
'---------------------------------------> dY/dZ (2nd output of Gradient)
```

It is possible to represent high-order differentiation using Gradient operators.
For example, given the following linear model:

```
W --> Gemm --> Y --> Loss --> O
^ ^
| |
X L
```

To compute the 2nd order derivative of O with respect to W (denoted by
d^2O/dW^2), one can do

```
W --> Gemm --> Y --> Loss --> O
| ^ ^
| | |
| X .------------L
| | | |
| | | v
+------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient)
| | | |
| | | '---> dO/dW (2nd output of Gradient)
| v v
'---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of
| Gradient)
|
|
'---> d^2O/dW^2 (2nd output of Gradient)
```

The tensors named in attributes "xs", "zs", and "y" define the differentiated
computation graph, and the inputs to Gradient node define the values at
which the gradient is computed. We can feed different tensors to the identified
graph. For example, one can compute the gradient of Y with respect to H at
a specific value of H, H_1, by providing that value as an input to the Gradient
node.

```
W --> Conv --> H --> Gemm --> Y
^ ^
| |
X Z

Z_1 (2nd input of Gradient)
|
v
H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1.
|
'------------------------------> dY/dZ (2nd output of Gradient)
```

When the inputs of Gradient are the tensors named in "xs" and "zs", the
computation can be optimized. More specifically, intermediate variables in
forward pass can be reused if the gradient is computed via reverse-mode
auto-differentiation.


#### Version

This version of the operator has been available since version 1 of the 'ai.onnx.training' operator set.

#### Attributes

<dl>
<dt><tt>xs</tt> : list of strings (required)</dt>
<dd>Input tensor names of the differentiated sub-graph. It contains only the necessary differentiated inputs of a (sub-)graph. Variables (usually called intermediate variables) that can be generated from inputs cannot be included in this attribute.</dd>
<dt><tt>y</tt> : string (required)</dt>
<dd>The targeted tensor. It can be viewed as the output of the differentiated function. The attribute "xs" and attribute "zs" are the minimal independent variable set that determines the value of "y".</dd>
<dt><tt>zs</tt> : list of strings</dt>
<dd>Input tensor names of the differentiated sub-graph. It contains only the necessary non-differentiated inputs of a (sub-)graph. Variables (usually called intermediate variables) that can be generated from inputs cannot be included in this attribute.</dd>
</dl>

#### Inputs (1 - &#8734;)

<dl>
<dt><tt>Inputs</tt> (variadic, heterogeneous) : T1</dt>
<dd>The values fed into graph identified by the attributes. The i-th input is the value of the i-th tensor specified in the concatenated list of the attribute "xs" and the attribute "zs". For example, if xs=["A", "B"] and zs=["C"], the first input is used as the value of symbol "A" and the 3rd input is substituted for all the occurrences of "C".</dd>
</dl>

#### Outputs (1 - &#8734;)

<dl>
<dt><tt>Outputs</tt> (variadic, heterogeneous) : T2</dt>
<dd>The gradient of the tensor specified by the attribute "y" with respect to each of tensors specified in the attribute "xs". The i-th output is the gradient of "y" with respect to the i-th tensor specified in the attribute "xs".</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T1</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
<dd>Allow outputs to be any kind of tensor.</dd>
<dt><tt>T2</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Allow inputs to be any kind of floating-point tensor.</dd>
</dl>

### <a name="ai.onnx.training.GraphCall-1"></a>**ai.onnx.training.GraphCall-1**</a>

The GraphCall operator invokes a graph inside TrainingInfoProto's
algorithm field. The GraphCall inputs and outputs are bound to those of
invoked graph by position. If a graph input has an initializer, that input
is considered optional. All graph outputs are optional.

Below Python syntax is used for describing dictionary and list.

Assume that ModelProto's graph field has
- name: "MyInferenceGraph"
- input: ["X", "W", "Z"]
- initializer: [W]
- output: ["Y"]

as visualized below for inference.

```
X -----.
|
v
W --> Conv --> H --> Gemm --> Y
^
|
Z
```

Assume that the training algorithm contains

- inputs: ["X_1", "Z_1", "C"]
- initializer: [T]
- outputs: ["W_new"]

with a dictionary

- update_binding: {"W": "W_new", "T": "T_new"}

Inside the training algorithm graph, one can invoke the inference
graph via adding a GraphCall node with

- inputs: ["X_1", "W", Z_1"]
- outputs: ["Y_1"]
- an attribute graph_name="MyInferenceGraph",

The initializers, "W" and "T" in this case, in update_binding
are considered globally-visible and mutable variables, which
can be used as inputs of operators in the training graph.

An example training algorithm graph may look like

```
.-------- W (a global and mutable variable from
| | the inference graph)
| |
| .-----'-----------.
| | |
| | v
| | .-- X_1 --> GraphCall(graph_name="MyInferenceGraph")
| | | | |
| | | | |
| | | Z_1 -----' |
| | | | V
| | | | Y_1 ---> Loss ---> O
| | | | ^
| | | | |
| | `--. | C
| | | | |
| | | | .----------------'
| | | | |
| | v v v
| `--> Gradient(xs=["W"], zs=["X_1", "Z_1", "C"], y="O")
| |
| v
| dO_dW (gradient of W) 1 (a scalar one)
| | |
| V v
| Div <--- T ------------> Add ---> T_new
| | (T is the number of training iterations.
| | T is also globally visible and mutable.)
| v
`-----> Sub ----> W_new
```

where Loss is a dummy node which computes the minimized objective function.

The variable "W" is an optional input in the called graph.
If the user omits it, the input list of GraphCall becomes ["X_1", "", "Z_1"].
In this case, from the view of computation graph, the Conv operator invoked by
GraphCall's may be still connected the global "W" variable and therefore the
structure of the computation graph is unchanged.

#### Version

This version of the operator has been available since version 1 of the 'ai.onnx.training' operator set.

#### Attributes

<dl>
<dt><tt>graph_name</tt> : string (required)</dt>
<dd>The invoked graph's name. The only allowed value is the name of the inference graph, which is stored in "ModelProto.graph.name" in the ONNX model format.</dd>
</dl>

#### Inputs (1 - &#8734;)

<dl>
<dt><tt>Inputs</tt> (variadic, heterogeneous) : T</dt>
<dd>Inputs fed to the invoked graph. The i-th input here goes to the i-th input of the invoked graph. To omit an optional input in this field, the user can drop it or use an empty string.</dd>
</dl>

#### Outputs (1 - &#8734;)

<dl>
<dt><tt>Outputs</tt> (variadic, heterogeneous) : T</dt>
<dd>The outputs generated by the called graph. Its i-th value is bound to the i-th output of the called graph. Similar to the inputs, all outputs are optional.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
<dd>Allow inputs and outputs to be any kind of tensor.</dd>
</dl>

Loading

0 comments on commit 807c62c

Please sign in to comment.