Skip to content

Commit

Permalink
feat: compile hypernet with reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 29, 2025
1 parent 94ff497 commit 53693b7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
6 changes: 2 additions & 4 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[compat]
ComponentArrays = "0.15.22"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Zygote = "0.6.70, 0.7"
Reactant = "0.2.21"
55 changes: 32 additions & 23 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

# ## Package Imports

using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote

CUDA.allowscalar(false)
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Reactant

# ## Loading Datasets
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
Expand Down Expand Up @@ -42,10 +40,9 @@ end

# ## Implement a HyperNet Layer
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
ComponentArray |> getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
## Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
Expand All @@ -55,7 +52,7 @@ end

# Defining functions on the CompactLuxLayer requires some understanding of how the layer
# is structured, as such we don't recommend doing it unless you are familiar with the
# internals. In this case, we simply write it to ignore the initialization of the
# internals. In this case, we simply write it to ignore the initialization of the
# `core_network` parameters.

function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
Expand All @@ -77,14 +74,13 @@ function create_model()
end

# ## Define Utility Functions
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
Expand All @@ -93,34 +89,45 @@ end

# ## Training
function train()
dev = reactant_device(; force=true)

model = create_model()
dataloaders = load_datasets()
dataloaders = load_datasets() |> dev

dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev

train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)

stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
AutoEnzyme(), CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y), train_state; return_gradients=Val(false)
)
end
ttime = time() - stime

train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
Expand All @@ -134,13 +141,15 @@ function train()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
Expand Down

0 comments on commit 53693b7

Please sign in to comment.