Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractModel type & neural kernel network #94

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
1bf561d
Basic GP examples
willtebbutt Feb 21, 2020
aa412c9
Relax version requirement
willtebbutt Feb 21, 2020
331605e
Complete plotting basics
willtebbutt Feb 21, 2020
2841ba8
Document examples
willtebbutt Feb 21, 2020
e5bedfe
Demonstrate approximate inference with Titsias
willtebbutt Feb 21, 2020
bbccdab
Docuemntation
willtebbutt Feb 21, 2020
40a8725
Furhter docs improvements
willtebbutt Feb 21, 2020
64a3a59
More docs and the process decomposition example
willtebbutt Feb 21, 2020
b2bd74e
More docs, more examples
willtebbutt Feb 21, 2020
c2d8d1d
Sensor fusion
willtebbutt Feb 21, 2020
5c172bc
Tweak docs
willtebbutt Feb 21, 2020
73b87f4
More docs and more examples
willtebbutt Feb 21, 2020
731654a
More examples, more docs
willtebbutt Feb 21, 2020
7b6b0df
WIP on GPPP + Pseudo-Points
willtebbutt Feb 21, 2020
8f32481
add fnn example
HamletWantToCode Feb 25, 2020
d7e70b6
add classification example
HamletWantToCode Feb 25, 2020
dabddaf
correct some syntax
HamletWantToCode Feb 25, 2020
76804ab
Merge branch 'example-revamp'
HamletWantToCode Feb 25, 2020
8f3d599
correct indentation
HamletWantToCode Feb 25, 2020
d0bd5f8
correct indentation
HamletWantToCode Feb 25, 2020
7bcf9e6
add readme entries on new examples
HamletWantToCode Feb 26, 2020
c35f523
Merge branch 'master' of github.com:HamletWantToCode/Stheno.jl
HamletWantToCode Feb 26, 2020
4531354
Merge branch 'master' of https://github.com/willtebbutt/Stheno.jl
HamletWantToCode Feb 29, 2020
84bc524
add neural kernel network
HamletWantToCode Feb 29, 2020
0f0b96e
Merge branch 'master' of https://github.com/willtebbutt/Stheno.jl
HamletWantToCode Feb 29, 2020
e9afe7b
Merge branch 'nkn_kernel'
HamletWantToCode Feb 29, 2020
07f4291
correct indentation
HamletWantToCode Feb 29, 2020
7f6bb36
Update neural_kernel_network.jl
HamletWantToCode Feb 29, 2020
3e1c2d6
update, fix NKN's ew method, modify parameter's type of some kernel i…
HamletWantToCode Mar 1, 2020
fe77f08
update
HamletWantToCode Mar 1, 2020
86d33b8
update example
HamletWantToCode Mar 1, 2020
fe3d483
design a tree structure for handling model parameters
HamletWantToCode Mar 3, 2020
2cf9d6e
add AbstractModel type, add neural network specified for GP
HamletWantToCode Mar 3, 2020
9da0fd4
fix bug
HamletWantToCode Mar 4, 2020
237d99b
fix bug, pass tests
HamletWantToCode Mar 4, 2020
4a98ba8
add kernel parameter constraint, redefine a interface for Scaled, upd…
HamletWantToCode Mar 5, 2020
582277d
update
HamletWantToCode Mar 5, 2020
fba284b
add `child` & `get_iparam` to composite_gp
HamletWantToCode Mar 6, 2020
bc05b77
add annotations
HamletWantToCode Mar 6, 2020
0b46ba2
Update composite_gp.jl
HamletWantToCode Mar 6, 2020
a3a40ef
Update kernel.jl
HamletWantToCode Mar 6, 2020
3250164
Update basic.jl
HamletWantToCode Mar 6, 2020
9d14353
Update abstract_model.jl
HamletWantToCode Mar 6, 2020
cf39c97
fix bug
HamletWantToCode Mar 6, 2020
3c6ef99
fix bug
HamletWantToCode Mar 7, 2020
cecd7b5
remove AbstractModel, add tests for NKN
HamletWantToCode Mar 13, 2020
c487460
update
HamletWantToCode Mar 13, 2020
38c5786
update
HamletWantToCode Mar 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions examples/flux_integration/neural_kernel_network/AirPassengers.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"","time","value"
"1",1949,112
"2",1949.08333333333,118
"3",1949.16666666667,132
"4",1949.25,129
"5",1949.33333333333,121
"6",1949.41666666667,135
"7",1949.5,148
"8",1949.58333333333,148
"9",1949.66666666667,136
"10",1949.75,119
"11",1949.83333333333,104
"12",1949.91666666667,118
"13",1950,115
"14",1950.08333333333,126
"15",1950.16666666667,141
"16",1950.25,135
"17",1950.33333333333,125
"18",1950.41666666667,149
"19",1950.5,170
"20",1950.58333333333,170
"21",1950.66666666667,158
"22",1950.75,133
"23",1950.83333333333,114
"24",1950.91666666667,140
"25",1951,145
"26",1951.08333333333,150
"27",1951.16666666667,178
"28",1951.25,163
"29",1951.33333333333,172
"30",1951.41666666667,178
"31",1951.5,199
"32",1951.58333333333,199
"33",1951.66666666667,184
"34",1951.75,162
"35",1951.83333333333,146
"36",1951.91666666667,166
"37",1952,171
"38",1952.08333333333,180
"39",1952.16666666667,193
"40",1952.25,181
"41",1952.33333333333,183
"42",1952.41666666667,218
"43",1952.5,230
"44",1952.58333333333,242
"45",1952.66666666667,209
"46",1952.75,191
"47",1952.83333333333,172
"48",1952.91666666667,194
"49",1953,196
"50",1953.08333333333,196
"51",1953.16666666667,236
"52",1953.25,235
"53",1953.33333333333,229
"54",1953.41666666667,243
"55",1953.5,264
"56",1953.58333333333,272
"57",1953.66666666667,237
"58",1953.75,211
"59",1953.83333333333,180
"60",1953.91666666667,201
"61",1954,204
"62",1954.08333333333,188
"63",1954.16666666667,235
"64",1954.25,227
"65",1954.33333333333,234
"66",1954.41666666667,264
"67",1954.5,302
"68",1954.58333333333,293
"69",1954.66666666667,259
"70",1954.75,229
"71",1954.83333333333,203
"72",1954.91666666667,229
"73",1955,242
"74",1955.08333333334,233
"75",1955.16666666667,267
"76",1955.25,269
"77",1955.33333333334,270
"78",1955.41666666667,315
"79",1955.5,364
"80",1955.58333333334,347
"81",1955.66666666667,312
"82",1955.75,274
"83",1955.83333333334,237
"84",1955.91666666667,278
"85",1956,284
"86",1956.08333333334,277
"87",1956.16666666667,317
"88",1956.25,313
"89",1956.33333333334,318
"90",1956.41666666667,374
"91",1956.5,413
"92",1956.58333333334,405
"93",1956.66666666667,355
"94",1956.75,306
"95",1956.83333333334,271
"96",1956.91666666667,306
"97",1957,315
"98",1957.08333333334,301
"99",1957.16666666667,356
"100",1957.25,348
"101",1957.33333333334,355
"102",1957.41666666667,422
"103",1957.5,465
"104",1957.58333333334,467
"105",1957.66666666667,404
"106",1957.75,347
"107",1957.83333333334,305
"108",1957.91666666667,336
"109",1958,340
"110",1958.08333333334,318
"111",1958.16666666667,362
"112",1958.25,348
"113",1958.33333333334,363
"114",1958.41666666667,435
"115",1958.5,491
"116",1958.58333333334,505
"117",1958.66666666667,404
"118",1958.75,359
"119",1958.83333333334,310
"120",1958.91666666667,337
"121",1959,360
"122",1959.08333333334,342
"123",1959.16666666667,406
"124",1959.25,396
"125",1959.33333333334,420
"126",1959.41666666667,472
"127",1959.5,548
"128",1959.58333333334,559
"129",1959.66666666667,463
"130",1959.75,407
"131",1959.83333333334,362
"132",1959.91666666667,405
"133",1960,417
"134",1960.08333333334,391
"135",1960.16666666667,419
"136",1960.25,461
"137",1960.33333333334,472
"138",1960.41666666667,535
"139",1960.5,622
"140",1960.58333333334,606
"141",1960.66666666667,508
"142",1960.75,461
"143",1960.83333333334,390
"144",1960.91666666667,432
12 changes: 12 additions & 0 deletions examples/flux_integration/neural_kernel_network/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Stheno = "8188c328-b5d6-583d-959b-9690869a5511"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Flux = "0.10"
Stheno = "0.6"
Zygote = "0.4.6"
julia = "1"
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
157 changes: 157 additions & 0 deletions examples/flux_integration/neural_kernel_network/time_series.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Set up the environment to run this example. Make sure you're within the folder that this
# file lives in.
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

