|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <cmath> |
| 10 | +#include <cstring> |
| 11 | +#include <fstream> |
| 12 | + |
| 13 | +#include <gflags/gflags.h> |
| 14 | + |
| 15 | +#include <executorch/extension/llm/runner/audio.h> |
| 16 | +#include <executorch/extension/llm/runner/image.h> |
| 17 | +#include <executorch/extension/llm/runner/llm_runner_helper.h> |
| 18 | +#include <executorch/extension/llm/runner/multimodal_input.h> |
| 19 | +#include <executorch/extension/llm/runner/multimodal_runner.h> |
| 20 | +#include <executorch/runtime/core/error.h> |
| 21 | +#include <executorch/runtime/platform/log.h> |
| 22 | + |
| 23 | +#if defined(ET_USE_THREADPOOL) |
| 24 | +#include <executorch/extension/threadpool/cpuinfo_utils.h> |
| 25 | +#include <executorch/extension/threadpool/threadpool.h> |
| 26 | +#endif |
| 27 | + |
| 28 | +DEFINE_string( |
| 29 | + model_path, |
| 30 | + "multimodal.pte", |
| 31 | + "Model serialized in flatbuffer format."); |
| 32 | + |
| 33 | +DEFINE_string(tokenizer_path, "tekken.json", "Tokenizer stuff."); |
| 34 | + |
| 35 | +DEFINE_string(prompt, "What is happening in this audio?", "Text prompt."); |
| 36 | + |
| 37 | +DEFINE_string(audio_path, "", "Path to input audio file."); |
| 38 | + |
| 39 | +DEFINE_double( |
| 40 | + temperature, |
| 41 | + 0.8f, |
| 42 | + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); |
| 43 | + |
| 44 | +DEFINE_int32( |
| 45 | + cpu_threads, |
| 46 | + -1, |
| 47 | + "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); |
| 48 | + |
| 49 | +DEFINE_bool(warmup, false, "Whether to run a warmup run."); |
| 50 | + |
| 51 | +namespace { |
| 52 | + |
| 53 | +using ::executorch::extension::llm::Image; |
| 54 | +using ::executorch::extension::llm::make_image_input; |
| 55 | +using ::executorch::extension::llm::make_text_input; |
| 56 | +using ::executorch::extension::llm::MultimodalInput; |
| 57 | + |
| 58 | +bool ends_with(const std::string& str, const std::string& suffix) { |
| 59 | + return str.size() >= suffix.size() && |
| 60 | + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; |
| 61 | +} |
| 62 | + |
| 63 | +/** |
| 64 | + * @brief Loads preprocessed audio data from a binary file |
| 65 | + * |
| 66 | + * Reads mel spectrogram features that have been pre-computed and saved as a |
| 67 | + * binary file. The audio data is expected to be stored as float values in |
| 68 | + * binary format, typically saved using: |
| 69 | + * with open("tensor.bin", "wb") as f: |
| 70 | + * f.write(t.numpy().tobytes()) |
| 71 | + * |
| 72 | + * @param audio_path Path to the binary audio file (.bin) |
| 73 | + * @return MultimodalInput containing the loaded audio data |
| 74 | + */ |
| 75 | +MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { |
| 76 | + std::ifstream f(audio_path, std::ios::binary | std::ios::ate); |
| 77 | + int32_t n_bins = 128; |
| 78 | + int32_t n_frames = 3000; |
| 79 | + std::size_t n_floats = |
| 80 | + f.tellg() / sizeof(float); // Number of floats in the audio file. |
| 81 | + f.seekg(0, std::ios::beg); |
| 82 | + int32_t batch_size = ceil( |
| 83 | + n_floats / |
| 84 | + (n_bins * n_frames)); // Batch in increments of n_frames, rounding up. |
| 85 | + std::vector<float> audio_data(batch_size * n_bins * n_frames); |
| 86 | + f.read( |
| 87 | + reinterpret_cast<char*>(audio_data.data()), |
| 88 | + audio_data.size() * sizeof(float)); |
| 89 | + |
| 90 | + ET_LOG(Info, "audio_data len = %d", audio_data.size()); |
| 91 | + |
| 92 | + auto audio = std::make_unique<::executorch::extension::llm::Audio>(); |
| 93 | + audio->batch_size = batch_size; |
| 94 | + audio->n_bins = n_bins; |
| 95 | + audio->n_frames = n_frames; |
| 96 | + audio->data.resize(audio_data.size() * sizeof(float)); |
| 97 | + std::memcpy( |
| 98 | + audio->data.data(), audio_data.data(), audio_data.size() * sizeof(float)); |
| 99 | + return ::executorch::extension::llm::make_audio_input(std::move(*audio)); |
| 100 | +} |
| 101 | + |
| 102 | +/** |
| 103 | + * @brief Processes audio files for multimodal input |
| 104 | + * |
| 105 | + * Dispatches audio file processing based on file extension: |
| 106 | + * - .bin files: Loads preprocessed mel spectrogram features directly |
| 107 | + * - .wav/.mp3 files: Currently unsupported, throws runtime_error |
| 108 | + * |
| 109 | + * This function provides a interface for different audio input formats |
| 110 | + * and can be extended to support raw audio processing in the future. |
| 111 | + * |
| 112 | + * @param audio_path Path to the audio file |
| 113 | + * @return MultimodalInput containing the processed audio data |
| 114 | + * @throws std::runtime_error if file format is unsupported or processing fails |
| 115 | + */ |
| 116 | +MultimodalInput processAudioFile(const std::string& audio_path) { |
| 117 | + if (ends_with(audio_path, ".bin")) { |
| 118 | + // Current behavior - load preprocessed audio stored as a binary file. |
| 119 | + return loadPreprocessedAudio(audio_path); |
| 120 | + } else if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".mp3")) { |
| 121 | + // New: Process raw audio files - unsupported for now |
| 122 | + ET_LOG(Error, "Raw audio file processing (.wav/.mp3) is not yet supported"); |
| 123 | + throw std::runtime_error("Raw audio file processing not supported"); |
| 124 | + } else { |
| 125 | + ET_LOG(Error, "Unsupported audio file format: %s", audio_path.c_str()); |
| 126 | + throw std::runtime_error("Unsupported audio file format"); |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +} // namespace |
| 131 | + |
| 132 | +int32_t main(int32_t argc, char** argv) { |
| 133 | + gflags::ParseCommandLineFlags(&argc, &argv, true); |
| 134 | + |
| 135 | + const char* model_path = FLAGS_model_path.c_str(); |
| 136 | + |
| 137 | + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); |
| 138 | + const char* prompt = FLAGS_prompt.c_str(); |
| 139 | + const char* audio_path = FLAGS_audio_path.c_str(); |
| 140 | + float temperature = FLAGS_temperature; |
| 141 | + int32_t cpu_threads = FLAGS_cpu_threads; |
| 142 | + bool warmup = FLAGS_warmup; |
| 143 | + |
| 144 | +#if defined(ET_USE_THREADPOOL) |
| 145 | + uint32_t num_performant_cores = cpu_threads == -1 |
| 146 | + ? ::executorch::extension::cpuinfo::get_num_performant_cores() |
| 147 | + : static_cast<uint32_t>(cpu_threads); |
| 148 | + ET_LOG( |
| 149 | + Info, "Resetting threadpool with num threads = %d", num_performant_cores); |
| 150 | + if (num_performant_cores > 0) { |
| 151 | + ::executorch::extension::threadpool::get_threadpool() |
| 152 | + ->_unsafe_reset_threadpool(num_performant_cores); |
| 153 | + } |
| 154 | +#endif |
| 155 | + |
| 156 | + // Load tokenizer |
| 157 | + std::unique_ptr<::tokenizers::Tokenizer> tokenizer = |
| 158 | + ::executorch::extension::llm::load_tokenizer(tokenizer_path); |
| 159 | + if (tokenizer == nullptr) { |
| 160 | + ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path); |
| 161 | + return 1; |
| 162 | + } |
| 163 | + |
| 164 | + // Create multimodal runner |
| 165 | + std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner = |
| 166 | + ::executorch::extension::llm::create_multimodal_runner( |
| 167 | + model_path, std::move(tokenizer)); |
| 168 | + if (runner == nullptr) { |
| 169 | + ET_LOG(Error, "Failed to create multimodal runner"); |
| 170 | + return 1; |
| 171 | + } |
| 172 | + |
| 173 | + // Load runner |
| 174 | + auto load_error = runner->load(); |
| 175 | + if (load_error != ::executorch::runtime::Error::Ok) { |
| 176 | + ET_LOG(Error, "Failed to load multimodal runner"); |
| 177 | + return 1; |
| 178 | + } |
| 179 | + |
| 180 | + // Prepare inputs |
| 181 | + std::vector<MultimodalInput> inputs; |
| 182 | + |
| 183 | + // 1. Add start bos-related text inputs and modality start token. |
| 184 | + inputs.emplace_back(make_text_input("<s>[INST][BEGIN_AUDIO]")); |
| 185 | + |
| 186 | + // 2. Add audio input |
| 187 | + inputs.emplace_back(processAudioFile(audio_path)); |
| 188 | + |
| 189 | + // 3. Add text input (the actual user-submitted prompt) |
| 190 | + inputs.emplace_back(make_text_input(std::string(prompt) + "[/INST]")); |
| 191 | + |
| 192 | + ::executorch::extension::llm::GenerationConfig config; |
| 193 | + config.max_new_tokens = 100; |
| 194 | + config.temperature = temperature; |
| 195 | + |
| 196 | + // Run warmup if requested |
| 197 | + if (warmup) { |
| 198 | + ET_LOG(Info, "Running warmup..."); |
| 199 | + auto warmup_error = runner->generate(inputs, config); |
| 200 | + if (warmup_error != ::executorch::runtime::Error::Ok) { |
| 201 | + ET_LOG(Error, "Failed to run warmup"); |
| 202 | + return 1; |
| 203 | + } |
| 204 | + runner->reset(); |
| 205 | + } |
| 206 | + |
| 207 | + // Generate |
| 208 | + ET_LOG(Info, "Starting generation..."); |
| 209 | + auto error = runner->generate(inputs, config); |
| 210 | + if (error != ::executorch::runtime::Error::Ok) { |
| 211 | + ET_LOG(Error, "Failed to generate with multimodal runner"); |
| 212 | + return 1; |
| 213 | + } |
| 214 | + |
| 215 | + printf("\n"); |
| 216 | + return 0; |
| 217 | +} |
0 commit comments