Skip to content

Commit

Permalink
feat: compile hypernet with reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 11, 2025
1 parent 044efef commit fdf056e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 39 deletions.
6 changes: 2 additions & 4 deletions examples/HyperNet/Project.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[compat]
ComponentArrays = "0.15.22"
Lux = "1"
LuxCUDA = "0.3"
MLDatasets = "0.7"
MLUtils = "0.4"
OneHotArrays = "0.2.5"
Optimisers = "0.4.1"
Zygote = "0.6.70, 0.7"
Reactant = "0.2.21"
89 changes: 54 additions & 35 deletions examples/HyperNet/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

# ## Package Imports

using Lux, ComponentArrays, LuxCUDA, MLDatasets, MLUtils, OneHotArrays, Optimisers,
Printf, Random, Zygote

CUDA.allowscalar(false)
using Lux, ComponentArrays, MLDatasets, MLUtils, OneHotArrays, Optimisers, Printf, Random,
Reactant

# ## Loading Datasets
function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},
Expand All @@ -28,9 +26,13 @@ function load_dataset(::Type{dset}, n_train::Union{Nothing, Int},

return (
DataLoader(
(x_train, y_train); batchsize=min(batchsize, size(x_train, 4)), shuffle=true),
(x_train, y_train);
batchsize=min(batchsize, size(x_train, 4)), shuffle=true, partial=false
),
DataLoader(
(x_test, y_test); batchsize=min(batchsize, size(x_test, 4)), shuffle=false)
(x_test, y_test);
batchsize=min(batchsize, size(x_test, 4)), shuffle=false, partial=false
)
)
end

Expand All @@ -42,10 +44,9 @@ end

# ## Implement a HyperNet Layer
function HyperNet(
weight_generator::Lux.AbstractLuxLayer, core_network::Lux.AbstractLuxLayer)
weight_generator::AbstractLuxLayer, core_network::AbstractLuxLayer)
ca_axes = Lux.initialparameters(Random.default_rng(), core_network) |>
ComponentArray |>
getaxes
ComponentArray |> getaxes
return @compact(; ca_axes, weight_generator, core_network, dispatch=:HyperNet) do (x, y)
## Generate the weights
ps_new = ComponentArray(vec(weight_generator(x)), ca_axes)
Expand All @@ -55,7 +56,7 @@ end

# Defining functions on the CompactLuxLayer requires some understanding of how the layer
# is structured, as such we don't recommend doing it unless you are familiar with the
# internals. In this case, we simply write it to ignore the initialization of the
# internals. In this case, we simply write it to ignore the initialization of the
# `core_network` parameters.

function Lux.initialparameters(rng::AbstractRNG, hn::CompactLuxLayer{:HyperNet})
Expand All @@ -64,27 +65,32 @@ end

# ## Create and Initialize the HyperNet
function create_model()
## Doesn't need to be a MLP can have any Lux Layer
core_network = Chain(FlattenLayer(), Dense(784, 256, relu), Dense(256, 10))
weight_generator = Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
core_network = Chain(
Conv((3, 3), 1 => 16, relu; stride=2),
Conv((3, 3), 16 => 32, relu; stride=2),
Conv((3, 3), 32 => 64, relu; stride=2),
GlobalMeanPool(),
FlattenLayer(),
Dense(64, 10)
)
return HyperNet(
Chain(
Embedding(2 => 32),
Dense(32, 64, relu),
Dense(64, Lux.parameterlength(core_network))
),
core_network
)

model = HyperNet(weight_generator, core_network)
return model
end

# ## Define Utility Functions
const loss = CrossEntropyLoss(; logits=Val(true))

function accuracy(model, ps, st, dataloader, data_idx)
total_correct, total = 0, 0
cdev = cpu_device()
st = Lux.testmode(st)
for (x, y) in dataloader
target_class = onecold(y)
predicted_class = onecold(first(model((data_idx, x), ps, st)))
target_class = onecold(cdev(y))
predicted_class = onecold(cdev(first(model((data_idx, x), ps, st))))
total_correct += sum(target_class .== predicted_class)
total += length(target_class)
end
Expand All @@ -93,34 +99,45 @@ end

# ## Training
function train()
dev = reactant_device(; force=true)

model = create_model()
dataloaders = load_datasets()
dataloaders = load_datasets() |> dev

dev = gpu_device()
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev

train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
train_state = Training.TrainState(model, ps, st, Adam(0.0003f0))

x = first(first(dataloaders[1][1]))
data_idx = ConcreteRNumber(1)
model_compiled = @compile model((data_idx, x), ps, Lux.testmode(st))

### Lets train the model
nepochs = 50
for epoch in 1:nepochs, data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

### This allows us to trace the data index, else it will be embedded as a constant
### in the IR
concrete_data_idx = ConcreteRNumber(data_idx)

stime = time()
for (x, y) in train_dataloader
(_, _, _, train_state) = Training.single_train_step!(
AutoZygote(), loss, ((data_idx, x), y), train_state)
AutoEnzyme(), CrossEntropyLoss(; logits=Val(true)),
((concrete_data_idx, x), y), train_state; return_gradients=Val(false)
)
end
ttime = time() - stime

train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
Expand All @@ -134,13 +151,15 @@ function train()
test_acc_list = [0.0, 0.0]
for data_idx in 1:2
train_dataloader, test_dataloader = dataloaders[data_idx] .|> dev

concrete_data_idx = ConcreteRNumber(data_idx)
train_acc = round(
accuracy(model, train_state.parameters,
train_state.states, train_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, train_dataloader, concrete_data_idx) * 100;
digits=2)
test_acc = round(
accuracy(model, train_state.parameters,
train_state.states, test_dataloader, data_idx) * 100;
accuracy(model_compiled, train_state.parameters,
train_state.states, test_dataloader, concrete_data_idx) * 100;
digits=2)

data_name = data_idx == 1 ? "MNIST" : "FashionMNIST"
Expand Down

0 comments on commit fdf056e

Please sign in to comment.