Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

main: port basic LLaVA (multimodal) support from llava-cli #5730

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,11 @@ clean:
# Helper function that replaces .c, .cpp, and .cu file endings with .o:
GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1))))

main: examples/main/main.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
main: examples/main/main.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp)
$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./main -h for help. ===='
@echo
Expand Down
2 changes: 1 addition & 1 deletion build.zig
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub fn build(b: *std.build.Builder) !void {
const clip = make.obj("clip", "examples/llava/clip.cpp");
const llava = make.obj("llava", "examples/llava/llava.cpp");

_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser, clip, llava });
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
Expand Down
3 changes: 3 additions & 0 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#include "ggml-metal.h"
#endif

#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
#pragma GCC diagnostic ignored "-Wcast-qual"
#endif
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

Expand Down
42 changes: 22 additions & 20 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,43 +372,45 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c
return result;
}

static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) {
static bool load_image_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) {
auto file = fopen(path, "rb");
if (file == NULL) {
LOG_TEE("%s: can't read file %s\n", __func__, path);
return false;
}

fseek(file, 0, SEEK_END);
auto fileSize = ftell(file);
fseek(file, 0, SEEK_SET);
const size_t limit = 128 * 1024 * 1024; // File size should be less than this
// Instead of trying to get file size, let's allow reading from devices which cannot provide it (fifo, sockets)

auto buffer = (unsigned char*)malloc(limit); // Allocate memory to hold the file data

auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
if (buffer == NULL) {
LOG_TEE("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
perror("Memory allocation error");
fclose(file);
return false;
}
errno = 0;
size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("read error: %s", strerror(errno));
}
if (ret != (size_t) fileSize) {
die("unexpectedly reached end of file");
size_t total = 0;
while (true) {
size_t ret = fread(buffer + total, 1, limit - total, file); // Read the file into the buffer
if (ferror(file)) {
die_fmt("%s: read error: %s", __func__, strerror(errno));
}
total += ret;
if (total >= limit) {
die_fmt("%s: file too big: %zu bytes or higher", __func__, limit);
}
if (feof(file)) {
break;
}
}
fclose(file); // Close the file

*bytesOut = buffer;
*sizeOut = fileSize;
// Fix memory allocation size
*bytesOut = (unsigned char*)realloc(buffer, total);
*sizeOut = total;
return true;
}

struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) {
unsigned char* image_bytes;
long image_bytes_length;
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
auto loaded = load_image_to_bytes(image_path, &image_bytes, &image_bytes_length);
if (!loaded) {
LOG_TEE("%s: failed to load %s\n", __func__, image_path);
return NULL;
Expand Down
2 changes: 1 addition & 1 deletion examples/main/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set(TARGET main)
add_executable(${TARGET} main.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
90 changes: 73 additions & 17 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "common.h"

#include "../llava/clip.h"
#include "../llava/llava.h"
#include "console.h"
#include "llama.h"

Expand Down Expand Up @@ -194,6 +196,9 @@ int main(int argc, char ** argv) {
g_model = &model;
g_ctx = &ctx;

clip_ctx* ctx_clip = nullptr;
llava_image_embed* image_embed = nullptr;

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
Expand All @@ -207,6 +212,27 @@ int main(int argc, char ** argv) {
return 1;
}

if (!params.image.empty() && params.mmproj.empty()) {
LOG_TEE("%s: error: image specified without mmproj\n", __func__);
return 1;
}

if (!params.mmproj.empty()) {
ctx_clip = clip_model_load(params.mmproj.c_str(), /*verbosity=*/1);
if (!ctx_clip) {
LOG_TEE("%s: error: failed to load mmproj (CLIP)\n", __func__);
return 1;
}

if (!params.image.empty()) {
image_embed = llava_image_embed_make_with_filename(ctx_clip, params.n_threads, params.image.c_str());
if (!image_embed) {
LOG_TEE("%s: error: failed to load image\n", __func__);
return 1;
}
}
}

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
LOG("n_ctx: %d\n", n_ctx);
Expand Down Expand Up @@ -250,13 +276,22 @@ int main(int argc, char ** argv) {
LOG("add_bos: %d\n", add_bos);

std::vector<llama_token> embd_inp;
int embd_img_pos = -1;

if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
if (params.chatml) {
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
}
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
const auto epos = params.prompt.find("<image>");
if (epos + 1 && image_embed) {
embd_inp = ::llama_tokenize(ctx, params.prompt.substr(0, epos), true, true);
embd_img_pos = embd_inp.size();
auto end = ::llama_tokenize(ctx, params.prompt.substr(epos + 7), false, true);
embd_inp.insert(embd_inp.end(), end.begin(), end.end());
} else {
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
}
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
Expand Down Expand Up @@ -333,8 +368,10 @@ int main(int argc, char ** argv) {
}

// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
bool n_keep_full = false;
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct || params.chatml) {
params.n_keep = (int)embd_inp.size();
n_keep_full = true;
} else {
params.n_keep += add_bos; // always keep the BOS token
}
Expand Down Expand Up @@ -454,6 +491,10 @@ int main(int argc, char ** argv) {
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
// Extend n_keep with embedded image size (there is an edge case with
// explicit n_keep that it must include at least 1 token after img)
if (embd_img_pos >= 0 && (params.n_keep > embd_img_pos || n_keep_full))
params.n_keep += image_embed->n_image_pos;

// group-attention state
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
Expand Down Expand Up @@ -659,26 +700,36 @@ int main(int argc, char ** argv) {
}
}

for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
auto decode_tokens = [&](int start, int count) -> void {
if (count == -1)
count = embd.size() - start;
for (int i = start; i < count; i += params.n_batch) {
int n_eval = count - i;
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}

LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0));

n_past += n_eval;
n_past += n_eval;

LOG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
LOG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
}
}
};

if (embd_img_pos >= 0) {
decode_tokens(0, embd_img_pos);
llava_eval_image_embed(ctx, image_embed, params.n_batch, &n_past);
decode_tokens(embd_img_pos, -1);
embd_img_pos = -1;
} else {
decode_tokens(0, embd.size());
}

if (!embd.empty() && !path_session.empty()) {
Expand Down Expand Up @@ -943,6 +994,11 @@ int main(int argc, char ** argv) {
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);

if (ctx_guidance) { llama_free(ctx_guidance); }

if (image_embed)
llava_image_embed_free(image_embed);
if (ctx_clip)
clip_free(ctx_clip);
llama_free(ctx);
llama_free_model(model);

Expand Down