forked from mikelma/craftium
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mt_server.c
240 lines (198 loc) · 6.94 KB
/
mt_server.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
// my_extension.c
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <Python.h>
#include <numpy/arrayobject.h>
#include <stdio.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h> // read(), write(), close()
#include <poll.h>
#define SA struct sockaddr
static PyObject* init_server(PyObject* self, PyObject* args) { // , PyObject* args) {
int port, sockfd, connfd;
struct sockaddr_in servaddr;
// socket create and verification
sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == -1) {
PyErr_SetString(PyExc_Exception, "Server socket creation failed");
return NULL;
}
bzero(&servaddr, sizeof(servaddr));
// Assign IP and port
servaddr.sin_family = AF_INET;
servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
servaddr.sin_port = htons(0);
// Binding newly created socket to given address
if ((bind(sockfd, (SA*)&servaddr, sizeof(servaddr))) != 0) {
PyErr_SetString(PyExc_Exception, "Server socket bind failed");
return NULL;
}
struct sockaddr_in sin;
socklen_t len = sizeof(sin);
if (getsockname(sockfd, (struct sockaddr *)&sin, &len) == -1) {
PyErr_SetString(PyExc_Exception, "Server socket getsockname failed");
return NULL;
}
port = ntohs(sin.sin_port);
PyObject *py_port = Py_BuildValue("i", port);
PyObject *py_sockfd = Py_BuildValue("i", sockfd);
return PyTuple_Pack(2, py_port, py_sockfd);
}
static PyObject* server_listen(PyObject* self, PyObject* args) {
int sockfd, connfd, len, res, timeout_ms;
struct sockaddr_in cli;
struct pollfd pfd;
if (!PyArg_ParseTuple(args, "ii", &sockfd, &timeout_ms)) {
PyErr_SetString(PyExc_TypeError, "Expected two integers as arguments: sockefd, and timeout_ms");
return NULL;
}
if ((listen(sockfd, 1)) != 0) {
PyErr_SetString(PyExc_Exception, "Server socket listen failed");
return NULL;
}
pfd.fd = sockfd;
pfd.events = POLLIN;
pfd.revents = 0;
if ((res = poll(&pfd, 1, timeout_ms)) <= 0) {
if (res < 0)
PyErr_SetString(PyExc_Exception, "Server socket poll failed");
else
PyErr_SetString(PyExc_ConnectionError, "Server socket listen timeout reached");
return NULL;
}
len = sizeof(cli);
connfd = accept(sockfd, (SA*)&cli, &len);
if (connfd < 0) {
PyErr_SetString(PyExc_Exception, "Server socket accept failed");
return NULL;
}
return Py_BuildValue("i", connfd); // return connection's fd
}
#define BUFFER_SIZE 8192
int read_large_from_socket(int socket_fd, char *buffer, int total_size) {
int bytes_received = 0;
int total_bytes = 0;
while (total_bytes < total_size) {
// Calculate remaining size to read
int bytes_to_read = (total_size - total_bytes) < BUFFER_SIZE ?
(total_size - total_bytes) : BUFFER_SIZE;
bytes_received = recv(socket_fd, buffer + total_bytes, bytes_to_read, 0);
if (bytes_received < 0) {
// Handle error
if (errno == EINTR) {
continue; // Interrupted by signal, retry recv
} else {
perror("recv failed");
return -1; // Indicate failure
}
} else if (bytes_received == 0) {
// Connection closed by peer
break;
}
total_bytes += bytes_received;
}
return total_bytes;
}
static PyObject* server_recv(PyObject* self, PyObject* args) {
int connfd, n_bytes, obs_width, obs_height, n_read, n_channels;
double reward;
char *buff;
if (!PyArg_ParseTuple(args, "iiiii", &connfd, &n_bytes, &obs_width, &obs_height, &n_channels)) {
PyErr_SetString(PyExc_TypeError,
"Arguments must be 5 integers: connection's fd, num. of bytes to read, obs. width and height, and num. channels.");
return NULL;
}
// Create the buffer where the received image+data will be stored
buff = (char*)malloc(n_bytes);
if (buff == NULL) {
PyErr_SetString(PyExc_Exception, "Failed to allocate memory for recv buffer");
return NULL;
}
n_read = read_large_from_socket(connfd, buff, n_bytes);
if (n_read < 0) {
PyErr_SetString(PyExc_ConnectionError, "Failed to receive from MT, error reading from socket.");
return NULL;
} else if (n_read == 0) {
close(connfd);
PyErr_SetString(PyExc_ConnectionError, "Failed to receive from MT. Connection closed by peer: is MT down?");
return NULL;
}
// Retreive the termination flag (last byte in the buffer) and reward (8 bytes)
int termination = (int) buff[n_bytes-1];
PyObject* py_termination = PyBool_FromLong(termination);
memcpy(&reward, &buff[n_bytes-9], sizeof(reward));
PyObject* py_reward = PyFloat_FromDouble(reward);
// Create the numpy array of the image
npy_intp dims[3] = {obs_height, obs_width, n_channels};
PyObject* array = PyArray_SimpleNewFromData(3, dims, NPY_UINT8, buff);
if (!array) {
PyErr_SetString(PyExc_RuntimeError, "Failed to create NumPy array");
return NULL;
}
// Make the NumPy array own its data.
// This makes sure that NumPy handles the data lifecycle properly.
PyArray_ENABLEFLAGS((PyArrayObject*)array, NPY_ARRAY_OWNDATA);
PyObject* tuple = PyTuple_Pack(3, array, py_reward, py_termination);
// Decreases the reference count of Python objects. Useful if the
// objects' lifetime is no longer needed after creating the tuple.
Py_DECREF(array);
Py_DECREF(py_reward);
Py_DECREF(py_termination);
return tuple;
}
static PyObject* server_send(PyObject* self, PyObject* args) {
int connfd, n_send, size;
PyObject *bytes_obj;
char *buff;
if (!PyArg_ParseTuple(args, "iS", &connfd, &bytes_obj)) {
PyErr_SetString(PyExc_TypeError,
"Arguments are: connection's fd (int), and a bytes object.");
return NULL;
}
// Get the size of the bytes object
size = PyBytes_Size(bytes_obj);
if (size < 0) {
return NULL;
}
// Get a pointer to the bytes object's data
buff = PyBytes_AsString(bytes_obj);
if (buff == NULL) {
return NULL;
}
n_send = write(connfd, buff, size);
if (n_send <= 0) {
PyErr_SetString(PyExc_ConnectionError, "Failed to send data to MT");
return NULL;
}
return Py_BuildValue("");
}
// Method definitions
static PyMethodDef MyMethods[] = {
{"init_server", init_server, METH_VARARGS, "Initialize the MT server"},
{"server_listen", server_listen, METH_VARARGS, "Listen for MT to connect"},
{"server_recv", server_recv, METH_VARARGS, "Receive message from MT"},
{"server_send", server_send, METH_VARARGS, "Sends a message to MT"},
{NULL, NULL, 0, NULL}
};
// Module definition
static struct PyModuleDef mymodule = {
PyModuleDef_HEAD_INIT,
"mt_server",
"A fast implementation for the MT communication server",
-1,
MyMethods
};
// Module initialization
PyMODINIT_FUNC PyInit_mt_server(void) {
PyObject *m;
m = PyModule_Create(&mymodule);
if (m == NULL) {
return NULL;
}
import_array(); // Initialize NumPy API
return m;
}