diff --git a/src/WebSocketApi.cpp b/src/WebSocketApi.cpp index 6939fb98..6546992d 100644 --- a/src/WebSocketApi.cpp +++ b/src/WebSocketApi.cpp @@ -47,14 +47,19 @@ WebSocketApi::WebSocketApi() proc_handler_add(_procHandler, "bool get_api_version(out int version)", &get_api_version, nullptr); proc_handler_add(_procHandler, "bool call_request(in string request_type, in string request_data, out ptr response)", - &call_request, nullptr); + &call_request, this); + proc_handler_add(_procHandler, "bool register_event_callback(in ptr callback, out bool success)", ®ister_event_callback, + this); + proc_handler_add(_procHandler, "bool unregister_event_callback(in ptr callback, out bool success)", + &unregister_event_callback, this); proc_handler_add(_procHandler, "bool vendor_register(in string name, out ptr vendor)", &vendor_register_cb, this); - proc_handler_add(_procHandler, "bool vendor_request_register(in ptr vendor, in string type, in ptr callback)", + proc_handler_add(_procHandler, + "bool vendor_request_register(in ptr vendor, in string type, in ptr callback, out bool success)", &vendor_request_register_cb, this); - proc_handler_add(_procHandler, "bool vendor_request_unregister(in ptr vendor, in string type)", + proc_handler_add(_procHandler, "bool vendor_request_unregister(in ptr vendor, in string type, out bool success)", &vendor_request_unregister_cb, this); - proc_handler_add(_procHandler, "bool vendor_event_emit(in ptr vendor, in string type, in ptr data)", &vendor_event_emit_cb, - this); + proc_handler_add(_procHandler, "bool vendor_event_emit(in ptr vendor, in string type, in ptr data, out bool success)", + &vendor_event_emit_cb, this); proc_handler_t *ph = obs_get_proc_handler(); assert(ph != NULL); @@ -70,6 +75,10 @@ WebSocketApi::~WebSocketApi() proc_handler_destroy(_procHandler); + size_t numEventCallbacks = _eventCallbacks.size(); + _eventCallbacks.clear(); + blog_debug("[WebSocketApi::~WebSocketApi] Deleted %ld event callbacks", numEventCallbacks); + for (auto vendor : _vendors) { blog_debug("[WebSocketApi::~WebSocketApi] Deleting vendor: %s", vendor.first.c_str()); delete vendor.second; @@ -80,10 +89,19 @@ WebSocketApi::~WebSocketApi() void WebSocketApi::BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData, uint8_t rpcVersion) { - UNUSED_PARAMETER(requiredIntent); - UNUSED_PARAMETER(eventType); - UNUSED_PARAMETER(eventData); - UNUSED_PARAMETER(rpcVersion); + if (!_obsReady) + return; + + // Only broadcast events applicable to the latest RPC version + if (rpcVersion && rpcVersion != CURRENT_RPC_VERSION) + return; + + std::string eventDataString = eventData.dump(); + + std::shared_lock l(_mutex); + + for (auto &cb : _eventCallbacks) + cb.callback(requiredIntent, eventType.c_str(), eventDataString.c_str(), cb.priv_data); } enum WebSocketApi::RequestReturnCode WebSocketApi::PerformVendorRequest(std::string vendorName, std::string requestType, @@ -128,14 +146,27 @@ void WebSocketApi::get_api_version(void *, calldata_t *cd) RETURN_SUCCESS(); } -void WebSocketApi::call_request(void *, calldata_t *cd) +void WebSocketApi::call_request(void *priv_data, calldata_t *cd) { + auto c = static_cast(priv_data); + +#if !defined(PLUGIN_TESTS) + if (!c->_obsReady) + RETURN_FAILURE(); +#endif + const char *request_type = calldata_string(cd, "request_type"); const char *request_data = calldata_string(cd, "request_data"); if (!request_type) RETURN_FAILURE(); +#ifdef PLUGIN_TESTS + // Allow plugin tests to complete, even though OBS wouldn't be ready at the time of the test + if (!c->_obsReady && std::string(request_type) != "GetVersion") + RETURN_FAILURE(); +#endif + auto response = static_cast(bzalloc(sizeof(struct obs_websocket_request_response))); if (!response) RETURN_FAILURE(); @@ -164,6 +195,52 @@ void WebSocketApi::call_request(void *, calldata_t *cd) RETURN_SUCCESS(); } +void WebSocketApi::register_event_callback(void *priv_data, calldata_t *cd) +{ + auto c = static_cast(priv_data); + + void *voidCallback; + if (!calldata_get_ptr(cd, "callback", &voidCallback) || !voidCallback) { + blog(LOG_WARNING, "[WebSocketApi::register_event_callback] Failed due to missing `callback` pointer."); + RETURN_FAILURE(); + } + + auto cb = static_cast(voidCallback); + + std::unique_lock l(c->_mutex); + + int64_t foundIndex = c->GetEventCallbackIndex(*cb); + if (foundIndex != -1) + RETURN_FAILURE(); + + c->_eventCallbacks.push_back(*cb); + + RETURN_SUCCESS(); +} + +void WebSocketApi::unregister_event_callback(void *priv_data, calldata_t *cd) +{ + auto c = static_cast(priv_data); + + void *voidCallback; + if (!calldata_get_ptr(cd, "callback", &voidCallback) || !voidCallback) { + blog(LOG_WARNING, "[WebSocketApi::register_event_callback] Failed due to missing `callback` pointer."); + RETURN_FAILURE(); + } + + auto cb = static_cast(voidCallback); + + std::unique_lock l(c->_mutex); + + int64_t foundIndex = c->GetEventCallbackIndex(*cb); + if (foundIndex == -1) + RETURN_FAILURE(); + + c->_eventCallbacks.erase(c->_eventCallbacks.begin() + foundIndex); + + RETURN_SUCCESS(); +} + void WebSocketApi::vendor_register_cb(void *priv_data, calldata_t *cd) { auto c = static_cast(priv_data); @@ -174,7 +251,7 @@ void WebSocketApi::vendor_register_cb(void *priv_data, calldata_t *cd) RETURN_FAILURE(); } - // Theoretically doesn't need a mutex, but it's good to be safe. + // Theoretically doesn't need a mutex due to module load being single-thread, but it's good to be safe. std::unique_lock l(c->_mutex); if (c->_vendors.count(vendorName)) { diff --git a/src/WebSocketApi.h b/src/WebSocketApi.h index 1c88502c..5f33df3a 100644 --- a/src/WebSocketApi.h +++ b/src/WebSocketApi.h @@ -57,9 +57,22 @@ class WebSocketApi { inline void SetVendorEventCallback(VendorEventCallback cb) { _vendorEventCallback = cb; } private: + inline int64_t GetEventCallbackIndex(obs_websocket_event_callback &cb) + { + for (int64_t i = 0; i < (int64_t)_eventCallbacks.size(); i++) { + auto currentCb = _eventCallbacks[i]; + if (currentCb.callback == cb.callback && currentCb.priv_data == cb.priv_data) + return i; + } + return -1; + } + + // Proc handlers static void get_ph_cb(void *priv_data, calldata_t *cd); static void get_api_version(void *, calldata_t *cd); static void call_request(void *, calldata_t *cd); + static void register_event_callback(void *, calldata_t *cd); + static void unregister_event_callback(void *, calldata_t *cd); static void vendor_register_cb(void *priv_data, calldata_t *cd); static void vendor_request_register_cb(void *priv_data, calldata_t *cd); static void vendor_request_unregister_cb(void *priv_data, calldata_t *cd); @@ -68,6 +81,7 @@ class WebSocketApi { std::shared_mutex _mutex; proc_handler_t *_procHandler; std::map _vendors; + std::vector _eventCallbacks; std::atomic _obsReady = false; diff --git a/src/obs-websocket.cpp b/src/obs-websocket.cpp index 629890e3..7c9071ee 100644 --- a/src/obs-websocket.cpp +++ b/src/obs-websocket.cpp @@ -103,12 +103,16 @@ bool obs_module_load(void) } #ifdef PLUGIN_TESTS +void test_call_request(); +void test_register_event_callback(); void test_register_vendor(); #endif void obs_module_post_load(void) { #ifdef PLUGIN_TESTS + test_call_request(); + test_register_event_callback(); test_register_vendor(); #endif @@ -229,12 +233,43 @@ void OnObsReady(bool ready) } #ifdef PLUGIN_TESTS +void test_call_request() +{ + blog(LOG_INFO, "[test_call_request] Testing obs-websocket plugin API request calling..."); -static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *responseData, void *priv_data) + struct obs_websocket_request_response *response = obs_websocket_call_request("GetVersion"); + if (response) { + blog(LOG_INFO, "[test_call_request] Called GetVersion. Status Code: %u | Comment: %s | Response Data: %s", + response->status_code, response->comment, response->response_data); + obs_websocket_request_response_free(response); + } else { + blog(LOG_ERROR, "[test_call_request] Failed to call GetVersion request via obs-websocket plugin API!"); + } + + blog(LOG_INFO, "[test_call_request] Test done."); +} + +static void test_event_cb(uint64_t eventIntent, const char *eventType, const char *eventData, void *priv_data) +{ + blog(LOG_DEBUG, "[test_event_cb] New event! Type: %s | Data: %s", eventType, eventData); + + UNUSED_PARAMETER(eventIntent); + UNUSED_PARAMETER(priv_data); +} + +void test_register_event_callback() { - blog(LOG_INFO, "[test_vendor_request_cb] Request called!"); + blog(LOG_INFO, "[test_register_event_callback] Registering test event callback..."); + + if (!obs_websocket_register_event_callback(test_event_cb, nullptr)) + blog(LOG_ERROR, "[test_register_event_callback] Failed to register event callback!"); - blog(LOG_INFO, "[test_vendor_request_cb] Request data: %s", obs_data_get_json(requestData)); + blog(LOG_INFO, "[test_register_event_callback] Test done."); +} + +static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *responseData, void *priv_data) +{ + blog(LOG_INFO, "[test_vendor_request_cb] Request called! Request data: %s", obs_data_get_json(requestData)); // Set an item to the response data obs_data_set_string(responseData, "test", "pp"); @@ -245,34 +280,25 @@ static void test_vendor_request_cb(obs_data_t *requestData, obs_data_t *response void test_register_vendor() { - blog(LOG_INFO, "[test_register_vendor] Registering test vendor..."); + blog(LOG_INFO, "[test_register_vendor] Testing vendor registration..."); // Test plugin API version fetch uint apiVersion = obs_websocket_get_api_version(); blog(LOG_INFO, "[test_register_vendor] obs-websocket plugin API version: %u", apiVersion); - // Test calling obs-websocket requests - struct obs_websocket_request_response *response = obs_websocket_call_request("GetVersion"); - if (response) { - blog(LOG_INFO, "[test_register_vendor] Called GetVersion. Status Code: %u | Comment: %s | Response Data: %s", - response->status_code, response->comment, response->response_data); - obs_websocket_request_response_free(response); - } - // Test vendor creation auto vendor = obs_websocket_register_vendor("obs-websocket-test"); if (!vendor) { - blog(LOG_WARNING, "[test_register_vendor] Failed to create vendor!"); + blog(LOG_ERROR, "[test_register_vendor] Failed to create vendor!"); return; } // Test vendor request registration if (!obs_websocket_vendor_register_request(vendor, "TestRequest", test_vendor_request_cb, vendor)) { - blog(LOG_WARNING, "[test_register_vendor] Failed to register vendor request!"); + blog(LOG_ERROR, "[test_register_vendor] Failed to register vendor request!"); return; } - blog(LOG_INFO, "[test_register_vendor] Post load completed."); + blog(LOG_INFO, "[test_register_vendor] Test done."); } - #endif diff --git a/src/obs-websocket.h b/src/obs-websocket.h index c8ca95b4..c6c0f349 100644 --- a/src/obs-websocket.h +++ b/src/obs-websocket.h @@ -26,6 +26,8 @@ with this program. If not, see #include "utils/Obs.h" #include "plugin-macros.generated.h" +#define CURRENT_RPC_VERSION 1 + struct Config; typedef std::shared_ptr ConfigPtr; diff --git a/src/websocketserver/WebSocketServer_Protocol.cpp b/src/websocketserver/WebSocketServer_Protocol.cpp index dd71f780..686a98f6 100644 --- a/src/websocketserver/WebSocketServer_Protocol.cpp +++ b/src/websocketserver/WebSocketServer_Protocol.cpp @@ -31,7 +31,7 @@ with this program. If not, see static bool IsSupportedRpcVersion(uint8_t requestedVersion) { - return (requestedVersion == 1); + return (requestedVersion == CURRENT_RPC_VERSION); } static json ConstructRequestResult(RequestResult requestResult, const json &requestJson)