REDAC HybridController
Firmware for LUCIDAC/REDAC Teensy
Loading...
Searching...
No Matches
endpoint.cpp
Go to the documentation of this file.
1#include "websockets/endpoint.h"
2#include <string.h>
3
4namespace websockets {
5
6FLASHMEM
7CloseReason GetCloseReason(uint16_t reasonCode) {
8 switch (reasonCode) {
9 case CloseReason_NormalClosure:
10 return CloseReason_NormalClosure;
11
12 case CloseReason_GoingAway:
13 return CloseReason_GoingAway;
14
15 case CloseReason_ProtocolError:
16 return CloseReason_ProtocolError;
17
18 case CloseReason_UnsupportedData:
19 return CloseReason_UnsupportedData;
20
21 case CloseReason_AbnormalClosure:
22 return CloseReason_AbnormalClosure;
23
24 case CloseReason_InvalidPayloadData:
25 return CloseReason_InvalidPayloadData;
26
27 case CloseReason_PolicyViolation:
28 return CloseReason_PolicyViolation;
29
30 case CloseReason_MessageTooBig:
31 return CloseReason_MessageTooBig;
32
33 case CloseReason_NoStatusRcvd:
34 return CloseReason_NoStatusRcvd;
35
36 case CloseReason_InternalServerError:
37 return CloseReason_InternalServerError;
38
39 default:
40 return CloseReason_None;
41 }
42}
43
44namespace internals {
45
47 uint32_t highest = (num >> 24);
48 uint32_t second = (num << 8) >> 24;
49 uint32_t third = (num << 16) >> 24;
50 uint32_t lowest = (num << 24) >> 24;
51
52 return highest | (second << 8) | (third << 16) | (lowest << 24);
53}
54
55uint64_t swapEndianess(uint64_t num) {
56 uint32_t upper = (num >> 32);
57 uint32_t lower = (num << 32) >> 32;
58
59 upper = swapEndianess(upper);
60 lower = swapEndianess(lower);
61
62 uint64_t upperLong = upper;
63 uint64_t lowerLong = lower;
64
65 return upperLong | (lowerLong << 32);
66}
67
68FLASHMEM
69WebsocketsEndpoint::WebsocketsEndpoint(std::shared_ptr<network::TcpClient> client,
70 FragmentsPolicy fragmentsPolicy)
71 : _client(client), _fragmentsPolicy(fragmentsPolicy), _recvMode(RecvMode_Normal),
72 _streamBuilder(fragmentsPolicy == FragmentsPolicy_Notify ? true : false),
73 _closeReason(CloseReason_None) {
74 // Empty
75}
76
77WebsocketsEndpoint::WebsocketsEndpoint(const WebsocketsEndpoint &other)
78 : _client(other._client), _fragmentsPolicy(other._fragmentsPolicy), _recvMode(other._recvMode),
79 _streamBuilder(other._streamBuilder), _closeReason(other._closeReason), _useMasking(other._useMasking) {
80
81 const_cast<WebsocketsEndpoint &>(other)._client = nullptr;
82}
83
84WebsocketsEndpoint::WebsocketsEndpoint(const WebsocketsEndpoint &&other)
85 : _client(other._client), _fragmentsPolicy(other._fragmentsPolicy), _recvMode(other._recvMode),
86 _streamBuilder(other._streamBuilder), _closeReason(other._closeReason), _useMasking(other._useMasking) {
87
88 const_cast<WebsocketsEndpoint &>(other)._client = nullptr;
89}
90
91FLASHMEM
92WebsocketsEndpoint &WebsocketsEndpoint::operator=(const WebsocketsEndpoint &other) {
93 this->_client = other._client;
94 this->_fragmentsPolicy = other._fragmentsPolicy;
95 this->_recvMode = other._recvMode;
96 this->_streamBuilder = other._streamBuilder;
97 this->_closeReason = other._closeReason;
98 this->_useMasking = other._useMasking;
99
100 const_cast<WebsocketsEndpoint &>(other)._client = nullptr;
101
102 return *this;
103}
104
105FLASHMEM
106WebsocketsEndpoint &WebsocketsEndpoint::operator=(const WebsocketsEndpoint &&other) {
107 this->_client = other._client;
108 this->_fragmentsPolicy = other._fragmentsPolicy;
109 this->_recvMode = other._recvMode;
110 this->_streamBuilder = other._streamBuilder;
111 this->_closeReason = other._closeReason;
112 this->_useMasking = other._useMasking;
113
114 const_cast<WebsocketsEndpoint &>(other)._client = nullptr;
115
116 return *this;
117}
118
119FLASHMEM
120void WebsocketsEndpoint::setInternalSocket(std::shared_ptr<network::TcpClient> socket) {
121 this->_client = socket;
122}
123
124FLASHMEM
125bool WebsocketsEndpoint::poll() { return this->_client->poll(); }
126
127uint32_t readUntilSuccessfullOrError(network::TcpClient &socket, uint8_t *buffer, const uint32_t len) {
128 uint32_t numRead = socket.read(buffer, len);
129 while (numRead == static_cast<uint32_t>(-1) && socket.available()) {
130 numRead = socket.read(buffer, len);
131 }
132 return numRead;
133}
134
135Header readHeaderFromSocket(network::TcpClient &socket) {
136 Header header;
137 header.payload = 0;
138 readUntilSuccessfullOrError(socket, reinterpret_cast<uint8_t *>(&header), 2);
139 return header;
140}
141
142FLASHMEM
143uint64_t readExtendedPayloadLength(network::TcpClient &socket, const Header &header) {
144 uint64_t extendedPayload = header.payload;
145 // in case of extended payload length
146 if (header.payload == 126) {
147 // read next 16 bits as payload length
148 uint16_t tmp = 0;
149 readUntilSuccessfullOrError(socket, reinterpret_cast<uint8_t *>(&tmp), 2);
150 tmp = (tmp << 8) | (tmp >> 8);
151 extendedPayload = tmp;
152 } else if (header.payload == 127) {
153 uint64_t tmp = 0;
154 readUntilSuccessfullOrError(socket, reinterpret_cast<uint8_t *>(&tmp), 8);
155 extendedPayload = swapEndianess(tmp);
156 }
157
158 return extendedPayload;
159}
160
161void readMaskingKey(network::TcpClient &socket, uint8_t *outputBuffer) {
162 readUntilSuccessfullOrError(socket, reinterpret_cast<uint8_t *>(outputBuffer), 4);
163}
164
165FLASHMEM
166std::string readData(network::TcpClient &socket, uint64_t extendedPayload) {
167 const uint64_t BUFFER_SIZE = _WS_BUFFER_SIZE;
168
169 std::string data(extendedPayload, '\0');
170 uint8_t buffer[BUFFER_SIZE];
171 uint64_t done_reading = 0;
172 while (done_reading < extendedPayload && socket.available()) {
173 uint64_t to_read =
174 extendedPayload - done_reading >= BUFFER_SIZE ? BUFFER_SIZE : extendedPayload - done_reading;
175 uint32_t numReceived = readUntilSuccessfullOrError(socket, buffer, to_read);
176
177 // On failed reads, skip
178 if (!socket.available())
179 break;
180
181 for (uint64_t i = 0; i < numReceived; i++) {
182 data[done_reading + i] = static_cast<char>(buffer[i]);
183 }
184
185 done_reading += numReceived;
186 }
187 return (data);
188}
189
190void remaskData(std::string &data, const uint8_t *const maskingKey, uint64_t payloadLength) {
191 for (uint64_t i = 0; i < payloadLength; i++) {
192 data[i] = data[i] ^ maskingKey[i % 4];
193 }
194}
195
196FLASHMEM
197WebsocketsFrame WebsocketsEndpoint::_recv() {
198 auto header = readHeaderFromSocket(*this->_client);
199 if (!_client->available())
200 return WebsocketsFrame(); // In case of faliure
201
202 uint64_t payloadLength = readExtendedPayloadLength(*this->_client, header);
203 if (!_client->available())
204 return WebsocketsFrame(); // In case of faliure
205
206#ifdef _WS_CONFIG_MAX_MESSAGE_SIZE
207 if (payloadLength > _WS_CONFIG_MAX_MESSAGE_SIZE) {
208 return WebsocketsFrame();
209 }
210#endif
211
212 uint8_t maskingKey[4];
213 // if masking is set
214 if (header.mask) {
215 readMaskingKey(*this->_client, maskingKey);
216 if (!_client->available())
217 return WebsocketsFrame(); // In case of faliure
218 }
219
220 WebsocketsFrame frame;
221 // read the message's payload (data) according to the read length
222 frame.payload = readData(*this->_client, payloadLength);
223 if (!_client->available())
224 return WebsocketsFrame(); // In case of faliure
225
226 // if masking is set un-mask the message
227 if (header.mask) {
228 remaskData(frame.payload, maskingKey, payloadLength);
229 }
230
231 // Construct frame from data and header that was read
232 frame.fin = header.fin;
233 frame.mask = header.mask;
234
235 frame.mask_buf[0] = maskingKey[0];
236 frame.mask_buf[1] = maskingKey[1];
237 frame.mask_buf[2] = maskingKey[2];
238 frame.mask_buf[3] = maskingKey[3];
239
240 frame.opcode = header.opcode;
241 frame.payload_length = payloadLength;
242
243 return (frame);
244}
245
246FLASHMEM
247WebsocketsMessage WebsocketsEndpoint::handleFrameInStreamingMode(WebsocketsFrame &frame) {
248 if (frame.isControlFrame()) {
249 auto msg = WebsocketsMessage::CreateFromFrame((frame));
250 this->handleMessageInternally(msg);
251 return (msg);
252 } else if (frame.isBeginningOfFragmentsStream()) {
253 this->_recvMode = RecvMode_Streaming;
254
255 if (this->_streamBuilder.isEmpty()) {
256 this->_streamBuilder.first(frame);
257 // if policy is set to notify, return the frame to the user
258 if (this->_fragmentsPolicy == FragmentsPolicy_Notify) {
259 return WebsocketsMessage(this->_streamBuilder.type(), std::move(frame.payload), MessageRole::First);
260 } else
261 return {};
262 }
263 } else if (frame.isContinuesFragment()) {
264 this->_streamBuilder.append(frame);
265 if (this->_streamBuilder.isOk()) {
266 // if policy is set to notify, return the frame to the user
267 if (this->_fragmentsPolicy == FragmentsPolicy_Notify) {
268 return WebsocketsMessage(this->_streamBuilder.type(), std::move(frame.payload),
269 MessageRole::Continuation);
270 } else
271 return {};
272 }
273 } else if (frame.isEndOfFragmentsStream()) {
274 this->_recvMode = RecvMode_Normal;
275 this->_streamBuilder.end(frame);
276 if (this->_streamBuilder.isOk()) {
277 // if policy is set to notify, return the frame to the user
278 if (this->_fragmentsPolicy == FragmentsPolicy_Aggregate) {
279 auto completeMessage = this->_streamBuilder.build();
280 this->_streamBuilder = WebsocketsMessage::StreamBuilder(false);
281 this->handleMessageInternally(completeMessage);
282 return completeMessage;
283 } else { // in case of notify policy
284 auto messageType = this->_streamBuilder.type();
285 this->_streamBuilder = WebsocketsMessage::StreamBuilder(true);
286 return WebsocketsMessage(messageType, std::move(frame.payload), MessageRole::Last);
287 }
288 }
289 }
290
291 // Error
292 close(CloseReason_ProtocolError);
293 return {};
294}
295
296FLASHMEM
297WebsocketsMessage WebsocketsEndpoint::handleFrameInStandardMode(WebsocketsFrame &frame) {
298 // Normal (unfragmented) frames are handled as a complete message
299 if (frame.isNormalUnfragmentedMessage() || frame.isControlFrame()) {
300 auto msg = WebsocketsMessage::CreateFromFrame(std::move(frame));
301 this->handleMessageInternally(msg);
302 return (msg);
303 } else if (frame.isBeginningOfFragmentsStream()) {
304 return handleFrameInStreamingMode(frame);
305 }
306
307 // This is an error. a bad combination of opcodes and fin flag arrived.
308 close(CloseReason_ProtocolError);
309 return {};
310}
311
312WebsocketsMessage WebsocketsEndpoint::recv() {
313 auto frame = _recv();
314 if (frame.isEmpty()) {
315 return {};
316 }
317
318 if (this->_recvMode == RecvMode_Normal) {
319 return handleFrameInStandardMode(frame);
320 } else /* this->_recvMode == RecvMode_Streaming */ {
321 return handleFrameInStreamingMode(frame);
322 }
323}
324
325FLASHMEM
326void WebsocketsEndpoint::handleMessageInternally(WebsocketsMessage &msg) {
327 if (msg.isPing()) {
328 pong(internals::fromInterfaceString(msg.data()));
329 } else if (msg.isClose()) {
330 // is there a reason field
331 if (internals::fromInterfaceString(msg.data()).size() >= 2) {
332 uint16_t reason = *(reinterpret_cast<const uint16_t *>(msg.data().c_str()));
333 reason = reason >> 8 | reason << 8;
334 this->_closeReason = GetCloseReason(reason);
335 } else {
336 this->_closeReason = CloseReason_GoingAway;
337 }
338 close(this->_closeReason);
339 }
340}
341
342bool WebsocketsEndpoint::send(const char *data, const size_t len, const uint8_t opcode, const bool fin) {
343 return this->send(data, len, opcode, fin, this->_useMasking);
344}
345
346bool WebsocketsEndpoint::send(const std::string &data, const uint8_t opcode, const bool fin) {
347 return this->send(data, opcode, fin, this->_useMasking);
348}
349
350bool WebsocketsEndpoint::send(const std::string &data, const uint8_t opcode, const bool fin, const bool mask,
351 const char *maskingKey) {
352 return send(data.c_str(), data.size(), opcode, fin, mask, maskingKey);
353}
354
355FLASHMEM
356std::string WebsocketsEndpoint::getHeader(uint64_t len, uint8_t opcode, bool fin, bool mask) {
357 std::string header_data;
358
359 if (len < 126) {
360 auto header = MakeHeader<Header>(len, opcode, fin, mask);
361 header_data = std::string(reinterpret_cast<char *>(&header), 2 + 0);
362 } else if (len < 65536) {
363 auto header = MakeHeader<HeaderWithExtended16>(len, opcode, fin, mask);
364 header.extendedPayload = (len << 8) | (len >> 8);
365 header_data = std::string(reinterpret_cast<char *>(&header), 2 + 2);
366 } else {
367 auto header = MakeHeader<HeaderWithExtended64>(len, opcode, fin, mask);
368 // header.extendedPayload = swapEndianess(len);
369 header.extendedPayload = swapEndianess(len);
370
371 header_data = std::string(reinterpret_cast<char *>(&header), 2);
372 header_data += std::string(reinterpret_cast<char *>(&header.extendedPayload), 8);
373 }
374
375 return header_data;
376}
377
378void remaskData(std::string &data, const char *const maskingKey, size_t first, size_t len) {
379 for (size_t i = first; i < first + len; i++) {
380 data[i] = data[i] ^ maskingKey[i % 4];
381 }
382}
383
384bool WebsocketsEndpoint::send(const char *data, const size_t len, const uint8_t opcode, const bool fin,
385 const bool mask, const char *maskingKey) {
386
387#ifdef _WS_CONFIG_MAX_MESSAGE_SIZE
388 if (len > _WS_CONFIG_MAX_MESSAGE_SIZE) {
389 return false;
390 }
391#endif
392 // send the header
393 std::string message_data = getHeader(len, opcode, fin, mask);
394
395 if (mask) {
396 message_data += std::string(maskingKey, 4);
397 }
398
399 size_t data_start = message_data.size();
400 message_data += std::string(data, len);
401
402 if (mask && memcmp(maskingKey, __TINY_WS_INTERNAL_DEFAULT_MASK, 4) != 0) {
403 remaskData(message_data, maskingKey, data_start, len);
404 }
405
406 this->_client->send(message_data);
407 return true; // TODO dont assume success
408}
409
410void WebsocketsEndpoint::close(CloseReason reason) {
411 this->_closeReason = reason;
412
413 if (!this->_client->available())
414 return;
415
416 if (reason == CloseReason_None) {
417 send(nullptr, 0, internals::ContentType::Close, true, this->_useMasking);
418 } else {
419 uint16_t reasonNum = static_cast<uint16_t>(reason);
420 reasonNum = (reasonNum >> 8) | (reasonNum << 8);
421 send(reinterpret_cast<const char *>(&reasonNum), 2, internals::ContentType::Close, true,
422 this->_useMasking);
423 }
424 this->_client->close();
425}
426
427CloseReason WebsocketsEndpoint::getCloseReason() const { return _closeReason; }
428
429bool WebsocketsEndpoint::ping(const std::string &msg) {
430 // Ping data must be shorter than 125 bytes
431 if (msg.size() > 125) {
432 return false;
433 } else {
434 return send(msg, ContentType::Ping, true, this->_useMasking);
435 }
436}
437
438bool WebsocketsEndpoint::ping(const std::string &&msg) {
439 // Ping data must be shorter than 125 bytes
440 if (msg.size() > 125) {
441 return false;
442 } else {
443 return send(msg, ContentType::Ping, true, this->_useMasking);
444 }
445}
446
447bool WebsocketsEndpoint::pong(const std::string &msg) {
448 // Pong data must be shorter than 125 bytes
449 if (msg.size() > 125) {
450 return false;
451 } else {
452 return this->send(msg, ContentType::Pong, true, this->_useMasking);
453 }
454}
455
456bool WebsocketsEndpoint::pong(const std::string &&msg) {
457 // Pong data must be shorter than 125 bytes
458 if (msg.size() > 125) {
459 return false;
460 } else {
461 return this->send(msg, ContentType::Pong, true, this->_useMasking);
462 }
463}
464
465void WebsocketsEndpoint::setFragmentsPolicy(FragmentsPolicy newPolicy) { this->_fragmentsPolicy = newPolicy; }
466
467FragmentsPolicy WebsocketsEndpoint::getFragmentsPolicy() const { return this->_fragmentsPolicy; }
468
469WebsocketsEndpoint::~WebsocketsEndpoint() {}
470} // namespace internals
471} // namespace websockets
uint32_t
Definition flasher.cpp:195
FLASHMEM uint64_t readExtendedPayloadLength(network::TcpClient &socket, const Header &header)
Definition endpoint.cpp:143
uint32_t readUntilSuccessfullOrError(network::TcpClient &socket, uint8_t *buffer, const uint32_t len)
Definition endpoint.cpp:127
FLASHMEM std::string readData(network::TcpClient &socket, uint64_t extendedPayload)
Definition endpoint.cpp:166
uint32_t swapEndianess(uint32_t num)
Definition endpoint.cpp:46
void remaskData(std::string &data, const uint8_t *const maskingKey, uint64_t payloadLength)
Definition endpoint.cpp:190
void readMaskingKey(network::TcpClient &socket, uint8_t *outputBuffer)
Definition endpoint.cpp:161
Header readHeaderFromSocket(network::TcpClient &socket)
Definition endpoint.cpp:135
FLASHMEM CloseReason GetCloseReason(uint16_t reasonCode)
Definition endpoint.cpp:7