using LinearAlgebra, Stheno, Flux, Zygote, DelimitedFiles, Statistics
using Plots; pyplot();
using Random; Random.seed!(4);

######################################################
# Data loading
## read AirPass data
data = readdlm("AirPassengers.csv", ',')
year = data[2:end,2]; passengers = data[2:end,3];
## Split the data into training and testing data
oxtrain = year[year.<1958]; oytrain = passengers[year.<1958];
oxtest = year[year.>=1958]; oytest = passengers[year.>=1958];

##data preprocessing
### standardize X and y
xtrain_mean = mean(oxtrain)
ytrain_mean = mean(oytrain)
xtrain_std = std(oxtrain)
ytrain_std = std(oytrain)
xtrain = @. (oxtrain-xtrain_mean)/xtrain_std
ytrain = @. (oytrain-ytrain_mean)/ytrain_std

xtest = @. (oxtest-xtrain_mean)/xtrain_std
ytest = @. (oytest-ytrain_mean)/ytrain_std

## input data
Xtrain = reshape(xtrain, 1, length(xtrain))
Xtest = reshape(xtest, 1, length(xtest))
Year = hcat(Xtrain, Xtest)
Passengers = vcat(ytrain, ytest)
######################################################

plt = plot(xlabel="Year", ylabel="Airline Passenger number", legend=true)
scatter!(plt, oxtrain, oytrain, label="Observations(train)", color=:black)




