From 3b499e39b225eea6cc3807e35ecaa74c024d84bc Mon Sep 17 00:00:00 2001 From: e10101 Date: Sun, 9 Jan 2022 16:23:16 +0800 Subject: [PATCH] #2 Add SAC agent info --- notebooks/CartPole-v0.ipynb | 131 +++++++++++++++++++++++++++--------- notebooks/agents.md | 4 +- 2 files changed, 102 insertions(+), 33 deletions(-) diff --git a/notebooks/CartPole-v0.ipynb b/notebooks/CartPole-v0.ipynb index e81c063..46f42cd 100644 --- a/notebooks/CartPole-v0.ipynb +++ b/notebooks/CartPole-v0.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -27,26 +27,24 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!pip install imageio\n", - "!pip install imageio-ffmpeg" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python 3.8.10\n" + ] + } + ], "source": [ "!python --version" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -55,18 +53,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'2.7.0'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "tf.__version__" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'0.11.0'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import tf_agents\n", "tf_agents.__version__" @@ -74,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { "tags": [] }, @@ -109,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "collapsed": false, "jupyter": { @@ -126,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "tags": [] }, @@ -162,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { "tags": [] }, @@ -174,18 +194,33 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "TimeStep(\n", + "{'discount': array(1., dtype=float32),\n", + " 'observation': array([ 0.02613745, -0.02241037, -0.0004048 , 0.02523139], dtype=float32),\n", + " 'reward': array(0., dtype=float32),\n", + " 'step_type': array(0, dtype=int32)})" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "env.reset()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { "tags": [] }, @@ -196,20 +231,45 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAIAAAD9V4nPAAAGdElEQVR4nO3dwU2DYBiAYTFdwjl0DOdoZ2rncAydwzHw4qFWTZqg/DTv89zgQL4LecMXAtM8z3cAUHU/egAAGEkIAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDSdqMHAD69nQ7nh4/746hJIEUIYZiL8gFDWI0CkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQgiLTAuMujJwTggBSBNCANJ2oweAupf3/cWZ54fTkEmgyRMhjPS9gr+dBP6JEMIWaSGsRghhGLWDLRBCANKEEIA0IYRhvB0KWyCEsEUaCasRQhjpx+CpIKxpmud59Axww5Z82PP1+OWt0afDn/XPfQ3XE0JYZJtfuHZfw/WsRgEAAKqsRmERq1G4dVajAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKQJIQBpQghAmhACkCaEAKT5+wQAaZ4IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIE0IAUgTQgDShBCANCEEIO0Dbm0yV3iC604AAAAASUVORK5CYII=\n", + "text/plain": [ + "" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "PIL.Image.fromarray(frame)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Time Step Spec:\n", + "TimeStep(\n", + "{'discount': BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0),\n", + " 'observation': BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]),\n", + " 'reward': ArraySpec(shape=(), dtype=dtype('float32'), name='reward'),\n", + " 'step_type': ArraySpec(shape=(), dtype=dtype('int32'), name='step_type')})\n" + ] + } + ], "source": [ "print('Time Step Spec:')\n", "print(env.time_step_spec())" @@ -217,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "collapsed": false, "jupyter": { @@ -227,7 +287,16 @@ "name": "#%%\n" } }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Action Spec:\n", + "BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)\n" + ] + } + ], "source": [ "print('Action Spec:')\n", "print(env.action_spec())" diff --git a/notebooks/agents.md b/notebooks/agents.md index 47f7fcb..08fec12 100644 --- a/notebooks/agents.md +++ b/notebooks/agents.md @@ -1,6 +1,6 @@ # Agents in the `tf-agents` -## Agents +## Agents / Algorithms | Agent (Algorithm) | Description | TFA Module | Action Space | Release | Inventor | Related Agents (or Algorithms) | On-policy / Off-policy | | ------------------------------------------------------------ | ----------------------------------------------------- | ------------------ | -------------------- | ------- | -------- | ------------------------------------------------------------ | ---------------------- | @@ -13,6 +13,6 @@ | [PPOClipAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/PPOClipAgent) | PPO with clipped probability ratios | ppo | | | | | | | [PPOKLPenaltyAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/PPOKLPenaltyAgent) | PPO with KL penalty loss | ppo | | | | | | | ReinforceAgent | REINFORCE | reinforce | | | | | | -| SacAgent | Soft Actor Critic | sas | | | | | | +| [SacAgent](https://www.tensorflow.org/agents/api_docs/python/tf_agents/agents/SacAgent) | [Soft Actor Critic](https://arxiv.org/abs/1801.01290) | sas | Continuous, Discrete | 2018 | Berkeley | | Off-policy | | TD3 Agent | Twin Delayed Deep Deterministic policy gradient (TD3) | td3 | | | | | |