diff --git a/Project.toml b/Project.toml index 99a6137..2433575 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PianoHands" uuid = "74435128-bd9d-4d82-978b-bd768beb391e" authors = ["NeroBlackstone "] -version = "0.1.0" +version = "0.2.0" [deps] CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" @@ -26,7 +26,7 @@ test = ["Test"] CodecZlib = "0.7" IterTools = "1.10" JLD2 = "0.4" -Lux = "0.5" +Lux = "1.0" LuxCUDA = "0.3" MIDI = "2.7" MLUtils = "0.4" diff --git a/README.md b/README.md index f4822a1..2d69b05 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,30 @@ # PianoHands.jl Predicting hand assignments in piano MIDI using neural networks -# Use Pre-trained weight +# Use pre-trained model ``` julia using PianoHands generate_midi("./your_midi.mid";) ``` -You will get a midi file `out.mid`, track 1 is left hand notes, track 2 is right hand notes. +You will get a midi file `your_midi_out.mid`, track 1 is left hand notes, track 2 is right hand notes. -# Train Your own weight. +# Train Your own model. ## Dataset preparation Download PIG v1.2 Dataset to `PianoFingeringDataset` and remove duplicate fingering file, approximately 150 fingering files are required. ``` julia -train_piano(DATASET_PATH, +function train_piano(DATASET_PATH, TESTSET_PATH; - BATCH_SIZE = 10, - SEQ_LENGTH = 70, + BATCH_SIZE = 12, + SEQ_LENGTH = 75, HIDDEN_SIZE = 14, - LEARNING_RATE = 0.0002f0, + LEARNING_RATE = 0.0005f0, MAX_EPOCH = 200, - EVALUATE_PER_N_TRAIN = 100) + EVALUATE_PER_N_TRAIN = 50) ``` The network structure is bi-directional GRU + Dense, and the hidden layer size can be adjusted by parameters. There is no stopping condition for training, you need stop manually. @@ -33,8 +33,8 @@ Use trained weight: ```julia generate_midi(input_file::String; - output_file="./out.mid", - weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"), + output_file="", + weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"), HIDDEN_SIZE=14) ``` diff --git a/model/model-0.91502.jld2 b/model/model-0.91502.jld2 new file mode 100644 index 0000000..90edf20 Binary files /dev/null and b/model/model-0.91502.jld2 differ diff --git a/src/data_processing.jl b/src/data_processing.jl index 68915ed..dfe9f01 100644 --- a/src/data_processing.jl +++ b/src/data_processing.jl @@ -69,7 +69,8 @@ function get_train_dataloaders(dataset_path::String; batch_size=10, seq_length=2 push!(feature_result,stack(stack.(partition(features,seq_length,1)))) push!(label_result,stack(stack.(partition(labels,seq_length,1)))) end - return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2)); batchsize = batch_size, shuffle=true, parallel = true) + return DataLoader((cat(feature_result...;dims=3),cat(label_result...;dims=2)); + batchsize = batch_size, shuffle=true, parallel = true) end """ @@ -96,7 +97,9 @@ Predict left hand or right hand by output. """ predict_y(y) = y > 0.5f0 ? 1 : 0 -function generate_midi(input_file::String;output_file="./out.mid",weight_file=pkgdir(PianoHands,"weight","weight-0.92757.jld2"),HIDDEN_SIZE=14) +function generate_midi(input_file::String; output_file::String="", + weight_file=pkgdir(PianoHands,"model","model-0.91502.jld2"),HIDDEN_SIZE=14) + midi_file = load(input_file) hand_classify = inferance_midi(midi_file,weight_file,HIDDEN_SIZE) @@ -115,5 +118,5 @@ function generate_midi(input_file::String;output_file="./out.mid",weight_file=pk addnotes!(track_rh, notes_rh) addtrackname!(track_rh, "piano right") push!(new_midi_file.tracks, track_lh, track_rh) - save(output_file, new_midi_file) + save(isempty(output_file) ? first(splitext(input_file))*"_out.mid" : output_file, new_midi_file) end \ No newline at end of file diff --git a/src/training.jl b/src/training.jl index 8059811..4066b9d 100644 --- a/src/training.jl +++ b/src/training.jl @@ -7,23 +7,26 @@ end function train_piano(DATASET_PATH, TESTSET_PATH; - BATCH_SIZE = 10, - SEQ_LENGTH = 70, + BATCH_SIZE = 12, + SEQ_LENGTH = 75, HIDDEN_SIZE = 14, - LEARNING_RATE = 0.0002f0, + LEARNING_RATE = 0.0005f0, MAX_EPOCH = 200, - EVALUATE_PER_N_TRAIN = 100) + EVALUATE_PER_N_TRAIN = 50) + + dev = gpu_device() # Get the dataloaders - train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH) + train_loader = get_train_dataloaders(DATASET_PATH;batch_size=BATCH_SIZE, seq_length=SEQ_LENGTH) |> dev val_x, val_y = get_val_datas(TESTSET_PATH) # Create the model model = build_model(GRUCell,HIDDEN_SIZE) display(model) rng = Xoshiro(0) - dev = gpu_device() - train_state = Lux.Experimental.TrainState(rng, model, Adam(LEARNING_RATE); transform_variables=dev) + + ps, st = Lux.setup(rng, model) |> dev + train_state = Training.TrainState(model, ps, st,Adam(LEARNING_RATE)) logitbce = BinaryCrossEntropyLoss(); loss_fn(ŷ,y) = sum(logitbce.(vec.(ŷ),eachslice(y;dims=1))) @@ -40,28 +43,21 @@ function train_piano(DATASET_PATH, loss_sum = 0 # Train the model for (x,y) in train_loader - x = x |> dev - y = y |> dev - - (_, loss, _, train_state) = Lux.Experimental.single_train_step!( + (_, loss, _, train_state) = Training.single_train_step!( AutoZygote(), compute_loss, (x, y), train_state) - i+=1 loss_sum += loss if i % EVALUATE_PER_N_TRAIN == 0 @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss_sum/i - # Validate the model st_ = Lux.testmode(train_state.states) matchs = 0 note_count = mapreduce(length,+,val_y) loss_sum_in = 0 for (x, y) in zip(val_x,val_y) - x = reshape(x, Val(3)) |> dev + x = x |> dev y = y |> dev - ŷ, st_ = model(x, train_state.parameters, st_) - loss_sum_in += loss_fn(ŷ, y) matchs += matches_num(vcat(ŷ...),y) end @@ -77,17 +73,12 @@ function train_piano(DATASET_PATH, heightest_acc = acc end end - - end i = 1 end end function inferance_midi(midi_file::MIDIFile,weight_file::String,HIDDEN_SIZE::Int)::Vector{Int} - f = midi_to_features(midi_file) - x = reshape(stack(f), Val(3)) - model = build_model(GRUCell,HIDDEN_SIZE) display(model) dev = gpu_device() @@ -95,7 +86,7 @@ function inferance_midi(midi_file::MIDIFile,weight_file::String,HIDDEN_SIZE::Int ps_trained,st_trained |> dev st_ = Lux.testmode(st_trained) - y, st_ = model(x, ps_trained, st_) + y, st_ = model(stack(midi_to_features(midi_file)), ps_trained, st_) y |> cpu_device() return (predict_y ∘ first).(y) end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 097a364..cb680ac 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,12 +3,12 @@ using Test,PianoHands,MIDI,Lux,Random,Printf,LuxCUDA,Optimisers,Zygote,JLD2 @testset "pig to feature" begin # train_piano("../PianoFingeringDataset/dataset/", # "../PianoFingeringDataset/testset/"; - # SEQ_LENGTH=70, - # BATCH_SIZE=10, + # SEQ_LENGTH=75, + # BATCH_SIZE=12, # LEARNING_RATE = 0.0005f0, # HIDDEN_SIZE = 14, - # EVALUATE_PER_N_TRAIN = 100 + # EVALUATE_PER_N_TRAIN = 50 # ) - # generate_midi("./ymsn_full.mid";weight_file="./14trained_model-0.92757.jld2",HIDDEN_SIZE=14) + generate_midi("./ymsn_full.mid";weight_file="../model/model-0.91502.jld2") end \ No newline at end of file diff --git a/weight/weight-0.92757.jld2 b/weight/weight-0.92757.jld2 deleted file mode 100644 index bb4d899..0000000 Binary files a/weight/weight-0.92757.jld2 and /dev/null differ