diff --git a/src/tcp_server.cpp b/src/tcp_server.cpp index 1fdeb79..9d48ca3 100644 --- a/src/tcp_server.cpp +++ b/src/tcp_server.cpp @@ -128,7 +128,7 @@ void TCPServer::parse_address_list_internal(const std::string& address_list, con } } - const uint32_t port = std::stoul(address.substr(k2 + 1), nullptr, 10); + const uint32_t port = strtoul(address.c_str() + k2 + 1, nullptr, 10); if ((port > 0) && (port < 65536)) { callback(is_v6, address, ip, static_cast(port)); } @@ -340,6 +340,7 @@ bool TCPServer::connect_to_peer(const std::string& domain, int port) if (s.m_spilled) { LOGERR(1, "Can't connect to " << domain << ": too long domain name"); + return_client(client); return false; } @@ -429,6 +430,7 @@ bool TCPServer::connect_to_peer(Client* client) } else { LOGWARN(5, "failed to initiate tcp connection to " << static_cast(client->m_addrString) << ": domain name is unresolved and SOCKS5 proxy is not enabled"); + uv_close(reinterpret_cast(&client->m_socket), on_connection_error); return false; } } @@ -582,6 +584,11 @@ bool TCPServer::send_internal(Client* client, const Callback m_callbackBuf.size()) { @@ -613,7 +620,7 @@ bool TCPServer::send_internal(Client* client, const Callback(buf->m_data); - bufs[0].len = static_cast(size); + bufs[0].len = static_cast(size); const int err = uv_write(&buf->m_write, reinterpret_cast(&client->m_socket), bufs, 1, Client::on_write); if (err) { @@ -937,6 +944,7 @@ void TCPServer::on_new_client(uv_stream_t* server, Client* client) } else { client->close(); + return; } } } @@ -1258,7 +1266,8 @@ bool TCPServer::Client::on_proxy_handshake(const char* data, uint32_t size) m_socks5ProxyState = Socks5ProxyState::ConnectRequestSent; } else { - close(); + LOGWARN(5, "Failed to send SOCKS5 proxy connect request"); + return false; } } break; @@ -1466,6 +1475,13 @@ void TCPServer::Client::on_write(uv_write_t* req, int status) if (server) { server->return_write_buffer(buf); } + else { + // no server to return buf to, so just delete it + if (buf->m_data) { + free_hook(buf->m_data); + } + delete buf; + } if (status != 0) { LOGWARN(5, "client " << static_cast(client->m_addrString) << " failed to write data to client connection, error " << uv_err_name(status)); diff --git a/src/tls.cpp b/src/tls.cpp index ed2b10e..d589c8a 100644 --- a/src/tls.cpp +++ b/src/tls.cpp @@ -232,7 +232,16 @@ bool ServerTls::on_read_internal(const char* data, uint32_t size, const ReadCall return false; } - if (!BIO_write_all(SSL_get_rbio(ssl), data, size)) { + BIO* rbio = SSL_get_rbio(ssl); + if (!rbio) { + return false; + } + + if (!BIO_write_all(rbio, data, size)) { + return false; + } + + if (BIO_ctrl_pending(rbio) > 32768u) { return false; } @@ -266,9 +275,12 @@ bool ServerTls::on_read_internal(const char* data, uint32_t size, const ReadCall } } - if ((result < 0) && (SSL_get_error(ssl, result) == SSL_ERROR_WANT_READ)) { - // Continue handshake, nothing to read yet - return true; + if (result < 0) { + const int err = SSL_get_error(ssl, result); + if ((err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE)) { + // Continue handshake, nothing to read yet + return true; + } } if (result == 1) { @@ -289,11 +301,16 @@ bool ServerTls::on_read_internal(const char* data, uint32_t size, const ReadCall } } - return true; + const int err = SSL_get_error(ssl, bytes_read); + return (err == SSL_ERROR_WANT_READ) || (err == SSL_ERROR_WANT_WRITE); } bool ServerTls::on_write_internal(const uint8_t* data, size_t size, const WriteCallback::Base& write_callback) { + if (size > static_cast(std::numeric_limits::max())) { + return false; + } + SSL* ssl = m_ssl.get(); if (!ssl) { return false;