1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-05-31 08:48:11 +00:00

Kernel: Use a more detailed state machine for socket setup

This commit is contained in:
Conrad Pankoff 2019-08-10 13:17:00 +10:00 committed by Andreas Kling
parent 638008da13
commit bd6d2c0819
8 changed files with 64 additions and 11 deletions

View file

@ -74,7 +74,7 @@ bool IPv4Socket::get_peer_address(sockaddr* address, socklen_t* address_size)
KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size)
{ {
ASSERT(!is_connected()); ASSERT(setup_state() == SetupState::Unstarted);
if (address_size != sizeof(sockaddr_in)) if (address_size != sizeof(sockaddr_in))
return KResult(-EINVAL); return KResult(-EINVAL);
if (address->sa_family != AF_INET) if (address->sa_family != AF_INET)

View file

@ -41,7 +41,7 @@ bool LocalSocket::get_peer_address(sockaddr* address, socklen_t* address_size)
KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size) KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
{ {
ASSERT(!is_connected()); ASSERT(setup_state() == SetupState::Unstarted);
if (address_size != sizeof(sockaddr_un)) if (address_size != sizeof(sockaddr_un))
return KResult(-EINVAL); return KResult(-EINVAL);
if (address->sa_family != AF_LOCAL) if (address->sa_family != AF_LOCAL)
@ -68,6 +68,7 @@ KResult LocalSocket::bind(const sockaddr* address, socklen_t address_size)
m_address = local_address; m_address = local_address;
m_bound = true; m_bound = true;
set_setup_state(SetupState::Completed);
return KSuccess; return KSuccess;
} }
@ -109,6 +110,10 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre
if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal) if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal)
return KResult(-EINTR); return KResult(-EINTR);
#ifdef DEBUG_LOCAL_SOCKET
kprintf("%s(%u) LocalSocket{%p} connect(%s) status is %s\n", current->process().name().characters(), current->pid(), this, safe_address, to_string(setup_state()));
#endif
if (!is_connected()) if (!is_connected())
return KResult(-ECONNREFUSED); return KResult(-ECONNREFUSED);
return KSuccess; return KSuccess;

View file

@ -414,12 +414,14 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->send_tcp_packet(TCPFlags::ACK); socket->send_tcp_packet(TCPFlags::ACK);
socket->set_state(TCPSocket::State::Established); socket->set_state(TCPSocket::State::Established);
socket->set_setup_state(Socket::SetupState::Completed);
socket->set_connected(true); socket->set_connected(true);
return; return;
default: default:
kprintf("handle_tcp: unexpected flags in SynSent state\n"); kprintf("handle_tcp: unexpected flags in SynSent state\n");
socket->send_tcp_packet(TCPFlags::RST); socket->send_tcp_packet(TCPFlags::RST);
socket->set_state(TCPSocket::State::Closed); socket->set_state(TCPSocket::State::Closed);
socket->set_setup_state(Socket::SetupState::Completed);
return; return;
} }
case TCPSocket::State::SynReceived: case TCPSocket::State::SynReceived:
@ -427,8 +429,10 @@ void handle_tcp(const IPv4Packet& ipv4_packet)
case TCPFlags::ACK: case TCPFlags::ACK:
socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1); socket->set_ack_number(tcp_packet.sequence_number() + payload_size + 1);
socket->set_state(TCPSocket::State::Established); socket->set_state(TCPSocket::State::Established);
if (socket->direction() == TCPSocket::Direction::Outgoing) if (socket->direction() == TCPSocket::Direction::Outgoing) {
socket->set_setup_state(Socket::SetupState::Completed);
socket->set_connected(true); socket->set_connected(true);
}
return; return;
default: default:
kprintf("handle_tcp: unexpected flags in SynReceived state\n"); kprintf("handle_tcp: unexpected flags in SynReceived state\n");

View file

