Skip to content

Commit 382ec20

Browse files
llama/ggml: add LLM training support
1 parent 46c69e0 commit 382ec20

14 files changed

+767
-297
lines changed

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ else()
4646
add_subdirectory(speculative)
4747
add_subdirectory(speculative-simple)
4848
add_subdirectory(tokenize)
49+
add_subdirectory(training)
4950
if (NOT GGML_BACKEND_DL)
5051
# these examples use the backends directly and cannot be built with dynamic loading
5152
add_subdirectory(convert-llama2c-to-ggml)

examples/training/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-finetune)
2+
add_executable(${TARGET} finetune.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/training/finetune.cpp

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#include "arg.h"
2+
#include "common.h"
3+
#include "log.h"
4+
#include "llama.h"
5+
6+
#include <cmath>
7+
#include <cstdio>
8+
#include <cstring>
9+
#include <ctime>
10+
#include <vector>
11+
12+
#if defined(_MSC_VER)
13+
#pragma warning(disable: 4244 4267) // possible loss of data
14+
#endif
15+
16+
static std::vector<float> softmax(const std::vector<float>& logits) {
17+
std::vector<float> probs(logits.size());
18+
float max_logit = logits[0];
19+
for (float v : logits) {
20+
max_logit = std::max(max_logit, v);
21+
}
22+
double sum_exp = 0.0;
23+
for (size_t i = 0; i < logits.size(); i++) {
24+
// Subtract the maximum logit value from the current logit value for numerical stability
25+
const float logit = logits[i] - max_logit;
26+
const float exp_logit = expf(logit);
27+
sum_exp += exp_logit;
28+
probs[i] = exp_logit;
29+
}
30+
for (size_t i = 0; i < probs.size(); i++) {
31+
probs[i] /= sum_exp;
32+
}
33+
return probs;
34+
}
35+
36+
int main(int argc, char ** argv) {
37+
common_params params;
38+
39+
params.logits_all = true;
40+
params.escape = false;
41+
42+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
43+
return 1;
44+
}
45+
46+
if (params.use_mmap) {
47+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
48+
params.use_mmap = false;
49+
}
50+
if (params.cache_type_k == "f16") {
51+
LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
52+
params.cache_type_k = "f32";
53+
}
54+
if (params.cache_type_v == "f16") {
55+
LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
56+
params.cache_type_v = "f32";
57+
}
58+
59+
common_init();
60+
llama_backend_init();
61+
llama_numa_init(params.numa);
62+
63+
// load the model and apply lora adapter, if any
64+
common_init_result llama_init = common_init_from_params(params);
65+
llama_model * model = llama_init.model;
66+
llama_context * ctx = llama_init.context;
67+
68+
if (model == NULL) {
69+
LOG_ERR("%s: unable to load model\n", __func__);
70+
return 1;
71+
}
72+
73+
// print system information
74+
{
75+
LOG_INF("\n");
76+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
77+
}
78+
79+
constexpr float val_split = 0.05f;
80+
81+
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
82+
ggml_opt_dataset_t dataset = llama_opt_dataset_init(ctx, tokens.data(), tokens.size(), llama_n_ctx(ctx)/2);
83+
llama_opt_init(ctx);
84+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
85+
86+
while (true) {
87+
ggml_opt_result_t result_train = ggml_opt_result_init();
88+
ggml_opt_result_t result_eval = ggml_opt_result_init();
89+
90+
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
91+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
92+
fprintf(stderr, "\n");
93+
94+
ggml_opt_result_free(result_train);
95+
ggml_opt_result_free(result_eval);
96+
}
97+
98+
LOG("\n");
99+
llama_perf_context_print(ctx);
100+
101+
llama_free(ctx);
102+
llama_free_model(model);
103+
104+
llama_backend_free();
105+
106+
return 0;
107+
}

ggml/include/ggml-opt.h

