Skip to content

Commit

Permalink
WebSocketApi: Implement backend for obs-websocket event listening
Browse files Browse the repository at this point in the history
  • Loading branch information
tt2468 committed Apr 23, 2024
1 parent ee283c7 commit 5b4aa9d
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 28 deletions.
99 changes: 88 additions & 11 deletions src/WebSocketApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,19 @@ WebSocketApi::WebSocketApi()

proc_handler_add(_procHandler, "bool get_api_version(out int version)", &get_api_version, nullptr);
proc_handler_add(_procHandler, "bool call_request(in string request_type, in string request_data, out ptr response)",
&call_request, nullptr);
&call_request, this);
proc_handler_add(_procHandler, "bool register_event_callback(in ptr callback, out bool success)", &register_event_callback,
this);
proc_handler_add(_procHandler, "bool unregister_event_callback(in ptr callback, out bool success)",
&unregister_event_callback, this);
proc_handler_add(_procHandler, "bool vendor_register(in string name, out ptr vendor)", &vendor_register_cb, this);
proc_handler_add(_procHandler, "bool vendor_request_register(in ptr vendor, in string type, in ptr callback)",
proc_handler_add(_procHandler,
"bool vendor_request_register(in ptr vendor, in string type, in ptr callback, out bool success)",
&vendor_request_register_cb, this);
proc_handler_add(_procHandler, "bool vendor_request_unregister(in ptr vendor, in string type)",
proc_handler_add(_procHandler, "bool vendor_request_unregister(in ptr vendor, in string type, out bool success)",
&vendor_request_unregister_cb, this);
proc_handler_add(_procHandler, "bool vendor_event_emit(in ptr vendor, in string type, in ptr data)", &vendor_event_emit_cb,
this);
proc_handler_add(_procHandler, "bool vendor_event_emit(in ptr vendor, in string type, in ptr data, out bool success)",
&vendor_event_emit_cb, this);

proc_handler_t *ph = obs_get_proc_handler();
assert(ph != NULL);
Expand All @@ -70,6 +75,10 @@ WebSocketApi::~WebSocketApi()

proc_handler_destroy(_procHandler);

size_t numEventCallbacks = _eventCallbacks.size();
_eventCallbacks.clear();
blog_debug("[WebSocketApi::~WebSocketApi] Deleted %ld event callbacks", numEventCallbacks);

for (auto vendor : _vendors) {
blog_debug("[WebSocketApi::~WebSocketApi] Deleting vendor: %s", vendor.first.c_str());
delete vendor.second;
Expand All @@ -80,10 +89,19 @@ WebSocketApi::~WebSocketApi()

void WebSocketApi::BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData, uint8_t rpcVersion)
{
UNUSED_PARAMETER(requiredIntent);
UNUSED_PARAMETER(eventType);
UNUSED_PARAMETER(eventData);
UNUSED_PARAMETER(rpcVersion);
if (!_obsReady)
return;

// Only broadcast events applicable to the latest RPC version
if (rpcVersion && rpcVersion != CURRENT_RPC_VERSION)
return;

std::string eventDataString = eventData.dump();

std::shared_lock l(_mutex);

for (auto &cb : _eventCallbacks)
cb.callback(requiredIntent, eventType.c_str(), eventDataString.c_str(), cb.priv_data);
}

enum WebSocketApi::RequestReturnCode WebSocketApi::PerformVendorRequest(std::string vendorName, std::string requestType,
Expand Down Expand Up @@ -128,14 +146,27 @@ void WebSocketApi::get_api_version(void *, calldata_t *cd)
RETURN_SUCCESS();
}

void WebSocketApi::call_request(void *, calldata_t *cd)
void WebSocketApi::call_request(void *priv_data, calldata_t *cd)
{
auto c = static_cast<WebSocketApi *>(priv_data);

#if !defined(PLUGIN_TESTS)
if (!c->_obsReady)
RETURN_FAILURE();
#endif

const char *request_type = calldata_string(cd, "request_type");
const char *request_data = calldata_string(cd, "request_data");

if (!request_type)
RETURN_FAILURE();

#ifdef PLUGIN_TESTS
// Allow plugin tests to complete, even though OBS wouldn't be ready at the time of the test
if (!c->_obsReady && std::string(request_type) != "GetVersion")
RETURN_FAILURE();
#endif

auto response = static_cast<obs_websocket_request_response *>(bzalloc(sizeof(struct obs_websocket_request_response)));
if (!response)
RETURN_FAILURE();
Expand Down Expand Up @@ -164,6 +195,52 @@ void WebSocketApi::call_request(void *, calldata_t *cd)
RETURN_SUCCESS();
}

void WebSocketApi::register_event_callback(void *priv_data, calldata_t *cd)
{
auto c = static_cast<WebSocketApi *>(priv_data);

void *voidCallback;
if (!calldata_get_ptr(cd, "callback", &voidCallback) || !voidCallback) {
blog(LOG_WARNING, "[WebSocketApi::register_event_callback] Failed due to missing `callback` pointer.");
RETURN_FAILURE();
}

auto cb = static_cast<obs_websocket_event_callback *>(voidCallback);

std::unique_lock l(c->_mutex);

int64_t foundIndex = c->GetEventCallbackIndex(*cb);
if (foundIndex != -1)
RETURN_FAILURE();

c->_eventCallbacks.push_back(*cb);

RETURN_SUCCESS();
}

void WebSocketApi::unregister_event_callback(void *priv_data, calldata_t *cd)
{
auto c = static_cast<WebSocketApi *>(priv_data);

void *voidCallback;
if (!calldata_get_ptr(cd, "callback", &voidCallback) || !voidCallback) {
blog(LOG_WARNING, "[WebSocketApi::register_event_callback] Failed due to missing `callback` pointer.");
RETURN_FAILURE();
}

auto cb = static_cast<obs_websocket_event_callback *>(voidCallback);

std::unique_lock l(c->_mutex);

int64_t foundIndex = c->GetEventCallbackIndex(*cb);
if (foundIndex == -1)
RETURN_FAILURE();

c->_eventCallbacks.erase(c->_eventCallbacks.begin() + foundIndex);

RETURN_SUCCESS();
}

void WebSocketApi::vendor_register_cb(void *priv_data, calldata_t *cd)
{
auto c = static_cast<WebSocketApi *>(priv_data);
Expand All @@ -174,7 +251,7 @@ void WebSocketApi::vendor_register_cb(void *priv_data, calldata_t *cd)
RETURN_FAILURE();
}

// Theoretically doesn't need a mutex, but it's good to be safe.
// Theoretically doesn't need a mutex due to module load being single-thread, but it's good to be safe.
std::unique_lock l(c->_mutex);

if (c->_vendors.count(vendorName)) {
Expand Down
14 changes: 14 additions & 0 deletions src/WebSocketApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,22 @@ class WebSocketApi {
inline void SetVendorEventCallback(VendorEventCallback cb) { _vendorEventCallback = cb; }

private:
inline int64_t GetEventCallbackIndex(obs_websocket_event_callback &cb)
{
for (int64_t i = 0; i < (int64_t)_eventCallbacks.size(); i++) {
auto currentCb = _eventCallbacks[i];
if (currentCb.callback == cb.callback && currentCb.priv_data == cb.priv_data)
return i;
}
return -1;
}

// Proc handlers
static void get_ph_cb(void *priv_data, calldata_t *cd);
static void get_api_version(void *, calldata_t *cd);
static void call_request(void *, calldata_t *cd);
static void register_event_callback(void *, calldata_t *cd);
static void unregister_event_callback(void *, calldata_t *cd);
static void vendor_register_cb(void *priv_data, calldata_t *cd);
static void vendor_request_register_cb(void *priv_data, calldata_t *cd);
static void vendor_request_unregister_cb(void *priv_data, calldata_t *cd);
Expand All @@ -68,6 +81,7 @@ class WebSocketApi {
std::shared_mutex _mutex;
proc_handler_t *_procHandler;
std::map<std::string, Vendor *> _vendors;
std::vector<obs_websocket_event_callback> _eventCallbacks;

std::atomic<bool> _obsReady = false;

Expand Down
58 changes: 42 additions & 16 deletions src/obs-websocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,16 @@ bool obs_module_load(void)
}

#ifdef PLUGIN_TESTS
void test_call_request();
void test_register_event_callback();
void test_register_vendor();
#endif

void obs_module_post_load(void)
{
#ifdef PLUGIN_TESTS
test_call_request();
test_register_event_callback();
test_register_vendor();
#endif

Expand Down Expand Up @@ -229,12 +233,43 @@ void OnObsReady(bool ready)
}

#ifdef PLUGIN_TESTS
void test_call_request()
{
blog(LOG_INFO, "[test_call_request] Testing obs-websocket plugin API request calling...");

static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *responseData, void *priv_data)
struct obs_websocket_request_response *response = obs_websocket_call_request("GetVersion");
if (response) {
blog(LOG_INFO, "[test_call_request] Called GetVersion. Status Code: %u | Comment: %s | Response Data: %s",
response->status_code, response->comment, response->response_data);
obs_websocket_request_response_free(response);
} else {
blog(LOG_ERROR, "[test_call_request] Failed to call GetVersion request via obs-websocket plugin API!");
}

blog(LOG_INFO, "[test_call_request] Test done.");
}

static void test_event_cb(uint64_t eventIntent, const char *eventType, const char *eventData, void *priv_data)
{
blog(LOG_DEBUG, "[test_event_cb] New event! Type: %s | Data: %s", eventType, eventData);

UNUSED_PARAMETER(eventIntent);
UNUSED_PARAMETER(priv_data);
}

void test_register_event_callback()
{
blog(LOG_INFO, "[test_vendor_request_cb] Request called!");
blog(LOG_INFO, "[test_register_event_callback] Registering test event callback...");

if (!obs_websocket_register_event_callback(test_event_cb, nullptr))
blog(LOG_ERROR, "[test_register_event_callback] Failed to register event callback!");

blog(LOG_INFO, "[test_vendor_request_cb] Request data: %s", obs_data_get_json(requestData));
blog(LOG_INFO, "[test_register_event_callback] Test done.");
}

static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *responseData, void *priv_data)
{
blog(LOG_INFO, "[test_vendor_request_cb] Request called! Request data: %s", obs_data_get_json(requestData));

// Set an item to the response data
obs_data_set_string(responseData, "test", "pp");
Expand All @@ -245,34 +280,25 @@ static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *response

void test_register_vendor()
{
blog(LOG_INFO, "[test_register_vendor] Registering test vendor...");
blog(LOG_INFO, "[test_register_vendor] Testing vendor registration...");

// Test plugin API version fetch
uint apiVersion = obs_websocket_get_api_version();
blog(LOG_INFO, "[test_register_vendor] obs-websocket plugin API version: %u", apiVersion);

// Test calling obs-websocket requests
struct obs_websocket_request_response *response = obs_websocket_call_request("GetVersion");
if (response) {
blog(LOG_INFO, "[test_register_vendor] Called GetVersion. Status Code: %u | Comment: %s | Response Data: %s",
response->status_code, response->comment, response->response_data);
obs_websocket_request_response_free(response);
}

// Test vendor creation
auto vendor = obs_websocket_register_vendor("obs-websocket-test");
if (!vendor) {
blog(LOG_WARNING, "[test_register_vendor] Failed to create vendor!");
blog(LOG_ERROR, "[test_register_vendor] Failed to create vendor!");
return;
}

// Test vendor request registration
if (!obs_websocket_vendor_register_request(vendor, "TestRequest", test_vendor_request_cb, vendor)) {
blog(LOG_WARNING, "[test_register_vendor] Failed to register vendor request!");
blog(LOG_ERROR, "[test_register_vendor] Failed to register vendor request!");
return;
}

blog(LOG_INFO, "[test_register_vendor] Post load completed.");
blog(LOG_INFO, "[test_register_vendor] Test done.");
}

#endif
2 changes: 2 additions & 0 deletions src/obs-websocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ with this program. If not, see <https://www.gnu.org/licenses/>
#include "utils/Obs.h"
#include "plugin-macros.generated.h"

#define CURRENT_RPC_VERSION 1

struct Config;
typedef std::shared_ptr<Config> ConfigPtr;

Expand Down
2 changes: 1 addition & 1 deletion src/websocketserver/WebSocketServer_Protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ with this program. If not, see <https://www.gnu.org/licenses/>

static bool IsSupportedRpcVersion(uint8_t requestedVersion)
{
return (requestedVersion == 1);
return (requestedVersion == CURRENT_RPC_VERSION);
}

static json ConstructRequestResult(RequestResult requestResult, const json &requestJson)
Expand Down

0 comments on commit 5b4aa9d

Please sign in to comment.