271 lines
8.2 KiB
C++
271 lines
8.2 KiB
C++
#pragma once
|
|
#include "utils.hpp"
|
|
#include "ws_define.h"
|
|
|
|
namespace cinatra {
|
|
enum ws_header_status {
|
|
error = -1,
|
|
complete = 0,
|
|
incomplete = -2,
|
|
};
|
|
class websocket {
|
|
public:
|
|
void sec_ws_key(std::string_view sec_key) { sec_ws_key_ = sec_key; }
|
|
|
|
std::string_view get_sec_ws_key() { return sec_ws_key_; }
|
|
|
|
/*
|
|
0 1 2 3
|
|
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
|
+-+-+-+-+-------+-+-------------+-------------------------------+
|
|
|F|R|R|R| opcode|M| Payload len | Extended payload length |
|
|
|I|S|S|S| (4) |A| (7) | (16/64) |
|
|
|N|V|V|V| |S| | (if payload len==126/127) |
|
|
| |1|2|3| |K| | |
|
|
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
|
| Extended payload length continued, if payload len == 127 |
|
|
+ - - - - - - - - - - - - - - - +-------------------------------+
|
|
| |Masking-key, if MASK set to 1 |
|
|
+-------------------------------+-------------------------------+
|
|
| Masking-key (continued) | Payload Data |
|
|
+-------------------------------- - - - - - - - - - - - - - - - +
|
|
: Payload Data continued ... :
|
|
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
|
| Payload Data continued ... |
|
|
+---------------------------------------------------------------+
|
|
opcode:
|
|
* %x0 denotes a continuation frame
|
|
* %x1 denotes a text frame
|
|
* %x2 denotes a binary frame
|
|
* %x3-7 are reserved for further non-control frames
|
|
* %x8 denotes a connection close
|
|
* %x9 denotes a ping
|
|
* %xA denotes a pong
|
|
* %xB-F are reserved for further control frames
|
|
Payload length: 7 bits, 7+16 bits, or 7+64 bits
|
|
Masking-key: 0 or 4 bytes
|
|
*/
|
|
ws_header_status parse_header(const char *buf, size_t size,
|
|
bool is_server = true) {
|
|
const unsigned char *inp = (const unsigned char *)(buf);
|
|
|
|
msg_opcode_ = inp[0] & 0x0F;
|
|
msg_fin_ = (inp[0] >> 7) & 0x01;
|
|
unsigned char msg_masked = (inp[1] >> 7) & 0x01;
|
|
|
|
int pos = 2;
|
|
int length_field = inp[1] & (~0x80);
|
|
|
|
left_header_len_ = 0;
|
|
if (length_field <= 125) {
|
|
len_bytes_ = SHORT_HEADER;
|
|
payload_length_ = length_field;
|
|
}
|
|
else if (length_field == 126) // msglen is 16bit!
|
|
{
|
|
len_bytes_ = MEDIUM_HEADER;
|
|
payload_length_ = ntohs(*(uint16_t *)&inp[2]); // (inp[2] << 8) + inp[3];
|
|
pos += 2;
|
|
left_header_len_ =
|
|
is_server ? MEDIUM_HEADER - size : CLIENT_MEDIUM_HEADER - size;
|
|
}
|
|
else if (length_field == 127) // msglen is 64bit!
|
|
{
|
|
len_bytes_ = LONG_HEADER;
|
|
payload_length_ = (size_t)be64toh(*(uint64_t *)&inp[2]);
|
|
pos += 8;
|
|
left_header_len_ =
|
|
is_server ? LONG_HEADER - size : CLIENT_LONG_HEADER - size;
|
|
}
|
|
else {
|
|
len_bytes_ = INVALID_HEADER;
|
|
return ws_header_status::error;
|
|
}
|
|
|
|
if (msg_masked) {
|
|
std::memcpy(mask_, inp + pos, 4);
|
|
}
|
|
|
|
return left_header_len_ == 0 ? ws_header_status::complete
|
|
: ws_header_status::incomplete;
|
|
}
|
|
|
|
int len_bytes() const { return len_bytes_; }
|
|
void reset_len_bytes() { len_bytes_ = SHORT_HEADER; }
|
|
|
|
ws_frame_type parse_payload(std::span<char> buf) {
|
|
// unmask data:
|
|
if (*(uint32_t *)mask_ != 0) {
|
|
for (size_t i = 0; i < payload_length_; i++) {
|
|
buf[i] = buf[i] ^ mask_[i % 4];
|
|
}
|
|
}
|
|
|
|
if (msg_opcode_ == 0x0)
|
|
return (msg_fin_)
|
|
? ws_frame_type::WS_TEXT_FRAME
|
|
: ws_frame_type::WS_INCOMPLETE_TEXT_FRAME; // continuation
|
|
// frame ?
|
|
if (msg_opcode_ == 0x1)
|
|
return (msg_fin_) ? ws_frame_type::WS_TEXT_FRAME
|
|
: ws_frame_type::WS_INCOMPLETE_TEXT_FRAME;
|
|
if (msg_opcode_ == 0x2)
|
|
return (msg_fin_) ? ws_frame_type::WS_BINARY_FRAME
|
|
: ws_frame_type::WS_INCOMPLETE_BINARY_FRAME;
|
|
if (msg_opcode_ == 0x8)
|
|
return ws_frame_type::WS_CLOSE_FRAME;
|
|
if (msg_opcode_ == 0x9)
|
|
return ws_frame_type::WS_PING_FRAME;
|
|
if (msg_opcode_ == 0xA)
|
|
return ws_frame_type::WS_PONG_FRAME;
|
|
return ws_frame_type::WS_BINARY_FRAME;
|
|
}
|
|
|
|
std::string format_header(size_t length, opcode code) {
|
|
size_t header_length = encode_header(length, code);
|
|
return {msg_header_, header_length};
|
|
}
|
|
|
|
std::string encode_frame(std::span<char> &data, opcode op, bool need_mask,
|
|
bool eof = true) {
|
|
std::string header;
|
|
/// Base header.
|
|
frame_header hdr{};
|
|
hdr.fin = eof;
|
|
hdr.rsv1 = 0;
|
|
hdr.rsv2 = 0;
|
|
hdr.rsv3 = 0;
|
|
hdr.opcode = static_cast<uint8_t>(op);
|
|
hdr.mask = 1;
|
|
|
|
if (data.empty()) {
|
|
int mask = 0;
|
|
header.resize(sizeof(frame_header) + sizeof(mask));
|
|
std::memcpy(header.data(), &hdr, sizeof(hdr));
|
|
std::memcpy(header.data() + sizeof(hdr), &mask, sizeof(mask));
|
|
return header;
|
|
}
|
|
|
|
hdr.len =
|
|
data.size() < 126 ? data.size() : (data.size() < 65536 ? 126 : 127);
|
|
|
|
uint8_t buffer[sizeof(frame_header)];
|
|
std::memcpy(buffer, (uint8_t *)&hdr, sizeof(hdr));
|
|
std::string str_hdr_len =
|
|
std::string((const char *)buffer, sizeof(frame_header));
|
|
header.append(str_hdr_len);
|
|
|
|
/// The payload length may be larger than 126 bytes.
|
|
std::string str_payload_len;
|
|
if (data.size() >= 126) {
|
|
if (data.size() >= 65536) {
|
|
uint64_t len = data.size();
|
|
str_payload_len.resize(sizeof(uint64_t));
|
|
*((uint64_t *)&str_payload_len[0]) = htobe64(len);
|
|
}
|
|
else {
|
|
uint16_t len = data.size();
|
|
str_payload_len.resize(sizeof(uint16_t));
|
|
*((uint16_t *)&str_payload_len[0]) = htons(static_cast<uint16_t>(len));
|
|
}
|
|
header.append(str_payload_len);
|
|
}
|
|
|
|
/// The mask is a 32-bit value.
|
|
uint8_t mask[4] = {};
|
|
if (need_mask) {
|
|
header[1] |= 0x80;
|
|
uint32_t random = (uint32_t)rand();
|
|
memcpy(mask, &random, 4);
|
|
}
|
|
|
|
size_t size = header.size();
|
|
header.resize(size + 4);
|
|
std::memcpy(header.data() + size, mask, 4);
|
|
|
|
for (int i = 0; i < data.size(); ++i) {
|
|
data[i] ^= mask[i % 4];
|
|
}
|
|
|
|
return header;
|
|
}
|
|
|
|
close_frame parse_close_payload(char *src, size_t length) {
|
|
close_frame cf = {};
|
|
if (length >= 2) {
|
|
std::memcpy(&cf.code, src, 2);
|
|
cf = {ntohs(cf.code), src + 2, length - 2};
|
|
if (cf.code < 1000 || cf.code > 4999 ||
|
|
(cf.code > 1011 && cf.code < 4000) ||
|
|
(cf.code >= 1004 && cf.code <= 1006) ||
|
|
!is_valid_utf8((unsigned char *)cf.message, cf.length)) {
|
|
return {};
|
|
}
|
|
}
|
|
return cf;
|
|
}
|
|
|
|
std::string format_close_payload(uint16_t code, char *message,
|
|
size_t length) {
|
|
if (length == 0) {
|
|
return "";
|
|
}
|
|
std::string close_payload;
|
|
if (code) {
|
|
close_payload.resize(length + 2);
|
|
code = htons(code);
|
|
std::memcpy(close_payload.data(), &code, 2);
|
|
std::memcpy(close_payload.data() + 2, message, length);
|
|
}
|
|
return close_payload;
|
|
}
|
|
|
|
size_t left_header_len() const { return left_header_len_; }
|
|
|
|
size_t payload_length() const { return payload_length_; }
|
|
|
|
opcode get_opcode() { return (opcode)msg_opcode_; }
|
|
|
|
private:
|
|
size_t encode_header(size_t length, opcode code) {
|
|
size_t header_length;
|
|
|
|
if (length < 126) {
|
|
header_length = 2;
|
|
msg_header_[1] = static_cast<char>(length);
|
|
}
|
|
else if (length <= UINT16_MAX) {
|
|
header_length = 4;
|
|
msg_header_[1] = 126;
|
|
*((uint16_t *)&msg_header_[2]) = htons(static_cast<uint16_t>(length));
|
|
}
|
|
else {
|
|
header_length = 10;
|
|
msg_header_[1] = 127;
|
|
*((uint64_t *)&msg_header_[2]) = htobe64(length);
|
|
}
|
|
|
|
int flags = 0;
|
|
msg_header_[0] = (flags & SND_NO_FIN ? 0 : char(128));
|
|
if (!(flags & SND_CONTINUATION)) {
|
|
msg_header_[0] |= code;
|
|
}
|
|
|
|
return header_length;
|
|
}
|
|
|
|
std::string_view sec_ws_key_;
|
|
|
|
size_t payload_length_ = 0;
|
|
|
|
size_t left_header_len_ = 0;
|
|
uint8_t mask_[4] = {};
|
|
unsigned char msg_opcode_ = 0;
|
|
unsigned char msg_fin_ = 0;
|
|
|
|
char msg_header_[10];
|
|
ws_head_len len_bytes_ = SHORT_HEADER;
|
|
};
|
|
|
|
} // namespace cinatra
|