Skip to content

Commit

Permalink
update to Lux 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
NeroBlackstone committed Sep 8, 2024
1 parent 4789f2d commit 5d33d63
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 41 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PianoHands"
uuid = "74435128-bd9d-4d82-978b-bd768beb391e"
authors = ["NeroBlackstone <[email protected]>"]
version = "0.1.0"
version = "0.2.0"

[deps]
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
Expand All @@ -26,7 +26,7 @@ test = ["Test"]
CodecZlib = "0.7"
IterTools = "1.10"
JLD2 = "0.4"
Lux = "0.5"
Lux = "1.0"
LuxCUDA = "0.3"
MIDI = "2.7"
MLUtils = "0.4"
Expand Down
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
# PianoHands.jl
Predicting hand assignments in piano MIDI using neural networks

# Use Pre-trained weight
# Use pre-trained model

``` julia
using PianoHands
generate_midi("./your_midi.mid";)
```

You will get a midi file `out.mid`, track 1 is left hand notes, track 2 is right hand notes.
You will get a midi file `your_midi_out.mid`, track 1 is left hand notes, track 2 is right hand notes.

# Train Your own weight.
# Train Your own model.

## Dataset preparation

Download PIG v1.2 Dataset to `PianoFingeringDataset` and remove duplicate fingering file, approximately 150 fingering files are required.

``` julia
train_piano(DATASET_PATH,
function train_piano(DATASET_PATH,
TESTSET_PATH;
BATCH_SIZE = 10,
SEQ_LENGTH = 70,
BATCH_SIZE = 12,
SEQ_LENGTH = 75,
HIDDEN_SIZE = 14,
LEARNING_RATE = 0.0002f0,
LEARNING_RATE = 0.0005f0,
MAX_EPOCH = 200,
EVALUATE_PER_N_TRAIN = 100)
EVALUATE_PER_N_TRAIN = 50)
```

The network structure is bi-directional GRU + Dense, and the hidden layer size can be adjusted by parameters. There is no stopping condition for training, you need stop manually.
Expand All @@ -33,8 +33,8 @@ Use trained weight:

```julia
generate_midi(input_file::String;
output_file="./out.mid",
weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"),
output_file="",
weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"),
HIDDEN_SIZE=14)
```

Expand Down
Binary file added model/model-0.91502.jld2
Binary file not shown.
9 changes: 6 additions & 3 deletions src/data_processing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ function get_train_dataloaders(dataset_path::String; batch_size=10, seq_length=2
push!(feature_result,stack(stack.(partition(features,seq_length,1))))
push!(label_result,stack(stack.(partition(labels,seq_length,1))))
end
return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2)); batchsize = batch_size, shuffle=true, parallel = true)
return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2));
batchsize = batch_size, shuffle=true, parallel = true)
end

"""
Expand All @@ -96,7 +97,9 @@ Predict left hand or right hand by output.
"""
predict_y(y) = y > 0.5f0 ? 1 : 0

function generate_midi(input_file::String;output_file="./out.mid",weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"),HIDDEN_SIZE=14)
function generate_midi(input_file::String; output_file::String="",
weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"),HIDDEN_SIZE=14)

midi_file = load(input_file)
hand_classify = inferance_midi(midi_file,weight_file,HIDDEN_SIZE)

Expand All @@ -115,5 +118,5 @@ function generate_midi(input_file::String;output_file="./out.mid",weight_file=pk
addnotes!(track_rh, notes_rh)
addtrackname!(track_rh, "piano right")
push!(new_midi_file.tracks, track_lh, track_rh)
save(output_file, new_midi_file)
save(isempty(output_file) ? first(splitext(input_file))*"_out.mid" : output_file, new_midi_file)
end
35 changes: 13 additions & 22 deletions src/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,26 @@ end

function train_piano(DATASET_PATH,
TESTSET_PATH;
BATCH_SIZE = 10,
SEQ_LENGTH = 70,
BATCH_SIZE = 12,
SEQ_LENGTH = 75,
HIDDEN_SIZE = 14,
LEARNING_RATE = 0.0002f0,
LEARNING_RATE = 0.0005f0,
MAX_EPOCH = 200,
EVALUATE_PER_N_TRAIN = 100)
EVALUATE_PER_N_TRAIN = 50)

dev = gpu_device()

# Get the dataloaders
train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH)
train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH) |> dev
val_x, val_y = get_val_datas(TESTSET_PATH)

# Create the model
model = build_model(GRUCell,HIDDEN_SIZE)
display(model)
rng = Xoshiro(0)
dev = gpu_device()
train_state = Lux.Experimental.TrainState(rng, model, Adam(LEARNING_RATE); transform_variables=dev)

ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st,Adam(LEARNING_RATE))

logitbce = BinaryCrossEntropyLoss();
loss_fn(ŷ,y) = sum(logitbce.(vec.(ŷ),eachslice(y;dims=1)))
Expand All @@ -40,28 +43,21 @@ function train_piano(DATASET_PATH,
loss_sum = 0
# Train the model
for (x,y) in train_loader
x = x |> dev
y = y |> dev

(_, loss, _, train_state) = Lux.Experimental.single_train_step!(
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), compute_loss, (x, y), train_state)

i+=1
loss_sum += loss
if i % EVALUATE_PER_N_TRAIN == 0
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss_sum/i

# Validate the model
st_ = Lux.testmode(train_state.states)
matchs = 0
note_count = mapreduce(length,+,val_y)
loss_sum_in = 0
for (x, y) in zip(val_x,val_y)
x = reshape(x, Val(3)) |> dev
x = x |> dev
y = y |> dev

ŷ, st_ = model(x, train_state.parameters, st_)

loss_sum_in += loss_fn(ŷ, y)
matchs += matches_num(vcat(ŷ...),y)
end
Expand All @@ -77,25 +73,20 @@ function train_piano(DATASET_PATH,
heightest_acc = acc
end
end


end
i = 1
end
end

function inferance_midi(midi_file::MIDIFile,weight_file::String,HIDDEN_SIZE::Int)::Vector{Int}
f = midi_to_features(midi_file)
x = reshape(stack(f), Val(3))

model = build_model(GRUCell,HIDDEN_SIZE)
display(model)
dev = gpu_device()
@load weight_file ps_trained st_trained
ps_trained,st_trained |> dev

st_ = Lux.testmode(st_trained)
y, st_ = model(x, ps_trained, st_)
y, st_ = model(stack(midi_to_features(midi_file)), ps_trained, st_)
y |> cpu_device()
return (predict_y first).(y)
end
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ using Test,PianoHands,MIDI,Lux,Random,Printf,LuxCUDA,Optimisers,Zygote,JLD2
@testset "pig to feature" begin
# train_piano("../PianoFingeringDataset/dataset/",
# "../PianoFingeringDataset/testset/";
# SEQ_LENGTH=70,
# BATCH_SIZE=10,
# SEQ_LENGTH=75,
# BATCH_SIZE=12,
# LEARNING_RATE = 0.0005f0,
# HIDDEN_SIZE = 14,
# EVALUATE_PER_N_TRAIN = 100
# EVALUATE_PER_N_TRAIN = 50
# )

# generate_midi("./ymsn_full.mid";weight_file="./14trained_model-0.92757.jld2",HIDDEN_SIZE=14)
generate_midi("./ymsn_full.mid";weight_file="../model/model-0.91502.jld2")
end
Binary file removed weight/weight-0.92757.jld2
Binary file not shown.

0 comments on commit 5d33d63

Please sign in to comment.