Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural Kernel Network implements #92

Conversation

HamletWantToCode
Copy link
Contributor

This PR is an attempt to implement Neural Kernel Network (NKN) for Stheno, a working example can be find here

Interface

primitive_layer = Primitive(EQ(), PerEQ(), Linear()) # <:Primitive, kernel container & compute covariant matrix for each kernel & make them valid input to neural network
lin1 = LinearLayer(3, 4)  # <:LinearLayer, linear transformation
lin2 = LinearLayer(2, 1)
nn = Chain(lin1, Product, lin2)   # use Flux's `Chain` to build a neural network
nkn_kernel = NeuralKernelNetwork(primitive_layer, nn) # <:Kernel, composite kernel built on neural network

Newly implemented types & function:

  1. Primitive: NKN can be viewed as a composite kernel, Primitive serves as a container of all the basic kernels. It has ew & pw method implemented, but it isn't a subtype of Stheno's Kernel. Calls like ew(<:Primitive, x) & pw(<:Primitive, x) will compute ew and pw for each kernel inside Primitive, and then prepare them to be inputs to the following neural network.

  2. LinearLayer: This is just a linear transformation z = W*x. The reason I create this type instead of using Flux's Dense is because we don't need bias and activation functions here.

  3. Product: A product function perform element wise multiplication of kernel matrices.

  4. NeuralKernelNetwork: It's a subtype of Stheno's Kernel type with ew and pw method implemented, it can be viewed as a common Stheno's kernel.

Supports

  • Extract all parameters within NKN with Flux's params method
  • Use Zygote to compute gradient of the logpdf w.r.t all the parameters in NKN

To be discussed

In order to allow using Flux's params to extract all the parameters inside NKN, I slightly modify the definition and type of input variables ofScaled, Stretched and RQ in kernels.jl.

  1. Scaled: the original σ² is replaced by logσ², and it's type is restricted to AbstractVector. The reason for doing so is that σ² should remain positive during the optimization, and Flux's params method requires the type of the fields to be an AbstractArray.

  2. Stretched: a is replaced by loga and it's type is restricted to AbstractVecOrMat ( reason is the same as above ).

  3. RQ: α is replaced by logα and it's type is restricted to AbstractVector ( reason is the same as above ).

  4. PerEQ: I noticed that this kernel hasn't been exported by Stheno yet, I reimplement and export it.

NOTE: I only do some basic tests for these modification, it is not guaranteed to be type stable and may report bugs in other situations


Reference

[1] Shengyang Sun, Guodong Zhang, Chaoqi Wang, Wenyuan Zeng, Jiaman Li , Roger Grosse, Differentiable Compositional Kernel Learning for Gaussian Processes (2018)

willtebbutt and others added 30 commits February 6, 2020 03:35
* initial pass

* Enable main tests

* Make documenter a test dep

* Fix travis

* Some work

* Some work

* Tweak compat

* Tweak compat again
* Fix Diagonal perf

* Bump version

* Update news
* Make compat less restrictive

* Bump patch
* Make forwards-pass type stable for Float32

* Remove new space
* Basic GP examples

* Relax version requirement

* Complete plotting basics

* Document examples

* Demonstrate approximate inference with Titsias

* Docuemntation

* Furhter docs improvements

* More docs and the process decomposition example

* More docs, more examples

* Sensor fusion

* Tweak docs

* More docs and more examples

* More examples, more docs

* WIP on GPPP + Pseudo-Points
@willtebbutt
Copy link
Member

Thanks for this PR. I'm really busy this week, so I'll do a proper review early next week.

@codecov
Copy link

codecov bot commented Mar 6, 2020

Codecov Report

Merging #92 into wct/flux-nkn-integration will decrease coverage by 13.17%.
The diff coverage is 43.1%.

Impacted file tree graph

@@                    Coverage Diff                     @@
##           wct/flux-nkn-integration   #92       +/-   ##
==========================================================
- Coverage                     88.17%   75%   -13.18%     
==========================================================
  Files                            24    27        +3     
  Lines                           685   844      +159     
==========================================================
+ Hits                            604   633       +29     
- Misses                           81   211      +130
Impacted Files Coverage Δ
src/composite/compose.jl 61.53% <ø> (ø) ⬆️
src/Stheno.jl 100% <ø> (ø) ⬆️
src/neural_network/basic.jl 0% <0%> (ø)
src/abstract_model.jl 0% <0%> (ø)
src/gp/neural_kernel_network.jl 0% <0%> (ø)
src/composite/composite_gp.jl 80.64% <0%> (-19.36%) ⬇️
src/gp/gp.jl 88.23% <0%> (-11.77%) ⬇️
src/util/zygote_rules.jl 97.43% <100%> (+0.37%) ⬆️
src/abstract_gp.jl 86.36% <100%> (+1.74%) ⬆️
src/gp/mean.jl 60% <33.33%> (-40%) ⬇️
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 02bea63...3c6ef99. Read the comment docs.

@willtebbutt
Copy link
Member

Would you mind rebasing this on top of master so that it's easier to inspect the diff?

@HamletWantToCode
Copy link
Contributor Author

HamletWantToCode commented Mar 10, 2020 via email

@willtebbutt
Copy link
Member

It would be really helpful. It's not really possible to review it as it currently is.

@HamletWantToCode
Copy link
Contributor Author

I have opened a new PR #95 , since rebasing this PR failed ( due to the existence of PR #94 ), sorry for messing these up, and thank you for your time.

I will close this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants