Skip to content

Commit be9cf7c

Browse files
committed
Opt class for positional argument handling
Added support for positional arguments `MODEL` and `PROMPT`. Signed-off-by: Eric Curtin <[email protected]>
1 parent 46c69e0 commit be9cf7c

File tree

1 file changed

+92
-96
lines changed

1 file changed

+92
-96
lines changed

examples/run/run.cpp

+92-96
Original file line numberDiff line numberDiff line change
@@ -4,109 +4,114 @@
44
#include <unistd.h>
55
#endif
66

7-
#include <climits>
87
#include <cstdio>
98
#include <cstring>
109
#include <iostream>
1110
#include <sstream>
1211
#include <string>
13-
#include <unordered_map>
1412
#include <vector>
1513

1614
#include "llama-cpp.h"
1715

1816
typedef std::unique_ptr<char[]> char_array_ptr;
1917

20-
struct Argument {
21-
std::string flag;
22-
std::string help_text;
23-
};
24-
25-
struct Options {
26-
std::string model_path, prompt_non_interactive;
27-
int ngl = 99;
28-
int n_ctx = 2048;
29-
};
18+
class Opt {
19+
public:
20+
int init_opt(int argc, const char ** argv) {
21+
construct_help_str_();
22+
// Parse arguments
23+
if (parse(argc, argv)) {
24+
fprintf(stderr, "Error: Failed to parse arguments.\n");
25+
help();
26+
return 1;
27+
}
3028

31-
class ArgumentParser {
32-
public:
33-
ArgumentParser(const char * program_name) : program_name(program_name) {}
29+
// If help is requested, show help and exit
30+
if (help_) {
31+
help();
32+
return 2;
33+
}
3434

35-
void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") {
36-
string_args[flag] = &var;
37-
arguments.push_back({flag, help_text});
35+
return 0; // Success
3836
}
3937

40-
void add_argument(const std::string & flag, int & var, const std::string & help_text = "") {
41-
int_args[flag] = &var;
42-
arguments.push_back({flag, help_text});
38+
const char * model_ = nullptr;
39+
std::string prompt_;
40+
int context_size_ = 2048, ngl_ = 0;
41+
42+
private:
43+
std::string help_str_;
44+
bool help_ = false;
45+
46+
void construct_help_str_() {
47+
help_str_ =
48+
"Description:\n"
49+
" Runs a llm\n"
50+
"\n"
51+
"Usage:\n"
52+
" llama-run [options] MODEL [PROMPT]\n"
53+
"\n"
54+
"Options:\n"
55+
" -c, --context-size <value>\n"
56+
" Context size (default: " +
57+
std::to_string(context_size_);
58+
help_str_ +=
59+
")\n"
60+
" -n, --ngl <value>\n"
61+
" Number of GPU layers (default: " +
62+
std::to_string(ngl_);
63+
help_str_ +=
64+
")\n"
65+
" -h, --help\n"
66+
" Show help message\n"
67+
"\n"
68+
"Examples:\n"
69+
" llama-run your_model.gguf\n"
70+
" llama-run --ngl 99 your_model.gguf\n"
71+
" llama-run --ngl 99 your_model.gguf Hello World\n";
4372
}
4473

4574
int parse(int argc, const char ** argv) {
75+
if (parse_arguments(argc, argv)) {
76+
return 1;
77+
}
78+
79+
return 0;
80+
}
81+
82+
int parse_arguments(int argc, const char ** argv) {
83+
int positional_args_i = 0;
4684
for (int i = 1; i < argc; ++i) {
47-
std::string arg = argv[i];
48-
if (string_args.count(arg)) {
49-
if (i + 1 < argc) {
50-
*string_args[arg] = argv[++i];
51-
} else {
52-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
53-
print_usage();
85+
if (std::strcmp(argv[i], "-c") == 0 || std::strcmp(argv[i], "--context-size") == 0) {
86+
if (i + 1 >= argc) {
5487
return 1;
5588
}
56-
} else if (int_args.count(arg)) {
57-
if (i + 1 < argc) {
58-
if (parse_int_arg(argv[++i], *int_args[arg]) != 0) {
59-
fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]);
60-
print_usage();
61-
return 1;
62-
}
63-
} else {
64-
fprintf(stderr, "error: missing value for %s\n", arg.c_str());
65-
print_usage();
89+
90+
context_size_ = std::atoi(argv[++i]);
91+
} else if (std::strcmp(argv[i], "-n") == 0 || std::strcmp(argv[i], "--ngl") == 0) {
92+
if (i + 1 >= argc) {
6693
return 1;
6794
}
95+
96+
ngl_ = std::atoi(argv[++i]);
97+
} else if (std::strcmp(argv[i], "-h") == 0 || std::strcmp(argv[i], "--help") == 0) {
98+
help_ = true;
99+
return 0;
100+
} else if (!positional_args_i) {
101+
++positional_args_i;
102+
model_ = argv[i];
103+
} else if (positional_args_i == 1) {
104+
++positional_args_i;
105+
prompt_ = argv[i];
68106
} else {
69-
fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str());
70-
print_usage();
71-
return 1;
107+
prompt_ += " " + std::string(argv[i]);
72108
}
73109
}
74110

75-
if (string_args["-m"]->empty()) {
76-
fprintf(stderr, "error: -m is required\n");
77-
print_usage();
78-
return 1;
79-
}
80-
81-
return 0;
111+
return !model_; // model_ is the only required value
82112
}
83113

84-
private:
85-
const char * program_name;
86-
std::unordered_map<std::string, std::string *> string_args;
87-
std::unordered_map<std::string, int *> int_args;
88-
std::vector<Argument> arguments;
89-
90-
int parse_int_arg(const char * arg, int & value) {
91-
char * end;
92-
const long val = std::strtol(arg, &end, 10);
93-
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
94-
value = static_cast<int>(val);
95-
return 0;
96-
}
97-
return 1;
98-
}
99-
100-
void print_usage() const {
101-
printf("\nUsage:\n");
102-
printf(" %s [OPTIONS]\n\n", program_name);
103-
printf("Options:\n");
104-
for (const auto & arg : arguments) {
105-
printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str());
106-
}
107-
108-
printf("\n");
109-
}
114+
void help() const { printf("%s", help_str_.c_str()); }
110115
};
111116

112117
class LlamaData {
@@ -116,13 +121,13 @@ class LlamaData {
116121
llama_context_ptr context;
117122
std::vector<llama_chat_message> messages;
118123

119-
int init(const Options & opt) {
120-
model = initialize_model(opt.model_path, opt.ngl);
124+
int init(const Opt & opt) {
125+
model = initialize_model(opt.model_, opt.ngl_);
121126
if (!model) {
122127
return 1;
123128
}
124129

125-
context = initialize_context(model, opt.n_ctx);
130+
context = initialize_context(model, opt.context_size_);
126131
if (!context) {
127132
return 1;
128133
}
@@ -134,6 +139,7 @@ class LlamaData {
134139
private:
135140
// Initializes the model and returns a unique pointer to it
136141
llama_model_ptr initialize_model(const std::string & model_path, const int ngl) {
142+
ggml_backend_load_all();
137143
llama_model_params model_params = llama_model_default_params();
138144
model_params.n_gpu_layers = ngl;
139145

@@ -273,19 +279,6 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
273279
return 0;
274280
}
275281

276-
static int parse_arguments(const int argc, const char ** argv, Options & opt) {
277-
ArgumentParser parser(argv[0]);
278-
parser.add_argument("-m", opt.model_path, "model");
279-
parser.add_argument("-p", opt.prompt_non_interactive, "prompt");
280-
parser.add_argument("-c", opt.n_ctx, "context_size");
281-
parser.add_argument("-ngl", opt.ngl, "n_gpu_layers");
282-
if (parser.parse(argc, argv)) {
283-
return 1;
284-
}
285-
286-
return 0;
287-
}
288-
289282
static int read_user_input(std::string & user) {
290283
std::getline(std::cin, user);
291284
return user.empty(); // Indicate an error or empty input
@@ -382,17 +375,20 @@ static std::string read_pipe_data() {
382375
}
383376

384377
int main(int argc, const char ** argv) {
385-
Options opt;
386-
if (parse_arguments(argc, argv, opt)) {
378+
Opt opt;
379+
const int opt_ret = opt.init_opt(argc, argv);
380+
if (opt_ret == 2) {
381+
return 0;
382+
} else if (opt_ret) {
387383
return 1;
388384
}
389385

390386
if (!is_stdin_a_terminal()) {
391-
if (!opt.prompt_non_interactive.empty()) {
392-
opt.prompt_non_interactive += "\n\n";
387+
if (!opt.prompt_.empty()) {
388+
opt.prompt_ += "\n\n";
393389
}
394390

395-
opt.prompt_non_interactive += read_pipe_data();
391+
opt.prompt_ += read_pipe_data();
396392
}
397393

398394
llama_log_set(log_callback, nullptr);
@@ -401,7 +397,7 @@ int main(int argc, const char ** argv) {
401397
return 1;
402398
}
403399

404-
if (chat_loop(llama_data, opt.prompt_non_interactive)) {
400+
if (chat_loop(llama_data, opt.prompt_)) {
405401
return 1;
406402
}
407403

0 commit comments

Comments
 (0)