llama/ggml: add LLM training support #10544
Draft
+767
−297
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
See ggerganov/ggml#1025 except I decided to implement the training directly in llama.cpp after all because the GPT-2 GGML example is already pretty complex, would require a significant amount of effort to refactor, and I'm not familiar with the codebase at all.
The goal of this PR is to add general training support to llama.cpp using
ggml_opt
. CPU training seems to work, other backends are missing support for some GGML ops. It's currently not possible to actually save the finetuned model to disk but you can confirm that the finetuning works by doing one epoch over the input text prior to perplexity calculation (or by observing how the loss goes down with the new finetune example). One epoch over the test set of Wikitext-2 (with the stride chosen in such a way that each token is used twice per epoch) currently takes ~1 minute with Stories 260k or ~20 hours and ~100 GB RAM with LLaMA 3 8b. For the user-facing API my concrete plans are:n_ctx
determines the max. sequence length with which the model is trained.n_batch
determines how many tokens are consumed per optimizer step.n_ubatch
determines the number of tokens in parallel, enables speed <-> memory use tradeoff, should have no effect on the result except for differences in floating point rounding error.std::vector<llama_token>
. Currently I have this as part ofllama.h
but maybe this would make more sense to put incommon.h
?llama_opt_init
that prepares allama_context
for training and lets the user define things like the learning rate or which tensors should be trainable parameters.llama_opt_epoch
that performs one epoch over aggml_opt_dataset
, equivalent toggml_opt_epoch
.llama_opt_fit
equivalent toggml_opt_fit
that is even more high-level?Currently, while functional, the PR is in a bad state in terms of software design and is in need of a refactor. The reason I'm already opening it now is because I want to ask for advice regarding how to best implement
llama_opt_epoch
. My current approach was to try and hijack the first half ofllama_decode_internal
but I found that in the end all I needed from it was the generation of the nextllama_ubatch
and the corresponding manipulation of the KV cache. But maybe it would make more sense to instead write a function likellama_prepare_next_ubatch
and to use that function inllama_decode_internal
andllama_opt_epoch
?