Skip to content

Commit

Permalink
Linear Regression move to Lux.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone committed Oct 28, 2024
1 parent 8d2e6a0 commit 7168262
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@ IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLBase = "f0e99cf1-93fa-52ec-9ecc-5026115318e0"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[compat]
Lux = "1.2.0"
Optimisers = "0.3.3"
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ This tutorial mainly focuses on using pure julia to implement the code in *Dive

Install `jupyterlab-desktop` or vscode with `jupyter` plugin.

Install `Julia` 1.10:
Install `Julia` 1.11:

``` shell
$ julia -v
julia version 1.10.3
julia version 1.11.1
```

Clone this project and change directory to `D2lJulia` and install dependencies:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
"using Distributions\n",
"\n",
"function synthetic_data(w::Vector{<:Real},b::Real,num_example::Int)\n",
" X = rand(Normal(0f0,1f0),(num_example,length(w)))\n",
" y = X * w .+ b\n",
" X = randn(Float32,(num_example,length(w)))\n",
" y = Float32.(X * w .+ b)\n",
" y += rand(Normal(0f0,0.01f0),(size(y)))\n",
" return X',reshape(y,(1,:))\n",
"end"
Expand All @@ -61,7 +61,7 @@
{
"data": {
"text/plain": [
"(Float32[-0.4412894 1.1683054 … 1.3911616 0.98486364; -0.9546863 0.6044618 … 0.8892765 0.543286], [6.5628222349681895 4.4863953853026043.9511675149202348 4.3048679489642385])"
"(Float32[0.5049412 0.24637741 … 1.7447525 0.9225617; -0.9852763 -1.5925564 … 0.7616638 1.9118508], Float32[8.557541 10.1087555.107828 -0.4439846])"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -93,8 +93,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"features:Float32[-0.4412894, -0.9546863]\n",
"label:6.5628222349681895\n"
"features:Float32[0.5049412, -0.9852763]\n",
"label:8.557541\n"
]
}
],
Expand Down Expand Up @@ -159,7 +159,7 @@
{
"data": {
"text/plain": [
"Dense(2 => 1) \u001b[90m# 3 parameters\u001b[39m"
"((weight = Float32[-0.034858026 -0.23098828], bias = Float32[-0.68244076]), NamedTuple())"
]
},
"execution_count": 5,
Expand All @@ -168,8 +168,10 @@
}
],
"source": [
"using Flux\n",
"model = Dense(2=>1)"
"using Lux,Random\n",
"rng = Xoshiro(0)\n",
"model = Dense(2=>1)\n",
"ps, st = Lux.setup(rng, model)"
]
},
{
Expand Down Expand Up @@ -197,7 +199,7 @@
{
"data": {
"text/plain": [
"loss (generic function with 1 method)"
"(::GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(mean)}) (generic function with 2 methods)"
]
},
"execution_count": 6,
Expand All @@ -206,7 +208,7 @@
}
],
"source": [
"loss(model,x,y) = Flux.mse(model(x),y)"
"const mse = MSELoss()"
]
},
{
Expand All @@ -228,7 +230,7 @@
{
"data": {
"text/plain": [
"Descent(0.1)"
"Descent(0.1f0)"
]
},
"execution_count": 7,
Expand All @@ -237,6 +239,7 @@
}
],
"source": [
"using Optimisers\n",
"opt = Descent()"
]
},
Expand All @@ -260,20 +263,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 1, loss 0.381278 \n",
"epoch 2, loss 0.007915 \n",
"epoch 3, loss 0.000265 \n"
"epoch 1, loss 0.408317 \n",
"epoch 2, loss 0.006815 \n",
"epoch 3, loss 0.000224 \n"
]
}
],
"source": [
"using Printf\n",
"using Printf,Zygote\n",
"\n",
"train_state = Training.TrainState(model, ps, st,opt)\n",
"num_epochs = 3\n",
"for epoch in 1:num_epochs\n",
" for data in train_loader\n",
" Flux.train!(loss,model,[data],opt)\n",
" (_, loss, _, train_state) = Training.single_train_step!(AutoZygote(), mse, data, train_state)\n",
" end\n",
" @printf \"epoch %i, loss %f \\n\" epoch loss(model,features,labels)\n",
" @printf \"epoch %i, loss %f \\n\" epoch mse(model(features,ps,st)[1],labels)\n",
"end"
]
},
Expand All @@ -287,37 +292,45 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 14,
"id": "aeeef2ff-e9b2-4f4e-94a9-d37b407da11b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"error in estimating w:[-0.0016717910766601562, -0.0073896884918212]\n",
"error in estimating b:0.00999946594238299\n"
"error in estimating w:[-0.004839897155761719, -0.0038477420806883877]\n",
"error in estimating b:0.010038566589355646\n"
]
}
],
"source": [
"weight,bias = vec(model.weight),first(model.bias)\n",
"weight,bias = vec(ps[1]),first(ps[2])\n",
"println(\"error in estimating w:$(true_w - weight)\")\n",
"println(\"error in estimating b:$(true_b - bias)\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1be6f02c-fc1c-4e0f-a848-a3d0ef2f55f1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.1",
"display_name": "Julia 1.11.1",
"language": "julia",
"name": "julia-1.9"
"name": "julia-1.11"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.1"
"version": "1.11.1"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 7168262

Please sign in to comment.