######################################################
# Build kernel with Neural Kernel Network
## kernel length scale initialization
function median_distance_local(x)
n = length(x)
dist = []
for i in 1:n
for j in i:n
push!(dist, abs(x[j]-x[i]))
end
end
median(dist)
end
l = median_distance_local(xtrain)

## kernel parameter constraint
g1(x) = exp(-x)
g2(x) = exp(x)

## define kernels
iso_lin_kernel1 = stretch(Linear(), log(1.0), g1)
iso_per_kernel1 = scale(stretch(PerEQ(log(l), g2), log(l), g1), log(1.0), g2)
iso_eq_kernel1 = scale(stretch(EQ(), log(l/4.0), g1), log(1.0), g2)
iso_rq_kernel1 = scale(stretch(RQ(log(0.2), g2), log(2.0*l), g1), log(1.0), g2)
iso_lin_kernel2 = stretch(Linear(), log(1.0), g1)
iso_rq_kernel2 = scale(stretch(RQ(log(0.1), g2), log(l), g1), log(1.0), g2)
iso_eq_kernel2 = scale(stretch(EQ(), log(l), g1), log(1.0), g2)
iso_per_kernel2 = scale(stretch(PerEQ(log(l/4.0), g2), log(l/4.0), g1), log(1.0), g2)


# define network
linear1 = LinearLayer(8, 8)
prod1 = ProductLayer(2)
linear2 = LinearLayer(4, 4)
prod2 = ProductLayer(2)
linear3 = LinearLayer(2, 1)

## NKN
player = Primitive(iso_lin_kernel1, iso_per_kernel1, iso_eq_kernel1, iso_rq_kernel1,
iso_lin_kernel2, iso_rq_kernel2, iso_eq_kernel2, iso_per_kernel2)
nn = chain(linear1, prod1, linear2, prod2, linear3)
nkn = NeuralKernelNetwork(player, nn)
#############################################################


# Do some common calculation
σ²_n = 0.1 # specify Gaussian noise
gp = GP(nkn, GPC()) # define GP
loss(m, x, y) = -logpdf(m(ColVecs(x), σ²_n), y) # define loss & compute negative log likelihood
loss(gp, Xtrain, ytrain)
∂gp, = gradient(m->loss(m, Xtrain, ytrain), gp) # compute derivative of loss w.r.t GP parameters

# extract all parameters from the GP model
l_ps = parameters(gp) |> length
# extract the corresponding gradients from the derivative ( or conjugate of GP model )
l_∂ps = extract_gradient(gp, ∂gp) |> length
# make sure parameters and gradients are in one-to-one correspondence
@assert l_ps == l_∂ps


#############################################################
# Optimize GP parameters w.r.t training data
using Flux.Optimise: update!

