diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 8fd8012444..2bdd2c9619 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -204,8 +204,10 @@ ErrorOr IPv4Socket::sendto(OpenFileDescription&, UserOrKernelBuffer cons return set_so_error(EAFNOSUPPORT); } - m_peer_address = IPv4Address((u8 const*)&ia.sin_addr.s_addr); - m_peer_port = ntohs(ia.sin_port); + if (type() != SOCK_STREAM) { + m_peer_address = IPv4Address((u8 const*)&ia.sin_addr.s_addr); + m_peer_port = ntohs(ia.sin_port); + } } if (!is_connected() && m_peer_address.is_zero()) diff --git a/Tests/Kernel/CMakeLists.txt b/Tests/Kernel/CMakeLists.txt index b8b3d5040b..e05be25829 100644 --- a/Tests/Kernel/CMakeLists.txt +++ b/Tests/Kernel/CMakeLists.txt @@ -56,6 +56,7 @@ set(LIBTEST_BASED_SOURCES TestSigAltStack.cpp TestSigHandler.cpp TestSigWait.cpp + TestTCPSocket.cpp ) if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") diff --git a/Tests/Kernel/TestTCPSocket.cpp b/Tests/Kernel/TestTCPSocket.cpp new file mode 100644 index 0000000000..34e3f2a2d1 --- /dev/null +++ b/Tests/Kernel/TestTCPSocket.cpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2023, Idan Horowitz + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include +#include +#include + +static constexpr u16 port = 1337; + +static void* server_handler(void*) +{ + int server_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT(server_fd >= 0); + + sockaddr_in sin {}; + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + int rc = bind(server_fd, (sockaddr*)(&sin), sizeof(sin)); + EXPECT_EQ(rc, 0); + + rc = listen(server_fd, 1); + EXPECT_EQ(rc, 0); + + int client_fd = accept(server_fd, nullptr, nullptr); + EXPECT(client_fd >= 0); + + u8 data; + int nread = recv(client_fd, &data, sizeof(data), 0); + EXPECT_EQ(nread, 1); + EXPECT_EQ(data, 'A'); + + rc = close(client_fd); + EXPECT_EQ(rc, 0); + + rc = close(server_fd); + EXPECT_EQ(rc, 0); + + pthread_exit(nullptr); + VERIFY_NOT_REACHED(); +} + +static pthread_t start_tcp_server() +{ + pthread_t thread; + int rc = pthread_create(&thread, nullptr, server_handler, nullptr); + EXPECT_EQ(rc, 0); + return thread; +} + +TEST_CASE(tcp_sendto) +{ + pthread_t server = start_tcp_server(); + + int client_fd = socket(AF_INET, SOCK_STREAM, 0); + EXPECT(client_fd >= 0); + + sockaddr_in sin {}; + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + int rc = connect(client_fd, (sockaddr*)(&sin), sizeof(sin)); + EXPECT_EQ(rc, 0); + + u8 data = 'A'; + sockaddr_in dst {}; + dst.sin_family = AF_INET; + dst.sin_port = htons(port + 1); // Different port, should be ignored + dst.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + int nwritten = sendto(client_fd, &data, sizeof(data), 0, (sockaddr*)(&dst), sizeof(dst)); + EXPECT_EQ(nwritten, 1); + + rc = close(client_fd); + EXPECT_EQ(rc, 0); + + rc = pthread_join(server, nullptr); + EXPECT_EQ(rc, 0); +}