@ -6,6 +6,8 @@
#include <Kernel/UnixTypes.h> #include <Kernel/UnixTypes.h>
#include <LibC/errno_numbers.h> #include <LibC/errno_numbers.h>
//#define SOCKET_DEBUG
KResultOr<NonnullRefPtr<Socket>> Socket::create(int domain, int type, int protocol) KResultOr<NonnullRefPtr<Socket>> Socket::create(int domain, int type, int protocol)
{ {
(void)protocol; (void)protocol;
@ -31,6 +33,15 @@ Socket::~Socket()
{ {
} }
void Socket::set_setup_state(SetupState new_setup_state)
{
#ifdef SOCKET_DEBUG
kprintf("%s(%u) Socket{%p} setup state moving from %s to %s\n", current->process().name().characters(), current->pid(), this, to_string(m_setup_state), to_string(new_setup_state));
#endif
m_setup_state = new_setup_state;
}
KResult Socket::listen(int backlog) KResult Socket::listen(int backlog)
{ {
LOCKER(m_lock); LOCKER(m_lock);
@ -46,14 +57,21 @@ RefPtr<Socket> Socket::accept()
LOCKER(m_lock); LOCKER(m_lock);
if (m_pending.is_empty()) if (m_pending.is_empty())
return nullptr; return nullptr;
#ifdef SOCKET_DEBUG
kprintf("%s(%u) Socket{%p} de-queueing connection\n", current->process().name().characters(), current->pid(), this);
#endif
auto client = m_pending.take_first(); auto client = m_pending.take_first();
ASSERT(!client->is_connected()); ASSERT(!client->is_connected());
client->set_setup_state(SetupState::Completed);
client->m_connected = true; client->m_connected = true;
return client; return client;
} }
KResult Socket::queue_connection_from(NonnullRefPtr<Socket> peer) KResult Socket::queue_connection_from(NonnullRefPtr<Socket> peer)
{ {
#ifdef SOCKET_DEBUG
kprintf("%s(%u) Socket{%p} queueing connection\n", current->process().name().characters(), current->pid(), this);
#endif
LOCKER(m_lock); LOCKER(m_lock);
if (m_pending.size() >= m_backlog) if (m_pending.size() >= m_backlog)
return KResult(-ECONNREFUSED); return KResult(-ECONNREFUSED);
@ -144,7 +162,7 @@ static const char* to_string(SocketRole role)
String Socket::absolute_path(const FileDescription& description) const String Socket::absolute_path(const FileDescription& description) const
{ {
return String::format("socket:%x (role: %s)", this, to_string(description.socket_role())); return String::format("socket:%x (role: %s)", this, ::to_string(description.socket_role()));
} }
ssize_t Socket::read(FileDescription& description, u8* buffer, ssize_t size) ssize_t Socket::read(FileDescription& description, u8* buffer, ssize_t size)

View file

@ -32,9 +32,34 @@ public:
int type() const { return m_type; } int type() const { return m_type; }
int protocol() const { return m_protocol; } int protocol() const { return m_protocol; }
enum class SetupState {
Unstarted, // we haven't tried to set the socket up yet
InProgress, // we're in the process of setting things up - for TCP maybe we've sent a SYN packet
Completed, // the setup process is complete, but not necessarily successful
};
static const char* to_string(SetupState setup_state)
{
switch (setup_state) {
case SetupState::Unstarted:
return "Unstarted";
case SetupState::InProgress:
return "InProgress";
case SetupState::Completed:
return "Completed";
default:
return "None";
}
}
SetupState setup_state() const { return m_setup_state; }
void set_setup_state(SetupState setup_state);
bool is_connected() const { return m_connected; }
void set_connected(bool connected) { m_connected = connected; }
bool can_accept() const { return !m_pending.is_empty(); } bool can_accept() const { return !m_pending.is_empty(); }
RefPtr<Socket> accept(); RefPtr<Socket> accept();
bool is_connected() const { return m_connected; }
virtual KResult bind(const sockaddr*, socklen_t) = 0; virtual KResult bind(const sockaddr*, socklen_t) = 0;
virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0; virtual KResult connect(FileDescription&, const sockaddr*, socklen_t, ShouldBlock) = 0;
@ -56,8 +81,6 @@ public:
timeval receive_deadline() const { return m_receive_deadline; } timeval receive_deadline() const { return m_receive_deadline; }
timeval send_deadline() const { return m_send_deadline; } timeval send_deadline() const { return m_send_deadline; }
void set_connected(bool connected) { m_connected = connected; }
Lock& lock() { return m_lock; } Lock& lock() { return m_lock; }
// ^File // ^File
@ -87,6 +110,7 @@ private:
int m_type { 0 }; int m_type { 0 };
int m_protocol { 0 }; int m_protocol { 0 };
int m_backlog { 0 }; int m_backlog { 0 };
SetupState m_setup_state { SetupState::Unstarted };
bool m_connected { false }; bool m_connected { false };
timeval m_receive_timeout { 0, 0 }; timeval m_receive_timeout { 0, 0 };

View file

@ -70,6 +70,7 @@ SocketHandle<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_ad
auto client = TCPSocket::create(protocol()); auto client = TCPSocket::create(protocol());
client->set_setup_state(SetupState::InProgress);
client->set_local_address(new_local_address); client->set_local_address(new_local_address);
client->set_local_port(new_local_port); client->set_local_port(new_local_port);
client->set_peer_address(new_peer_address); client->set_peer_address(new_peer_address);
@ -239,7 +240,7 @@ KResult TCPSocket::protocol_listen()
sockets_by_tuple().resource().set(tuple(), this); sockets_by_tuple().resource().set(tuple(), this);
set_direction(Direction::Passive); set_direction(Direction::Passive);
set_state(State::Listen); set_state(State::Listen);
set_connected(true); set_setup_state(SetupState::Completed);
return KSuccess; return KSuccess;
} }
@ -264,6 +265,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
m_sequence_number = 0; m_sequence_number = 0;
m_ack_number = 0; m_ack_number = 0;
set_setup_state(SetupState::InProgress);
send_tcp_packet(TCPFlags::SYN); send_tcp_packet(TCPFlags::SYN);
m_state = State::SynSent; m_state = State::SynSent;
m_direction = Direction::Outgoing; m_direction = Direction::Outgoing;
@ -271,7 +273,7 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh
if (should_block == ShouldBlock::Yes) { if (should_block == ShouldBlock::Yes) {
if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal) if (current->block<Thread::ConnectBlocker>(description) == Thread::BlockResult::InterruptedBySignal)
return KResult(-EINTR); return KResult(-EINTR);
ASSERT(is_connected()); ASSERT(setup_state() == SetupState::Completed);
return KSuccess; return KSuccess;
} }

View file

@ -2313,7 +2313,7 @@ int Process::sys$getpeername(int sockfd, sockaddr* addr, socklen_t* addrlen)
auto& socket = *description->socket(); auto& socket = *description->socket();
if (!socket.is_connected()) if (socket.setup_state() != Socket::SetupState::Completed)
return -ENOTCONN; return -ENOTCONN;
if (!socket.get_peer_address(addr, addrlen)) if (!socket.get_peer_address(addr, addrlen))

View file

@ -111,7 +111,7 @@ Thread::ConnectBlocker::ConnectBlocker(const FileDescription& description)
bool Thread::ConnectBlocker::should_unblock(Thread&, time_t, long) bool Thread::ConnectBlocker::should_unblock(Thread&, time_t, long)
{ {
auto& socket = *blocked_description().socket(); auto& socket = *blocked_description().socket();
return socket.is_connected(); return socket.setup_state() == Socket::SetupState::Completed;
} }
Thread::WriteBlocker::WriteBlocker(const FileDescription& description) Thread::WriteBlocker::WriteBlocker(const FileDescription& description)