Skip to content

Commit

Permalink
#2 Add SAC agent info
Browse files Browse the repository at this point in the history
  • Loading branch information
e10101 committed Jan 9, 2022
1 parent ab14e02 commit 3b499e3
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 33 deletions.
131 changes: 100 additions & 31 deletions notebooks/CartPole-v0.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -27,26 +27,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install imageio\n",
"!pip install imageio-ffmpeg"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python 3.8.10\n"
]
}
],
"source": [
"!python --version"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -55,26 +53,48 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'2.7.0'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'0.11.0'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tf_agents\n",
"tf_agents.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -109,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {
"collapsed": false,
"jupyter": {
Expand All @@ -126,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -162,7 +182,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {
"tags": []
},
Expand All @@ -174,18 +194,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"TimeStep(\n",
"{'discount': array(1., dtype=float32),\n",
" 'observation': array([ 0.02613745, -0.02241037, -0.0004048 , 0.02523139], dtype=float32),\n",
" 'reward': array(0., dtype=float32),\n",
" 'step_type': array(0, dtype=int32)})"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {
"tags": []
},
Expand All @@ -196,28 +231,53 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAIAAAD9V4nPAAAGdElEQVR4nO3dwU2DYBiAYTFdwjl0DOdoZ2rncAydwzHw4qFWTZqg/DTv89zgQL4LecMXAtM8z3cAUHU/egAAGEkIAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDSdqMHAD69nQ7nh4/746hJIEUIYZiL8gFDWI0CkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQgiLTAuMujJwTggBSBNCANJ2oweAupf3/cWZ54fTkEmgyRMhjPS9gr+dBP6JEMIWaSGsRghhGLWDLRBCANKEEIA0IYRhvB0KWyCEsEUaCasRQhjpx+CpIKxpmud59Axww5Z82PP1+OWt0afDn/XPfQ3XE0JYZJtfuHZfw/WsRgEAAKqsRmERq1G4dVajAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKT5+wQAaZ4IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIO0Dbm0yV3iC604AAAAASUVORK5CYII=\n",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=600x400 at 0x7F6D39CBF460>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"PIL.Image.fromarray(frame)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time Step Spec:\n",
"TimeStep(\n",
"{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),\n",
" 'observation': BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]),\n",
" 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),\n",
" 'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')})\n"
]
}
],
"source": [
"print('Time Step Spec:')\n",
"print(env.time_step_spec())"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 14,
"metadata": {
"collapsed": false,
"jupyter": {
Expand All @@ -227,7 +287,16 @@
"name": "#%%\n"
}
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Action Spec:\n",
"BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)\n"
]
}
],
"source": [
"print('Action Spec:')\n",
"print(env.action_spec())"
Expand Down
4 changes: 2 additions & 2 deletions notebooks/agents.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Agents in the `tf-agents`

## Agents
## Agents / Algorithms

| Agent (Algorithm) | Description | TFA Module | Action Space | Release | Inventor | Related Agents (or Algorithms) | On-policy / Off-policy |
| ------------------------------------------------------------ | ----------------------------------------------------- | ------------------ | -------------------- | ------- | -------- | ------------------------------------------------------------ | ---------------------- |
Expand All @@ -13,6 +13,6 @@
| [PPOClipAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/PPOClipAgent) | PPO with clipped probability ratios | ppo | | | | | |
| [PPOKLPenaltyAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/PPOKLPenaltyAgent) | PPO with KL penalty loss | ppo | | | | | |
| ReinforceAgent | REINFORCE | reinforce | | | | | |
| SacAgent | Soft Actor Critic | sas | | | | | |
| [SacAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/SacAgent) | [Soft Actor Critic](https://arxiv.org/abs/1801.01290) | sas | Continuous, Discrete | 2018 | Berkeley | | Off-policy |
| TD3 Agent | Twin Delayed Deep Deterministic policy gradient (TD3) | td3 | | | | | |

0 comments on commit 3b499e3

Please sign in to comment.