diff --git a/src/BizHawk.Client.Common/Api/ClientWebSocketWrapper.cs b/src/BizHawk.Client.Common/Api/ClientWebSocketWrapper.cs index 9f754573936..467aea200a4 100644 --- a/src/BizHawk.Client.Common/Api/ClientWebSocketWrapper.cs +++ b/src/BizHawk.Client.Common/Api/ClientWebSocketWrapper.cs @@ -5,20 +5,30 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Collections.Generic; namespace BizHawk.Client.Common { - public struct ClientWebSocketWrapper + public class ClientWebSocketWrapper { private ClientWebSocket? _w; + + private List _receivedMessages; + + Uri _uri; /// calls getter (unless closed/disposed, then is always returned) public WebSocketState State => _w?.State ?? WebSocketState.Closed; - public ClientWebSocketWrapper(Uri uri, CancellationToken? cancellationToken = null) + public ClientWebSocketWrapper(Uri uri, int bufferSize, int maxMessages) { + _uri = uri; _w = new ClientWebSocket(); - _w.ConnectAsync(uri, cancellationToken ?? CancellationToken.None).Wait(); + _receivedMessages = new List(); + try{ + Connect(bufferSize, maxMessages).Wait(); + } + catch(Exception ex){} } /// calls @@ -26,24 +36,32 @@ public ClientWebSocketWrapper(Uri uri, CancellationToken? cancellationToken = nu public Task Close(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken? cancellationToken = null) { if (_w == null) throw new ObjectDisposedException(nameof(_w)); - var task = _w.CloseAsync(closeStatus, statusDescription, cancellationToken ?? CancellationToken.None); + var task = _w.CloseOutputAsync(closeStatus, statusDescription, cancellationToken ?? CancellationToken.None); _w.Dispose(); _w = null; return task; } /// calls - public Task Receive(ArraySegment buffer, CancellationToken? cancellationToken = null) - => _w?.ReceiveAsync(buffer, cancellationToken ?? CancellationToken.None) - ?? throw new ObjectDisposedException(nameof(_w)); + public async Task Receive(int bufferSize, int maxMessages){ + var buffer = new ArraySegment(new byte[bufferSize]); + while (_w != null && _w.State == WebSocketState.Open) + { + WebSocketReceiveResult result; + result = await _w.ReceiveAsync(buffer, CancellationToken.None); + if (maxMessages == 0 || _receivedMessages.Count < maxMessages) + _receivedMessages.Add(Encoding.UTF8.GetString(buffer.Array,0,result.Count)); + } + } - /// calls - public string Receive(int bufferCap, CancellationToken? cancellationToken = null) - { - if (_w == null) throw new ObjectDisposedException(nameof(_w)); - var buffer = new byte[bufferCap]; - var result = Receive(new ArraySegment(buffer), cancellationToken ?? CancellationToken.None).Result; - return Encoding.UTF8.GetString(buffer, 0, result.Count); + public async Task Connect(int bufferSize, int maxMessages){ + if (_w == null){ + _w = new ClientWebSocket(); + } + if(_w != null && _w.State != WebSocketState.Open){ + _w.ConnectAsync(_uri, CancellationToken.None).Wait(); + Receive(bufferSize, maxMessages); + } } /// calls @@ -62,5 +80,13 @@ public Task Send(string message, bool endOfMessage, CancellationToken? cancellat cancellationToken ?? CancellationToken.None ); } + + public string GetMessage() + { + if (_receivedMessages == null || _receivedMessages.Count == 0) return ""; + string returnThis = _receivedMessages[0]; + _receivedMessages.RemoveAt(0); + return returnThis; + } } } diff --git a/src/BizHawk.Client.Common/Api/Interfaces/ICommApi.cs b/src/BizHawk.Client.Common/Api/Interfaces/ICommApi.cs index 3d40a598f75..38eb1f02641 100644 --- a/src/BizHawk.Client.Common/Api/Interfaces/ICommApi.cs +++ b/src/BizHawk.Client.Common/Api/Interfaces/ICommApi.cs @@ -10,9 +10,7 @@ public interface ICommApi : IExternalApi SocketServer? Sockets { get; } -#if ENABLE_WEBSOCKETS WebSocketServer WebSockets { get; } -#endif string? HttpTest(); diff --git a/src/BizHawk.Client.Common/Api/WebSocketServer.cs b/src/BizHawk.Client.Common/Api/WebSocketServer.cs index da6a34a5100..457f9006ffe 100644 --- a/src/BizHawk.Client.Common/Api/WebSocketServer.cs +++ b/src/BizHawk.Client.Common/Api/WebSocketServer.cs @@ -7,6 +7,6 @@ namespace BizHawk.Client.Common { public sealed class WebSocketServer { - public ClientWebSocketWrapper Open(Uri uri, CancellationToken? cancellationToken = null) => new ClientWebSocketWrapper(uri, cancellationToken); + public ClientWebSocketWrapper Open(Uri uri, int bufferSize, int maxMessages) => new ClientWebSocketWrapper(uri, bufferSize, maxMessages); } } diff --git a/src/BizHawk.Client.Common/lua/CommonLibs/CommLuaLibrary.cs b/src/BizHawk.Client.Common/lua/CommonLibs/CommLuaLibrary.cs index bc44d08d1cb..333163a33b8 100644 --- a/src/BizHawk.Client.Common/lua/CommonLibs/CommLuaLibrary.cs +++ b/src/BizHawk.Client.Common/lua/CommonLibs/CommLuaLibrary.cs @@ -3,8 +3,10 @@ using System.ComponentModel; using System.Linq; using System.Text; +using System.Net.WebSockets; using NLua; +using System.Threading.Tasks; namespace BizHawk.Client.Common { @@ -254,20 +256,22 @@ private void CheckHttp() } } -#if ENABLE_WEBSOCKETS - [LuaMethod("ws_open", "Opens a websocket and returns the id so that it can be retrieved later.")] + [LuaMethod("ws_open", "Opens a websocket and returns the id so that it can be retrieved later. If an id is provided, reconnects to the ")] [LuaMethodExample("local ws_id = comm.ws_open(\"wss://echo.websocket.org\");")] - public string WebSocketOpen(string uri) + public string WebSocketOpen(string uri, string guid = null, int bufferSize = 1024, int maxMessages = 20) { var wsServer = APIs.Comm.WebSockets; + var localGuid = guid == null ? new Guid() : Guid.Parse(guid); if (wsServer == null) { - Log("WebSocket server is somehow not available"); + Log("WebSocket server is not available"); return null; } - var guid = new Guid(); - _websockets[guid] = wsServer.Open(new Uri(uri)); - return guid.ToString(); + if (guid == null) + _websockets[localGuid] = wsServer.Open(new Uri(uri),bufferSize, maxMessages); + else + _websockets[localGuid].Connect(bufferSize, maxMessages); + return localGuid.ToString(); } [LuaMethod("ws_send", "Send a message to a certain websocket id (boolean flag endOfMessage)")] @@ -280,11 +284,11 @@ public void WebSocketSend( if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper)) wrapper.Send(content, endOfMessage); } - [LuaMethod("ws_receive", "Receive a message from a certain websocket id and a maximum number of bytes to read")] - [LuaMethodExample("local ws = comm.ws_receive(ws_id, str_len);")] - public string WebSocketReceive(string guid, int bufferCap) + [LuaMethod("ws_receive", "Get a receive message from a certain websocket id")] + [LuaMethodExample("local ws = comm.ws_receive(ws_id);")] + public string WebSocketReceive(string guid) => _websockets.TryGetValue(Guid.Parse(guid), out var wrapper) - ? wrapper.Receive(bufferCap) + ? wrapper.GetMessage() : null; [LuaMethod("ws_get_status", "Get a websocket's status")] @@ -298,11 +302,10 @@ public string WebSocketReceive(string guid, int bufferCap) [LuaMethodExample("local ws_status = comm.ws_close(ws_id, close_status);")] public void WebSocketClose( string guid, - WebSocketCloseStatus status, + int status, string closeMessage) { - if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper)) wrapper.Close(status, closeMessage); + if (_websockets.TryGetValue(Guid.Parse(guid), out var wrapper)) wrapper.Close((WebSocketCloseStatus)status, closeMessage); } -#endif } } \ No newline at end of file