Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example for set terminate #1042

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions documents/Runtime_option.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ This file will provide details on the usage of SetRuntimeOption API. It will lis

## Set Terminate

Set Terminate is a runtime option to terminate the current session or continue/restart an already terminated session. There are two valid ways to call Set Terminate.
Set Terminate is a runtime option to terminate the current session or continue/restart an already terminated session. The current session will crash when the terminate option is enabled and the user will need to handle that scenario, examples/c/src/phi3_terminate.cpp contains an example for this.

To enable terminate, the valid pair is: ("set_terminate", "1")
There are two valid ways to call Set Terminate.

To disable terminate, the valid pair is: ("set_terminate", "0")
To enable terminate, the valid pair is: ("terminate_session", "1")

Key: "set_terminate"
To disable terminate, the valid pair is: ("terminate_session", "0")

Key: "terminate_session"

Accepted values: ("0", "1")
84 changes: 60 additions & 24 deletions examples/c/src/phi3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include <iostream>
#include <string>
#include "ort_genai.h"
#include <thread>
#include <csignal>
#include <atomic>

using Clock = std::chrono::high_resolution_clock;
using TimePoint = std::chrono::time_point<Clock>;
Expand Down Expand Up @@ -65,16 +68,58 @@ class Timing {

// C++ API Example

void Generate_Output_CXX(OgaGenerator* generator, std::unique_ptr<OgaTokenizerStream> tokenizer_stream, bool is_first_token, Timing& timing) {
try {
while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
}

// Show usage of GetOutput
std::unique_ptr<OgaTensor> output_logits = generator->GetOutput("logits");

// Assuming output_logits.Type() is float as it's logits
// Assuming shape is 1 dimensional with shape[0] being the size
auto logits = reinterpret_cast<float*>(output_logits->Data());

// Print out the logits using the following snippet, if needed
//auto shape = output_logits->Shape();
//for (size_t i=0; i < shape[0]; i++)
// std::cout << logits[i] << " ";
//std::cout << std::endl;

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
}
}
catch (const std::exception& e) {
std::cout << "Session Terminated: " << e.what() << std::endl;
}
}

std::atomic<bool> stopFlag(false);

void signalHandler(int signum) {
std::cout << "Interrupt signal received. Terminating current session...\n";
stopFlag = true;
}

void CXX_API(const char* model_path) {
std::cout << "Creating model..." << std::endl;
auto model = OgaModel::Create(model_path);
std::cout << "Creating tokenizer..." << std::endl;
auto tokenizer = OgaTokenizer::Create(*model);
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);

while (true) {
signal(SIGINT, signalHandler);
auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer);
std::string text;
std::cout << "Prompt: (Use quit() to exit)" << std::endl;
std::cout << "Prompt: (Use quit() to exit) Or (To terminate current output generation, press Ctrl+C)" << std::endl;
std::getline(std::cin, text);

if (text == "quit()") {
Expand All @@ -97,31 +142,22 @@ void CXX_API(const char* model_path) {

auto generator = OgaGenerator::Create(*model, *params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();
std::thread th(Generate_Output_CXX, generator.get(), std::move(tokenizer_stream), is_first_token, std::ref(timing));

if (is_first_token) {
timing.RecordFirstTokenTimestamp();
is_first_token = false;
// Check for stopFlag in a loop
while (th.joinable()) {
if (stopFlag) {
generator->SetRuntimeOption("terminate_session", "1");
stopFlag = false;
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // Check every 100 ms
if (generator->IsDone())
break;
}

// Show usage of GetOutput
std::unique_ptr<OgaTensor> output_logits = generator->GetOutput("logits");

// Assuming output_logits.Type() is float as it's logits
// Assuming shape is 1 dimensional with shape[0] being the size
auto logits = reinterpret_cast<float*>(output_logits->Data());

// Print out the logits using the following snippet, if needed
//auto shape = output_logits->Shape();
//for (size_t i=0; i < shape[0]; i++)
// std::cout << logits[i] << " ";
//std::cout << std::endl;

const auto num_tokens = generator->GetSequenceCount(0);
const auto new_token = generator->GetSequenceData(0)[num_tokens - 1];
std::cout << tokenizer_stream->Decode(new_token) << std::flush;
if (th.joinable()) {
th.join(); // Join the thread if it's still running
}

timing.RecordEndTimestamp();
Expand Down
Loading