@@ -37,13 +37,16 @@ extern "C" {
37
37
// ====== Dataset ======
38
38
39
39
GGML_API ggml_opt_dataset_t ggml_opt_dataset_init (
40
- int64_t ne_datapoint , // number of elements per datapoint
41
- int64_t ne_label , // number of elements per label
42
- int64_t ndata , // total number of datapoints/labels
43
- int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
40
+ enum ggml_type type_data , // the type for the internal data tensor
41
+ enum ggml_type type_label , // the type for the internal labels tensor
42
+ int64_t ne_datapoint , // number of elements per datapoint
43
+ int64_t ne_label , // number of elements per label
44
+ int64_t ndata , // total number of datapoints/labels
45
+ int64_t ndata_shard ); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
44
46
GGML_API void ggml_opt_dataset_free (ggml_opt_dataset_t dataset );
45
47
46
48
// get underlying tensors that store the data
49
+ GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset );
47
50
GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset ); // shape = [ne_datapoint, ndata]
48
51
GGML_API struct ggml_tensor * ggml_opt_dataset_labels (ggml_opt_dataset_t dataset ); // shape = [nd_label, ndata]
49
52
@@ -56,6 +59,12 @@ extern "C" {
56
59
struct ggml_tensor * data_batch , // shape = [ne_datapoint, ndata_batch]
57
60
struct ggml_tensor * labels_batch , // shape = [ne_label, ndata_batch]
58
61
int64_t ibatch );
62
+ GGML_API void ggml_opt_dataset_get_batch_host (
63
+ ggml_opt_dataset_t dataset ,
64
+ void * data_batch ,
65
+ size_t nb_data_batch ,
66
+ void * labels_batch ,
67
+ int64_t ibatch );
59
68
60
69
// ====== Model / Context ======
61
70
@@ -92,7 +101,8 @@ extern "C" {
92
101
struct ggml_context * ctx_compute ; // created in user code, holds non-static tensors
93
102
94
103
// the forward graph is defined by inputs and outputs
95
- // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
104
+ // the outputs and all tensors between inputs and outputs that have not been statically allocated
105
+ // are not intended to be reusable between multiple optimization contexts
96
106
struct ggml_tensor * inputs ;
97
107
struct ggml_tensor * outputs ;
98
108
@@ -107,7 +117,7 @@ extern "C" {
107
117
108
118
// get parameters for an optimization context with defaults set where possible
109
119
// parameters for which no sensible defaults exist are supplied as arguments to this function
110
- GGML_API ggml_opt_params ggml_opt_default_params (
120
+ GGML_API struct ggml_opt_params ggml_opt_default_params (
111
121
ggml_backend_sched_t backend_sched ,
112
122
struct ggml_context * ctx_compute ,
113
123
struct ggml_tensor * inputs ,
@@ -144,6 +154,10 @@ extern "C" {
144
154
145
155
// ====== Computation ======
146
156
157
+ GGML_API void ggml_opt_set_forward_graph (
158
+ ggml_opt_context_t opt_ctx , struct ggml_context * ctx_compute , struct ggml_cgraph * gf ,
159
+ struct ggml_tensor * inputs , struct ggml_tensor * outputs , bool backward );
160
+
147
161
// do forward pass, increment result if not NULL
148
162
GGML_API void ggml_opt_forward (ggml_opt_context_t opt_ctx , ggml_opt_result_t result );
149
163
@@ -200,9 +214,9 @@ extern "C" {
200
214
// fit model defined by inputs and outputs to dataset
201
215
GGML_API void ggml_opt_fit (
202
216
ggml_backend_sched_t backend_sched , // backend scheduler for constructing the compute graphs
203
- ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
204
- ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
205
- ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
217
+ struct ggml_context * ctx_compute , // context with temporarily allocated tensors to calculate the outputs
218
+ struct ggml_tensor * inputs , // input tensor with shape [ne_datapoint, ndata_batch]
219
+ struct ggml_tensor * outputs , // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206
220
ggml_opt_dataset_t dataset , // dataset with data and optionally also labels
207
221
enum ggml_opt_loss_type loss_type , // loss to minimize
208
222
ggml_opt_get_optimizer_params get_opt_pars , // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
0 commit comments