Skip to content

Commit 2f504fc

Browse files
committed
refactor
1 parent 726fee1 commit 2f504fc

7 files changed

Lines changed: 106 additions & 102 deletions

File tree

src/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ set(publicHeaders
1515

1616
set(privateHeaders
1717
"simple_socket/socket_common.hpp"
18-
"simple_socket/Socket.hpp"
18+
"simple_socket/SocketConnection.hpp"
19+
20+
"simple_socket/tls/TLSConnection.hpp"
1921

2022
"simple_socket/util/uuid.hpp"
2123
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
#ifndef SIMPLE_SOCKET_SOCKET_HPP
3+
#define SIMPLE_SOCKET_SOCKET_HPP
4+
5+
#include "simple_socket/SimpleConnection.hpp"
6+
#include "simple_socket/socket_common.hpp"
7+
8+
9+
namespace simple_socket {
10+
11+
struct SocketConnection: SimpleConnection {
12+
13+
explicit SocketConnection(SOCKET socket)
14+
: sockfd_(socket) {}
15+
16+
int read(unsigned char* buffer, size_t size) override {
17+
18+
#ifdef _WIN32
19+
const auto read = recv(sockfd_, reinterpret_cast<char*>(buffer), static_cast<int>(size), 0);
20+
#else
21+
const auto read = ::read(sockfd_, buffer, size);
22+
#endif
23+
24+
return (read != SOCKET_ERROR) && (read != 0) ? read : -1;
25+
}
26+
27+
bool write(const unsigned char* data, size_t size) override {
28+
29+
#ifdef _WIN32
30+
return send(sockfd_, reinterpret_cast<const char*>(data), static_cast<int>(size), 0) != SOCKET_ERROR;
31+
#else
32+
return ::write(sockfd_, data, size) != SOCKET_ERROR;
33+
#endif
34+
}
35+
36+
void close() override {
37+
38+
closeSocket(sockfd_);
39+
}
40+
41+
~SocketConnection() override {
42+
43+
SocketConnection::close();
44+
}
45+
46+
SOCKET sockfd_;
47+
};
48+
49+
50+
}// namespace simple_socket
51+
52+
#endif//SIMPLE_SOCKET_SOCKET_HPP

src/simple_socket/TCPSocket.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
#include "simple_socket/TCPSocket.hpp"
33

44
#include "simple_socket/SimpleConnection.hpp"
5-
#include "simple_socket/Socket.hpp"
6-
5+
#include "simple_socket/SocketConnection.hpp"
76

87
#ifdef _WIN32
98
#include <ws2def.h>
@@ -12,6 +11,10 @@
1211
#include <netdb.h>
1312
#endif
1413

14+
#ifdef SIMPLE_SOCKET_WITH_TLS
15+
#include "simple_socket/tls/TLSConnection.hpp"
16+
#endif
17+
1518

1619
using namespace simple_socket;
1720

@@ -82,7 +85,7 @@ struct TCPServer::Impl {
8285
throwSocketError("Accept failed");
8386
}
8487

85-
return std::make_unique<Socket>(new_sock);
88+
return std::make_unique<SocketConnection>(new_sock);
8689
}
8790

8891
void close() {
@@ -95,7 +98,7 @@ struct TCPServer::Impl {
9598
WSASession session;
9699
#endif
97100

98-
Socket socket;
101+
SocketConnection socket;
99102
};
100103

101104
TCPServer::TCPServer(uint16_t port, int backlog)
@@ -142,29 +145,12 @@ TCPServer::~TCPServer() = default;
142145
}
143146
if (useTLS) {
144147
#ifdef SIMPLE_SOCKET_WITH_TLS
145-
146-
SSL_library_init();
147-
SSL_load_error_strings();
148-
const SSL_METHOD* method = TLS_client_method();
149-
SSL_CTX* ctx = SSL_CTX_new(method);
150-
if (!ctx) return nullptr;
151-
152-
SSL* ssl = SSL_new(ctx);
153-
SSL_set_fd(ssl, static_cast<int>(sock));
154-
SSL_set_tlsext_host_name(ssl, ip.c_str());
155-
if (SSL_connect(ssl) <= 0) {
156-
ERR_print_errors_fp(stderr);
157-
SSL_free(ssl);
158-
SSL_CTX_free(ctx);
159-
closeSocket(sock);
160-
return nullptr;
161-
}
162-
return std::make_unique<TLSConnection>(sock, ssl, ctx);
148+
return std::make_unique<TLSConnection>(sock, ip);
163149
#else
164150
throw std::runtime_error("TLS support is not enabled in this build.");
165151
#endif
166152
}
167-
return std::make_unique<Socket>(sock);
153+
return std::make_unique<SocketConnection>(sock);
168154
}
169155

