Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone committed Jun 3, 2024
1 parent 6a3ba04 commit 068aa4f
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 30 deletions.
10 changes: 5 additions & 5 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ parts:
chapters:
- file: notebooks/chapter_recurrent_neural_networks/sequence.ipynb
- file: notebooks/chapter_recurrent_neural_networks/text-sequence.ipynb
- caption: Modern Recurrent Neural Networks
- caption: Modern Recurrent Neural Networks
numbered: true
- chapters:
- file: notebooks/chapter_recurrent-modern/lstm.ipynb
- file: notebooks/chapter_recurrent-modern/gru.ipynb
- file: notebooks/chapter_recurrent-modern/deep-rnn.ipynb
chapters:
- file: notebooks/chapter_recurrent_modern/lstm.ipynb
- file: notebooks/chapter_recurrent_modern/gru.ipynb
- file: notebooks/chapter_recurrent_modern/deep-rnn.ipynb

23 changes: 0 additions & 23 deletions notebooks/chapter_recurrent-modern/deep-rnn.ipynb

This file was deleted.

147 changes: 147 additions & 0 deletions notebooks/chapter_recurrent_modern/deep-rnn.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "julia"
}
},
"source": [
"# Deep Recurrent Neural Networks\n",
"\n",
"## Concise Implementation\n",
"\n",
"Fortunately many of the logistical details required to implement multiple layers of an RNN are readily available in high-level APIs. Our concise implementation will use such built-in functionalities. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Info: Using backend: CUDA.\n",
"└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662\n"
]
},
{
"data": {
"text/plain": [
"predict (generic function with 1 method)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"using Downloads,IterTools,CUDA,Flux\n",
"using StatsBase: wsample\n",
"\n",
"device = Flux.get_device(; verbose=true)\n",
"\n",
"file_path = Downloads.download(\"http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt\")\n",
"raw_text = open(io->read(io, String),file_path)\n",
"str = lowercase(replace(raw_text,r\"[^A-Za-z]+\"=>\" \"))\n",
"tokens = [str...]\n",
"vocab = unique(tokens)\n",
"vocab_len = length(vocab)\n",
"\n",
"# n*[seq_length x feature x batch_size]\n",
"function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple\n",
" data = collect.(partition(str,seq_length,1))\n",
" x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]\n",
" y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]\n",
" return x,y\n",
"end\n",
"\n",
"function loss(model, xs, ys)\n",
" Flux.reset!(model)\n",
" return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))\n",
"end\n",
"\n",
"function predict(model::Chain, prefix::String, num_preds::Int)\n",
" model = cpu(model)\n",
" Flux.reset!(model)\n",
" buf = IOBuffer()\n",
" write(buf, prefix)\n",
"\n",
" c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))\n",
" for i in 1:num_preds\n",
" write(buf, c)\n",
" c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))\n",
" end\n",
" return String(take!(buf))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Chain(\n",
" Recur(\n",
" GRUv3Cell(27 => 32), \u001b[90m# 5_792 parameters\u001b[39m\n",
" ),\n",
" Recur(\n",
" GRUv3Cell(32 => 32), \u001b[90m# 6_272 parameters\u001b[39m\n",
" ),\n",
" Dense(32 => 27), \u001b[90m# 891 parameters\u001b[39m\n",
") \u001b[90m # Total: 12 trainable arrays, \u001b[39m12_955 parameters,\n",
"\u001b[90m # plus 2 non-trainable, 64 parameters, summarysize \u001b[39m1.938 KiB."
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = Chain(GRUv3(vocab_len => 32),GRUv3(32 => 32),Dense(32 => vocab_len)) |> device"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"opt_state = Flux.setup(Adam(1e-2), model)\n",
"\n",
"x,y = getdata(str, vocab, 32, 1024) |> device\n",
"data = zip(x,y)\n",
"loss_train = []\n",
"\n",
"for epoch in 1:50\n",
" @show \"$epoch start\"\n",
" Flux.reset!(model)\n",
" Flux.train!(loss,model,data,opt_state)\n",
" push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str)) \n",
"end"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.10.3",
"language": "julia",
"name": "julia-1.10"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.10.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@
"loss_train = []\n",
"\n",
"for epoch in 1:50\n",
" @show \"$epoch start\"\n",
" Flux.reset!(model)\n",
" Flux.train!(loss,model,data,opt_state)\n",
" push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str)) \n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@
"loss_train = []\n",
"\n",
"for epoch in 1:50\n",
" @show \"$epoch start\"\n",
" Flux.reset!(model)\n",
" Flux.train!(loss,model,data,opt_state)\n",
" push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str)) \n",
Expand Down

0 comments on commit 068aa4f

Please sign in to comment.