diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index b904694ddb..27bd7019d3 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -74,7 +74,7 @@ bool IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size) KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) { - ASSERT(!is_connected()); + ASSERT(setup_state() == SetupState::Unstarted); if (address_size != sizeof(sockaddr_in)) return KResult(-EINVAL); if (address->sa_family != AF_INET) diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index dde13cd2a1..1309c77078 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -41,7 +41,7 @@ bool LocalSocket::get_peer_address(sockaddr* address, socklen_t* address_size) KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size) { - ASSERT(!is_connected()); + ASSERT(setup_state() == SetupState::Unstarted); if (address_size != sizeof(sockaddr_un)) return KResult(-EINVAL); if (address->sa_family != AF_LOCAL) @@ -68,6 +68,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size) m_address = local_address; m_bound = true; + set_setup_state(SetupState::Completed); return KSuccess; } @@ -109,6 +110,10 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre if (current->block(description) == Thread::BlockResult::InterruptedBySignal) return KResult(-EINTR); +#ifdef DEBUG_LOCAL_SOCKET + kprintf("%s(%u) LocalSocket{%p} connect(%s) status is %s\n", current->process().name().characters(), current->pid(), this, safe_address, to_string(setup_state())); +#endif + if (!is_connected()) return KResult(-ECONNREFUSED); return KSuccess; diff --git a/Kernel/Net/NetworkTask.cpp b/Kernel/Net/NetworkTask.cpp index ae6351c9a5..a4eaa0ee50 100644 --- a/Kernel/Net/NetworkTask.cpp +++ b/Kernel/Net/NetworkTask.cpp @@ -414,12 +414,14 @@ void handle_tcp(const IPv4Packet& ipv4_packet) socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->send_tcp_packet(TCPFlags::ACK); socket->set_state(TCPSocket::State::Established); + socket->set_setup_state(Socket::SetupState::Completed); socket->set_connected(true); return; default: kprintf("handle_tcp: unexpected flags in SynSent state\n"); socket->send_tcp_packet(TCPFlags::RST); socket->set_state(TCPSocket::State::Closed); + socket->set_setup_state(Socket::SetupState::Completed); return; } case TCPSocket::State::SynReceived: @@ -427,8 +429,10 @@ void handle_tcp(const IPv4Packet& ipv4_packet) case TCPFlags::ACK: socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_state(TCPSocket::State::Established); - if (socket->direction() == TCPSocket::Direction::Outgoing) + if (socket->direction() == TCPSocket::Direction::Outgoing) { + socket->set_setup_state(Socket::SetupState::Completed); socket->set_connected(true); + } return; default: kprintf("handle_tcp: unexpected flags in SynReceived state\n"); diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index 79e1198157..6a2ea1fbd9 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -6,6 +6,8 @@ #include #include +//#define SOCKET_DEBUG + KResultOr> Socket::create(int domain, int type, int protocol) { (void)protocol; @@ -31,6 +33,15 @@ Socket::~Socket() { } +void Socket::set_setup_state(SetupState new_setup_state) +{ +#ifdef SOCKET_DEBUG + kprintf("%s(%u) Socket{%p} setup state moving from %s to %s\n", current->process().name().characters(), current->pid(), this, to_string(m_setup_state), to_string(new_setup_state)); +#endif + + m_setup_state = new_setup_state; +} + KResult Socket::listen(int backlog) { LOCKER(m_lock); @@ -46,14 +57,21 @@ RefPtr Socket::accept() LOCKER(m_lock); if (m_pending.is_empty()) return nullptr; +#ifdef SOCKET_DEBUG + kprintf("%s(%u) Socket{%p} de-queueing connection\n", current->process().name().characters(), current->pid(), this); +#endif auto client = m_pending.take_first(); ASSERT(!client->is_connected()); + client->set_setup_state(SetupState::Completed); client->m_connected = true; return client; } KResult Socket::queue_connection_from(NonnullRefPtr peer) { +#ifdef SOCKET_DEBUG + kprintf("%s(%u) Socket{%p} queueing connection\n", current->process().name().characters(), current->pid(), this); +#endif LOCKER(m_lock); if (m_pending.size() >= m_backlog) return KResult(-ECONNREFUSED); @@ -144,7 +162,7 @@ static const char* to_string(SocketRole role) String Socket::absolute_path(const FileDescription& description) const { - return String::format("socket:%x (role: %s)", this, to_string(description.socket_role())); + return String::format("socket:%x (role: %s)", this, ::to_string(description.socket_role())); } ssize_t Socket::read(FileDescription& description, u8* buffer, ssize_t size) diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index 73d040b5e6..d4d4eec40e 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -32,9 +32,34 @@ public: int type() const { return m_type; } int protocol() const { return m_protocol; } + enum class SetupState { + Unstarted, // we haven't tried to set the socket up yet + InProgress, // we're in the process of setting things up - for TCP maybe we've sent a SYN packet + Completed, // the setup process is complete, but not necessarily successful + }; + + static const char* to_string(SetupState setup_state) + { + switch (setup_state) { + case SetupState::Unstarted: + return "Unstarted"; + case SetupState::InProgress: + return "InProgress"; + case SetupState::Completed: + return "Completed"; + default: + return "None"; + } + } + + SetupState setup_state() const { return m_setup_state; } + void set_setup_state(SetupState setup_state); + + bool is_connected() const { return m_connected; } + void set_connected(bool connected) { m_connected = connected; } + bool can_accept() const { return !m_pending.is_empty(); } RefPtr accept(); - bool is_connected() const { return m_connected; } virtual KResult bind(const sockaddr*, socklen_t) = 0; virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; @@ -56,8 +81,6 @@ public: timeval receive_deadline() const { return m_receive_deadline; } timeval send_deadline() const { return m_send_deadline; } - void set_connected(bool connected) { m_connected = connected; } - Lock& lock() { return m_lock; } // ^File @@ -87,6 +110,7 @@ private: int m_type { 0 }; int m_protocol { 0 }; int m_backlog { 0 }; + SetupState m_setup_state { SetupState::Unstarted }; bool m_connected { false }; timeval m_receive_timeout { 0, 0 }; diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index f665823962..8f34f6859e 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -70,6 +70,7 @@ SocketHandle TCPSocket::create_client(const IPv4Address& new_local_ad auto client = TCPSocket::create(protocol()); + client->set_setup_state(SetupState::InProgress); client->set_local_address(new_local_address); client->set_local_port(new_local_port); client->set_peer_address(new_peer_address); @@ -239,7 +240,7 @@ KResult TCPSocket::protocol_listen() sockets_by_tuple().resource().set(tuple(), this); set_direction(Direction::Passive); set_state(State::Listen); - set_connected(true); + set_setup_state(SetupState::Completed); return KSuccess; } @@ -264,6 +265,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh m_sequence_number = 0; m_ack_number = 0; + set_setup_state(SetupState::InProgress); send_tcp_packet(TCPFlags::SYN); m_state = State::SynSent; m_direction = Direction::Outgoing; @@ -271,7 +273,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh if (should_block == ShouldBlock::Yes) { if (current->block(description) == Thread::BlockResult::InterruptedBySignal) return KResult(-EINTR); - ASSERT(is_connected()); + ASSERT(setup_state() == SetupState::Completed); return KSuccess; } diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index 526bb3e947..beb3f7024a 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2313,7 +2313,7 @@ int Process::sys$getpeername(int sockfd, sockaddr* addr, socklen_t* addrlen) auto& socket = *description->socket(); - if (!socket.is_connected()) + if (socket.setup_state() != Socket::SetupState::Completed) return -ENOTCONN; if (!socket.get_peer_address(addr, addrlen)) diff --git a/Kernel/Scheduler.cpp b/Kernel/Scheduler.cpp index 30e06d2403..00e42d7df7 100644 --- a/Kernel/Scheduler.cpp +++ b/Kernel/Scheduler.cpp @@ -111,7 +111,7 @@ Thread::ConnectBlocker::ConnectBlocker(const FileDescription& description) bool Thread::ConnectBlocker::should_unblock(Thread&, time_t, long) { auto& socket = *blocked_description().socket(); - return socket.is_connected(); + return socket.setup_state() == Socket::SetupState::Completed; } Thread::WriteBlocker::WriteBlocker(const FileDescription& description)