170156
std::unique_ptr<SimpleConnection> TCPClientContext::connect(const std::string& host) {

src/simple_socket/UnixDomainSocket.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#include "simple_socket/UnixDomainSocket.hpp"
33

4-
#include "simple_socket/Socket.hpp"
4+
#include "simple_socket/SocketConnection.hpp"
55

66
#ifdef _WIN32
77
#include <afunix.h>
@@ -55,15 +55,15 @@ struct UnixDomainServer::Impl {
5555
}
5656
}
5757

58-
std::unique_ptr<Socket> accept() {
58+
std::unique_ptr<SocketConnection> accept() {
5959

6060
SOCKET new_sock = ::accept(socket.sockfd_, nullptr, nullptr);
6161
if (new_sock == INVALID_SOCKET) {
6262

6363
throwSocketError("Accept failed");
6464
}
6565

66-
return std::make_unique<Socket>(new_sock);
66+
return std::make_unique<SocketConnection>(new_sock);
6767
}
6868

6969
void close() {
@@ -80,7 +80,7 @@ struct UnixDomainServer::Impl {
8080
WSASession session;
8181
#endif
8282

83-
Socket socket;
83+
SocketConnection socket;
8484
std::string domain;
8585
};
8686

@@ -111,7 +111,7 @@ std::unique_ptr<SimpleConnection> UnixDomainClientContext::connect(const std::st
111111

112112
if (::connect(sockfd, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) >= 0) {
113113

114-
return std::make_unique<Socket>(sockfd);
114+
return std::make_unique<SocketConnection>(sockfd);
115115
}
116116

117117
return nullptr;

src/simple_socket/socket_common.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "WSASession.hpp"
99
#include <WinSock2.h>
1010
#include <ws2tcpip.h>
11-
#include <ws2def.h>
1211
#else
1312
#include <arpa/inet.h>
1413
#include <sys/socket.h>
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,43 @@
11

2-
#ifndef SIMPLE_SOCKET_SOCKET_HPP
3-
#define SIMPLE_SOCKET_SOCKET_HPP
2+
#ifndef SIMPLE_SOCKET_TLSCONNECTION_HPP
3+
#define SIMPLE_SOCKET_TLSCONNECTION_HPP
44

55
#include "simple_socket/SimpleConnection.hpp"
66
#include "simple_socket/socket_common.hpp"
77

8-
#ifdef SIMPLE_SOCKET_WITH_TLS
98
#include <openssl/err.h>
109
#include <openssl/ssl.h>
11-
#endif
1210

1311
#ifndef _WIN32
1412
#include <fcntl.h>
1513
#endif
1614

17-
1815
namespace simple_socket {
1916

20-
struct Socket: SimpleConnection {
21-
22-
explicit Socket(SOCKET socket)
23-
: sockfd_(socket) {}
24-
25-
int read(unsigned char* buffer, size_t size) override {
26-
27-
#ifdef _WIN32
28-
const auto read = recv(sockfd_, reinterpret_cast<char*>(buffer), static_cast<int>(size), 0);
29-
#else
30-
const auto read = ::read(sockfd_, buffer, size);
31-
#endif
32-
33-
return (read != SOCKET_ERROR) && (read != 0) ? read : -1;
34-
}
35-
36-
bool write(const unsigned char* data, size_t size) override {
37-
38-
#ifdef _WIN32
39-
return send(sockfd_, reinterpret_cast<const char*>(data), static_cast<int>(size), 0) != SOCKET_ERROR;
40-
#else
41-
return ::write(sockfd_, data, size) != SOCKET_ERROR;
42-
#endif
43-
}
44-
45-
void close() override {
46-
47-
closeSocket(sockfd_);
48-
}
49-
50-
~Socket() override {
51-
52-
closeSocket(sockfd_);
53-
}
54-
55-
SOCKET sockfd_;
56-
};
57-
58-
#ifdef SIMPLE_SOCKET_WITH_TLS
59-
60-
61-
// Helper: put socket into non-blocking mode
62-
inline void set_nonblocking(SOCKET s) {
63-
#ifdef _WIN32
64-
u_long mode = 1;
65-
ioctlsocket(s, FIONBIO, &mode);
66-
#else
67-
int flags = fcntl(s, F_GETFL, 0);
68-
if (flags >= 0) fcntl(s, F_SETFL, flags | O_NONBLOCK);
69-
#endif
70-
}
71-
72-
7317
class TLSConnection: public SimpleConnection {
7418
public:
75-
TLSConnection(SOCKET sock, SSL* ssl, SSL_CTX* ctx = nullptr)
76-
: sockfd_(sock), ssl_(ssl), ctx_(ctx) {
19+
TLSConnection(SOCKET sock, const std::string& ip)
20+
: sockfd_(sock) {
21+
22+
SSL_library_init();
23+
SSL_load_error_strings();
24+
const SSL_METHOD* method = TLS_client_method();
25+
ctx_ = SSL_CTX_new(method);
26+
if (!ctx_) throw std::runtime_error("Failed to create SSL context");
27+
28+
ssl_ = SSL_new(ctx_);
29+
SSL_set_fd(ssl_, static_cast<int>(sock));
30+
SSL_set_tlsext_host_name(ssl_, ip.c_str());
31+
if (SSL_connect(ssl_) <= 0) {
32+
ERR_print_errors_fp(stderr);
33+
throw std::runtime_error("Failed to connect to TLS host");
34+
}
7735

7836
set_nonblocking(sockfd_);
7937
// Ensure AUTO_RETRY is off for non-blocking semantics
8038
SSL_clear_mode(ssl_, SSL_MODE_AUTO_RETRY);
8139
}
8240

83-
8441
bool write(const uint8_t* buf, size_t len) override {
8542
if (!ssl_) return false;
8643

@@ -93,9 +50,9 @@ namespace simple_socket {
9350
}
9451
const int err = SSL_get_error(ssl_, n);
9552
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
96-
continue; // retry
53+
continue;// retry
9754
}
98-
return false; // fatal
55+
return false;// fatal
9956
}
10057
return true;
10158
}
@@ -108,12 +65,12 @@ namespace simple_socket {
10865
if (n > 0) return n;
10966

11067
const int err = SSL_get_error(ssl_, n);
111-
if (err == SSL_ERROR_ZERO_RETURN) return 0; // clean TLS shutdown
68+
if (err == SSL_ERROR_ZERO_RETURN) return 0;// clean TLS shutdown
11269
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
11370
// retry until caller cancels via close()
11471
continue;
11572
}
116-
return -1; // fatal
73+
return -1;// fatal
11774
}
11875
}
11976

@@ -141,9 +98,19 @@ namespace simple_socket {
14198
SSL* ssl_;
14299
SSL_CTX* ctx_;
143100

144-
};
101+
// Helper: put socket into non-blocking mode
102+
static void set_nonblocking(SOCKET s) {
103+
#ifdef _WIN32
104+
u_long mode = 1;
105+
ioctlsocket(s, FIONBIO, &mode);
106+
#else
107+
int flags = fcntl(s, F_GETFL, 0);
108+
if (flags >= 0) fcntl(s, F_SETFL, flags | O_NONBLOCK);
145109
#endif
110+
}
111+
};
112+
146113

147-
}// namespace simple_socket
114+
}
148115

149-
#endif//SIMPLE_SOCKET_SOCKET_HPP
116+
#endif //TLSCONNECTION_HPP

tests/test_wss_client.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11

22
#include <catch2/catch_test_macros.hpp>
33

4-
#include "../include/simple_socket/ws/WebSocket.hpp"
4+
#include "simple_socket/ws/WebSocket.hpp"
55

6-
#include <algorithm>
7-
#include <execution>
86
#include <iostream>
97
#include <mutex>
108
#include <semaphore>
@@ -18,7 +16,7 @@ TEST_CASE("Echo WebSocketClient wss") {
1816

1917
std::mutex mutex;
2018

21-
std::vector<WebSocketClient> clients(5);
19+
std::vector<WebSocketClient> clients(3);
2220
for (auto& client : clients) {
2321

2422
std::binary_semaphore latch{0};

0 commit comments

Comments
 (0)