-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6a3ba04
commit 068aa4f
Showing
5 changed files
with
152 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "julia" | ||
} | ||
}, | ||
"source": [ | ||
"# Deep Recurrent Neural Networks\n", | ||
"\n", | ||
"## Concise Implementation\n", | ||
"\n", | ||
"Fortunately many of the logistical details required to implement multiple layers of an RNN are readily available in high-level APIs. Our concise implementation will use such built-in functionalities. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"┌ Info: Using backend: CUDA.\n", | ||
"└ @ Flux /home/nero/.julia/packages/Flux/Wz6D4/src/functor.jl:662\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"predict (generic function with 1 method)" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"using Downloads,IterTools,CUDA,Flux\n", | ||
"using StatsBase: wsample\n", | ||
"\n", | ||
"device = Flux.get_device(; verbose=true)\n", | ||
"\n", | ||
"file_path = Downloads.download(\"http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt\")\n", | ||
"raw_text = open(io->read(io, String),file_path)\n", | ||
"str = lowercase(replace(raw_text,r\"[^A-Za-z]+\"=>\" \"))\n", | ||
"tokens = [str...]\n", | ||
"vocab = unique(tokens)\n", | ||
"vocab_len = length(vocab)\n", | ||
"\n", | ||
"# n*[seq_length x feature x batch_size]\n", | ||
"function getdata(str::String,vocab::Vector{Char},seq_length::Int,batch_size::Int)::Tuple\n", | ||
" data = collect.(partition(str,seq_length,1))\n", | ||
" x = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[begin:end-1];size = batch_size))]\n", | ||
" y = [[Flux.onehotbatch(i,vocab) for i in d] for d in Flux.batchseq.(Flux.chunk(data[2:end];size = batch_size))]\n", | ||
" return x,y\n", | ||
"end\n", | ||
"\n", | ||
"function loss(model, xs, ys)\n", | ||
" Flux.reset!(model)\n", | ||
" return sum(Flux.logitcrossentropy.([model(x) for x in xs], ys))\n", | ||
"end\n", | ||
"\n", | ||
"function predict(model::Chain, prefix::String, num_preds::Int)\n", | ||
" model = cpu(model)\n", | ||
" Flux.reset!(model)\n", | ||
" buf = IOBuffer()\n", | ||
" write(buf, prefix)\n", | ||
"\n", | ||
" c = wsample(vocab, softmax([model(Flux.onehot(c, vocab)) for c in collect(prefix)][end]))\n", | ||
" for i in 1:num_preds\n", | ||
" write(buf, c)\n", | ||
" c = wsample(vocab, softmax(model(Flux.onehot(c, vocab))))\n", | ||
" end\n", | ||
" return String(take!(buf))\n", | ||
"end" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Chain(\n", | ||
" Recur(\n", | ||
" GRUv3Cell(27 => 32), \u001b[90m# 5_792 parameters\u001b[39m\n", | ||
" ),\n", | ||
" Recur(\n", | ||
" GRUv3Cell(32 => 32), \u001b[90m# 6_272 parameters\u001b[39m\n", | ||
" ),\n", | ||
" Dense(32 => 27), \u001b[90m# 891 parameters\u001b[39m\n", | ||
") \u001b[90m # Total: 12 trainable arrays, \u001b[39m12_955 parameters,\n", | ||
"\u001b[90m # plus 2 non-trainable, 64 parameters, summarysize \u001b[39m1.938 KiB." | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"model = Chain(GRUv3(vocab_len => 32),GRUv3(32 => 32),Dense(32 => vocab_len)) |> device" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"opt_state = Flux.setup(Adam(1e-2), model)\n", | ||
"\n", | ||
"x,y = getdata(str, vocab, 32, 1024) |> device\n", | ||
"data = zip(x,y)\n", | ||
"loss_train = []\n", | ||
"\n", | ||
"for epoch in 1:50\n", | ||
" @show \"$epoch start\"\n", | ||
" Flux.reset!(model)\n", | ||
" Flux.train!(loss,model,data,opt_state)\n", | ||
" push!(loss_train,sum(loss.(Ref(model), x, y)) / length(str)) \n", | ||
"end" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Julia 1.10.3", | ||
"language": "julia", | ||
"name": "julia-1.10" | ||
}, | ||
"language_info": { | ||
"file_extension": ".jl", | ||
"mimetype": "application/julia", | ||
"name": "julia", | ||
"version": "1.10.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters