diff --git a/ait/core/server/stream.py b/ait/core/server/stream.py index 93b4af6b..9aeac911 100644 --- a/ait/core/server/stream.py +++ b/ait/core/server/stream.py @@ -107,15 +107,23 @@ def valid_workflow(self): def output_stream_factory(name, inputs, outputs, handlers, zmq_args=None): + """ + This factory preempts the creating of output streams directly. It accepts + the same args as any given stream class and then based primarily on the + values in 'outputs' decides on the appropriate stream to instantiate and + then returns it. + """ ostream = None if type(outputs) is not list or (type(outputs) is list and len(outputs) == 0): - raise ValueError(f"Output stream specification invalid: {outputs}") + ostream = ZMQStream( + name, + inputs, + handlers, + zmq_args=zmq_args, + ) + return ostream # backwards compatability with original UDP spec - if ( - type(outputs) is list - and type(outputs[0]) is int - and ait.MIN_PORT <= outputs[0] <= ait.MAX_PORT - ): + if type(outputs[0]) is int and ait.MIN_PORT <= outputs[0] <= ait.MAX_PORT: ostream = UDPOutputStream(name, inputs, outputs[0], handlers, zmq_args=zmq_args) elif is_valid_address_spec(outputs[0]): protocol, hostname, port = outputs[0].split(":") @@ -141,7 +149,7 @@ def output_stream_factory(name, inputs, outputs, handlers, zmq_args=None): def input_stream_factory(name, inputs, handlers, zmq_args=None): """ - This factory preempts the creating of streams directly. It accepts + This factory preempts the creating of input streams directly. It accepts the same args as any given stream class and then based primarily on the values in 'inputs' decides on the appropriate stream to instantiate and then returns it. diff --git a/tests/ait/core/server/test_client.py b/tests/ait/core/server/test_client.py index 8cf64d6e..b86a694b 100644 --- a/tests/ait/core/server/test_client.py +++ b/tests/ait/core/server/test_client.py @@ -1,62 +1,82 @@ -# import gevent -# from ait.core.server.broker import Broker -# from ait.core.server.client import TCPInputClient -# from ait.core.server.client import TCPInputServer -# broker = Broker() -# TEST_BYTES = "Howdy".encode() -# TEST_PORT = 6666 -# class SimpleServer(gevent.server.StreamServer): -# def handle(self, socket, address): -# socket.sendall(TEST_BYTES) -# class TCPServer(TCPInputServer): -# def __init__(self, name, inputs, **kwargs): -# super(TCPServer, self).__init__(broker.context, input=inputs) -# def process(self, input_data): -# self.cur_socket.sendall(input_data) -# class TCPClient(TCPInputClient): -# def __init__(self, name, inputs, **kwargs): -# super(TCPClient, self).__init__( -# broker.context, input=inputs, protocol=gevent.socket.SOCK_STREAM -# ) -# self.input_data = None -# def process(self, input_data): -# self.input_data = input_data -# self._exit() -# class TestTCPServer: -# def setup_method(self): -# self.server = TCPServer("test_tcp_server", inputs=["server", TEST_PORT]) -# self.server.start() -# self.client = gevent.socket.create_connection(("127.0.0.1", TEST_PORT)) -# def teardown_method(self): -# self.server.stop() -# self.client.close() -# def test_TCP_server(self): -# nbytes = self.client.send(TEST_BYTES) -# response = self.client.recv(len(TEST_BYTES)) -# assert nbytes == len(TEST_BYTES) -# assert response == TEST_BYTES -# def test_null_send(self): -# nbytes1 = self.client.send(b"") -# nbytes2 = self.client.send(TEST_BYTES) -# response = self.client.recv(len(TEST_BYTES)) -# assert nbytes1 == 0 -# assert nbytes2 == len(TEST_BYTES) -# assert response == TEST_BYTES -# class TestTCPClient: -# def setup_method(self): -# self.server = SimpleServer(("127.0.0.1", 0)) -# self.server.start() -# self.client = TCPClient( -# "test_tcp_client", inputs=["127.0.0.1", self.server.server_port] -# ) -# def teardown_method(self): -# self.server.stop() -# def test_TCP_client(self): -# self.client.start() -# gevent.sleep(1) -# assert self.client.input_data == TEST_BYTES -# def test_bad_connection(self): -# self.client.port = 1 -# self.client.connection_reattempts = 2 -# self.client.start() -# assert self.client.connection_status != 0 +import gevent + +from ait.core.server.broker import Broker +from ait.core.server.client import TCPInputClient +from ait.core.server.client import TCPInputServer + +broker = Broker() +TEST_BYTES = "Howdy".encode() +TEST_PORT = 6666 + + +class SimpleServer(gevent.server.StreamServer): + def handle(self, socket, address): + socket.sendall(TEST_BYTES) + + +class TCPServer(TCPInputServer): + def __init__(self, name, inputs, **kwargs): + super(TCPServer, self).__init__(broker.context, input=inputs) + + def process(self, input_data): + self.cur_socket.sendall(input_data) + + +class TCPClient(TCPInputClient): + def __init__(self, name, inputs, **kwargs): + super(TCPClient, self).__init__( + broker.context, input=inputs, protocol=gevent.socket.SOCK_STREAM + ) + self.input_data = None + + def process(self, input_data): + self.input_data = input_data + self._exit() + + +class TestTCPServer: + def setup_method(self): + self.server = TCPServer("test_tcp_server", inputs=f"tcp:server:{TEST_PORT}") + self.server.start() + self.client = gevent.socket.create_connection(("127.0.0.1", TEST_PORT)) + + def teardown_method(self): + self.server.stop() + self.client.close() + + def test_TCP_server(self): + nbytes = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes == len(TEST_BYTES) + assert response == TEST_BYTES + + def test_null_send(self): + nbytes1 = self.client.send(b"") + nbytes2 = self.client.send(TEST_BYTES) + response = self.client.recv(len(TEST_BYTES)) + assert nbytes1 == 0 + assert nbytes2 == len(TEST_BYTES) + assert response == TEST_BYTES + + +class TestTCPClient: + def setup_method(self): + self.server = SimpleServer(("127.0.0.1", 0)) + self.server.start() + self.client = TCPClient( + "test_tcp_client", inputs=f"tcp:127.0.0.1:{self.server.server_port}" + ) + + def teardown_method(self): + self.server.stop() + + def test_TCP_client(self): + self.client.start() + gevent.sleep(1) + assert self.client.input_data == TEST_BYTES + + def test_bad_connection(self): + self.client.port = 1 + self.client.connection_reattempts = 2 + self.client.start() + assert self.client.connection_status != 0 diff --git a/tests/ait/core/server/test_server.py b/tests/ait/core/server/test_server.py index 28ce004b..25c93765 100644 --- a/tests/ait/core/server/test_server.py +++ b/tests/ait/core/server/test_server.py @@ -375,11 +375,11 @@ def test_successful_outbound_stream_creation( assert type(created_stream.handlers) == list # Testing creation of outbound stream with port output - config = cfg.AitConfig(config={"name": "some_stream", "output": 3333}) + config = cfg.AitConfig(config={"name": "some_stream", "output": [3333]}) created_stream = server._create_outbound_stream(config) - assert type(created_stream) == ait.core.server.stream.PortOutputStream + assert type(created_stream) == ait.core.server.stream.UDPOutputStream assert created_stream.name == "some_stream" - assert created_stream.out_port == 3333 + assert created_stream.addr_spec == ("localhost", 3333) assert created_stream.handlers == [] diff --git a/tests/ait/core/server/test_stream.py b/tests/ait/core/server/test_stream.py index cf0491a1..54f77dc5 100644 --- a/tests/ait/core/server/test_stream.py +++ b/tests/ait/core/server/test_stream.py @@ -8,10 +8,11 @@ from ait.core.server.handlers import PacketHandler from ait.core.server.stream import input_stream_factory from ait.core.server.stream import output_stream_factory -from ait.core.server.stream import PortOutputStream from ait.core.server.stream import TCPInputClientStream from ait.core.server.stream import TCPInputServerStream +from ait.core.server.stream import TCPOutputStream from ait.core.server.stream import UDPInputServerStream +from ait.core.server.stream import UDPOutputStream from ait.core.server.stream import ZMQStream