+23-9
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,16 @@ extern "C" {
3737
// ====== Dataset ======
3838

3939
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40-
int64_t ne_datapoint, // number of elements per datapoint
41-
int64_t ne_label, // number of elements per label
42-
int64_t ndata, // total number of datapoints/labels
43-
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40+
enum ggml_type type_data, // the type for the internal data tensor
41+
enum ggml_type type_label, // the type for the internal labels tensor
42+
int64_t ne_datapoint, // number of elements per datapoint
43+
int64_t ne_label, // number of elements per label
44+
int64_t ndata, // total number of datapoints/labels
45+
int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
4446
GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
4547

4648
// get underlying tensors that store the data
49+
GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset);
4750
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
4851
GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
4952

@@ -56,6 +59,12 @@ extern "C" {
5659
struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
5760
struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
5861
int64_t ibatch);
62+
GGML_API void ggml_opt_dataset_get_batch_host(
63+
ggml_opt_dataset_t dataset,
64+
void * data_batch,
65+
size_t nb_data_batch,
66+
void * labels_batch,
67+
int64_t ibatch);
5968

6069
// ====== Model / Context ======
6170

@@ -92,7 +101,8 @@ extern "C" {
92101
struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93102

94103
// the forward graph is defined by inputs and outputs
95-
// those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
104+
// the outputs and all tensors between inputs and outputs that have not been statically allocated
105+
// are not intended to be reusable between multiple optimization contexts
96106
struct ggml_tensor * inputs;
97107
struct ggml_tensor * outputs;
98108

@@ -107,7 +117,7 @@ extern "C" {
107117

108118
// get parameters for an optimization context with defaults set where possible
109119
// parameters for which no sensible defaults exist are supplied as arguments to this function
110-
GGML_API ggml_opt_params ggml_opt_default_params(
120+
GGML_API struct ggml_opt_params ggml_opt_default_params(
111121
ggml_backend_sched_t backend_sched,
112122
struct ggml_context * ctx_compute,
113123
struct ggml_tensor * inputs,
@@ -144,6 +154,10 @@ extern "C" {
144154

145155
// ====== Computation ======
146156

157+
GGML_API void ggml_opt_set_forward_graph(
158+
ggml_opt_context_t opt_ctx, struct ggml_context * ctx_compute, struct ggml_cgraph * gf,
159+
struct ggml_tensor * inputs, struct ggml_tensor * outputs, bool backward);
160+
147161
// do forward pass, increment result if not NULL
148162
GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
149163

@@ -200,9 +214,9 @@ extern "C" {
200214
// fit model defined by inputs and outputs to dataset
201215
GGML_API void ggml_opt_fit(
202216
ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203-
ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204-
ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205-
ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
217+
struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
218+
struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
219+
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206220
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207221
enum ggml_opt_loss_type loss_type, // loss to minimize
208222
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)

ggml/include/ggml.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -1994,10 +1994,9 @@ extern "C" {
19941994

19951995
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
19961996
GGML_API void ggml_build_backward_expand(
1997-
struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
1998-
struct ggml_context * ctx_compute, // context for gradient computation
1999-
struct ggml_cgraph * cgraph,
2000-
bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
1997+
struct ggml_context * ctx, // context for gradient computation
1998+
struct ggml_cgraph * cgraph,
1999+
struct ggml_tensor ** grad_accs);
20012000

20022001
// graph allocation in a context
20032002
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false

ggml/src/ggml-backend.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
611611
#endif
612612

613613
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
614-
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
614+
#define GGML_SCHED_MAX_SPLIT_INPUTS 1024
615615
#endif
616616

617617
#ifndef GGML_SCHED_MAX_COPIES
@@ -1103,7 +1103,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
11031103

11041104
const int node_backend_id = tensor_backend_id(node);
11051105

1106-
assert(node_backend_id != -1); // all nodes should be assigned by now
1106+
assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback
11071107

11081108
// check if we should start a new split based on the sources of the current node
11091109
bool need_new_split = false;

0 commit comments

Comments
 (0)