optimizer = ADAM(0.001)
L = []
for i in 1:5000
nll = loss(gp, Xtrain, ytrain)
push!(L, nll)
if i==1 || i%200 == 0
@info "step=$i, loss=$nll"
end
ps = parameters(gp)
∂gp, = gradient(m->loss(m, Xtrain, ytrain), gp)

Δps = extract_gradient(gp, ∂gp)
update!(optimizer, ps, Δps)
dispatch!(gp, ps) # dispatch! will update the GP model with updated parameters
end

# you can view the loss curve
# plot(L, legend=false)
#############################################################


#############################################################
# make prediction
function predict(gp, X, Xtrain, ytrain)
gp_Xtrain = gp(ColVecs(Xtrain), σ²_n)
posterior = gp | Obs(gp_Xtrain, ytrain)
posterior(ColVecs(X))
end

posterior = predict(gp, Year, Xtrain, ytrain)
post_dist = marginals(posterior)
pred_y = mean.(post_dist)
var_y = std.(post_dist)

pred_oy = @. pred_y*ytrain_std+ytrain_mean
pred_oσ = @. var_y*ytrain_std

plot!(plt, year, pred_oy, ribbons=3*pred_oσ, title="Time series prediction",label="95% predictive confidence region")
scatter!(plt, oxtest, oytest, label="Observations(test)", color=:red)
display(plt)
##############################################################








25 changes: 15 additions & 10 deletions src/Stheno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ module Stheno
using BlockArrays: _BlockArray
import LinearAlgebra: cholesky, cross
import Distances: pairwise, colwise

const AV{T} = AbstractVector{T}
const AM{T} = AbstractMatrix{T}
const AVM{T} = AbstractVecOrMat{T}

const BlockLowerTriangular{T} = LowerTriangular{T, <:BlockMatrix{T}}
const BlockUpperTriangular{T} = UpperTriangular{T, <:BlockMatrix{T}}
const BlockTriangular{T} = Union{BlockLowerTriangular{T}, BlockUpperTriangular{T}}

function elementwise end

const pw = pairwise
const ew = elementwise

# Various bits of utility that aren't inherently GP-related. Often very type-piratic.
include(joinpath("util", "zygote_rules.jl"))
include(joinpath("util", "covariance_matrices.jl"))
Expand All @@ -34,15 +34,20 @@ module Stheno
include(joinpath("util", "abstract_data_set.jl"))
include(joinpath("util", "distances.jl"))
include(joinpath("util", "proper_type_piracy.jl"))

include(joinpath("util", "parameter_handler.jl"))

# Supertype for GPs.
include("abstract_gp.jl")

# Atomic GP objects.

# Neural network used for building neural kernel network
include(joinpath("neural_network", "basic.jl"))

# Atomic GP objects
include(joinpath("gp", "mean.jl"))
include(joinpath("gp", "kernel.jl"))
include(joinpath("gp", "neural_kernel_network.jl"))
include(joinpath("gp", "gp.jl"))

# Composite GPs, constructed via affine transformation of CompositeGPs and GPs.
include(joinpath("composite", "composite_gp.jl"))
include(joinpath("composite", "indexing.jl"))
Expand All @@ -54,7 +59,7 @@ module Stheno
include(joinpath("composite", "compose.jl"))
# include(joinpath("composite", "gradient.jl"))
# include(joinpath("composite", "integrate.jl"))

# Various stuff for convenience.
include(joinpath("util", "model.jl"))
include(joinpath("util", "plotting.jl"))
Expand Down
3 changes: 3 additions & 0 deletions src/composite/compose.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Should parameters inside `stretch`, `periodic` & `shift` be optimized ?
# Or these are just treated as data preprocessing ?

import Base: ∘
export ∘, select, stretch, periodic, shift

Expand Down
3 changes: 3 additions & 0 deletions src/composite/composite_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ struct CompositeGP{Targs} <: AbstractGP
return gp
end
end
get_iparam(::CompositeGP) = throw(UndefVarError("get_iparam method currently not defined for composite GP"))
child(::CompositeGP) = throw(UndefVarError("child method currently not defined for composite GP"))

CompositeGP(args::Targs, gpc::GPC) where {Targs} = CompositeGP{Targs}(args, gpc)

mean_vector(f::CompositeGP, x::AV) = mean_vector(f.args, x)
Expand Down
Loading