diff --git a/Kernel/FileSystem/FileDescription.cpp b/Kernel/FileSystem/FileDescription.cpp index 4b122bb256..61308e9c7e 100644 --- a/Kernel/FileSystem/FileDescription.cpp +++ b/Kernel/FileSystem/FileDescription.cpp @@ -22,17 +22,16 @@ NonnullRefPtr FileDescription::create(Custody& custody) return description; } -NonnullRefPtr FileDescription::create(File& file, SocketRole role) +NonnullRefPtr FileDescription::create(File& file) { - return adopt(*new FileDescription(file, role)); + return adopt(*new FileDescription(file)); } -FileDescription::FileDescription(File& file, SocketRole role) +FileDescription::FileDescription(File& file) : m_file(file) { if (file.is_inode()) m_inode = static_cast(file).inode(); - set_socket_role(role); if (is_socket()) socket()->attach(*this); } @@ -47,15 +46,6 @@ FileDescription::~FileDescription() m_inode = nullptr; } -void FileDescription::set_socket_role(SocketRole role) -{ - if (role == m_socket_role) - return; - - ASSERT(is_socket()); - m_socket_role = role; -} - KResult FileDescription::fstat(stat& buffer) { ASSERT(!is_fifo()); diff --git a/Kernel/FileSystem/FileDescription.h b/Kernel/FileSystem/FileDescription.h index cd72832ceb..62dcb91b7f 100644 --- a/Kernel/FileSystem/FileDescription.h +++ b/Kernel/FileSystem/FileDescription.h @@ -23,7 +23,7 @@ class SharedMemory; class FileDescription : public RefCounted { public: static NonnullRefPtr create(Custody&); - static NonnullRefPtr create(File&, SocketRole = SocketRole::None); + static NonnullRefPtr create(File&); ~FileDescription(); int close(); @@ -93,9 +93,6 @@ public: void set_original_inode(Badge, NonnullRefPtr&& inode) { m_inode = move(inode); } - SocketRole socket_role() const { return m_socket_role; } - void set_socket_role(SocketRole); - KResult truncate(off_t); off_t offset() const { return m_current_offset; } @@ -104,7 +101,7 @@ public: private: friend class VFS; - FileDescription(File&, SocketRole = SocketRole::None); + explicit FileDescription(File&); FileDescription(FIFO&, FIFO::Direction); RefPtr m_custody; @@ -119,6 +116,5 @@ private: bool m_is_blocking { true }; bool m_should_append { false }; - SocketRole m_socket_role { SocketRole::None }; FIFO::Direction m_fifo_direction { FIFO::Direction::Neither }; }; diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 27bd7019d3..5621846934 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -108,6 +108,8 @@ KResult IPv4Socket::connect(FileDescription& description, const sockaddr* addres return KResult(-EINVAL); if (address->sa_family != AF_INET) return KResult(-EINVAL); + if (m_role == Role::Connected) + return KResult(-EISCONN); auto& ia = *(const sockaddr_in*)address; m_peer_address = IPv4Address((const u8*)&ia.sin_addr.s_addr); @@ -124,9 +126,9 @@ void IPv4Socket::detach(FileDescription&) { } -bool IPv4Socket::can_read(FileDescription& description) const +bool IPv4Socket::can_read(FileDescription&) const { - if (description.socket_role() == SocketRole::Listener) + if (m_role == Role::Listener) return can_accept(); if (protocol_is_disconnected()) return true; diff --git a/Kernel/Net/LocalSocket.cpp b/Kernel/Net/LocalSocket.cpp index f4e1377134..efcc35c0c8 100644 --- a/Kernel/Net/LocalSocket.cpp +++ b/Kernel/Net/LocalSocket.cpp @@ -99,23 +99,35 @@ KResult LocalSocket::connect(FileDescription& description, const sockaddr* addre m_address = local_address; + ASSERT(m_connect_side_fd == &description); + m_connect_side_role = Role::Connecting; + auto peer = m_file->inode()->socket(); auto result = peer->queue_connection_from(*this); - if (result.is_error()) + if (result.is_error()) { + m_connect_side_role = Role::None; return result; + } - if (is_connected()) + if (is_connected()) { + m_connect_side_role = Role::Connected; return KSuccess; + } - if (current->block(description) == Thread::BlockResult::InterruptedBySignal) + if (current->block(description) == Thread::BlockResult::InterruptedBySignal) { + m_connect_side_role = Role::None; return KResult(-EINTR); + } #ifdef DEBUG_LOCAL_SOCKET kprintf("%s(%u) LocalSocket{%p} connect(%s) status is %s\n", current->process().name().characters(), current->pid(), this, safe_address, to_string(setup_state())); #endif - if (!is_connected()) + if (!is_connected()) { + m_connect_side_role = Role::None; return KResult(-ECONNREFUSED); + } + m_connect_side_role = Role::Connected; return KSuccess; } @@ -125,6 +137,7 @@ KResult LocalSocket::listen(int backlog) if (type() != SOCK_STREAM) return KResult(-EOPNOTSUPP); set_backlog(backlog); + m_connect_side_role = m_role = Role::Listener; kprintf("LocalSocket{%p} listening with backlog=%d\n", this, backlog); return KSuccess; } @@ -132,68 +145,53 @@ KResult LocalSocket::listen(int backlog) void LocalSocket::attach(FileDescription& description) { ASSERT(!m_accept_side_fd_open); - switch (description.socket_role()) { - case SocketRole::None: - ASSERT(!m_connect_side_fd_open); - m_connect_side_fd_open = true; - break; - case SocketRole::Accepted: + if (m_connect_side_role == Role::None) { + ASSERT(m_connect_side_fd == nullptr); + m_connect_side_fd = &description; + } else { + ASSERT(m_connect_side_fd != &description); m_accept_side_fd_open = true; - break; - case SocketRole::Connected: - ASSERT_NOT_REACHED(); - default: - break; } } void LocalSocket::detach(FileDescription& description) { - switch (description.socket_role()) { - case SocketRole::None: - ASSERT(!m_accept_side_fd_open); - ASSERT(m_connect_side_fd_open); - m_connect_side_fd_open = false; - break; - case SocketRole::Accepted: + if (m_connect_side_fd == &description) { + m_connect_side_fd = nullptr; + } else { ASSERT(m_accept_side_fd_open); m_accept_side_fd_open = false; - break; - case SocketRole::Connected: - ASSERT(m_connect_side_fd_open); - m_connect_side_fd_open = false; - break; - default: - break; } } bool LocalSocket::can_read(FileDescription& description) const { - auto role = description.socket_role(); - if (role == SocketRole::Listener) + auto role = this->role(description); + if (role == Role::Listener) return can_accept(); - if (role == SocketRole::Accepted) + if (role == Role::Accepted) return !has_attached_peer(description) || !m_for_server.is_empty(); - if (role == SocketRole::Connected) + if (role == Role::Connected) return !has_attached_peer(description) || !m_for_client.is_empty(); ASSERT_NOT_REACHED(); } bool LocalSocket::has_attached_peer(const FileDescription& description) const { - if (description.socket_role() == SocketRole::Accepted) - return m_connect_side_fd_open; - if (description.socket_role() == SocketRole::Connected) + auto role = this->role(description); + if (role == Role::Accepted) + return m_connect_side_fd != nullptr; + if (role == Role::Connected) return m_accept_side_fd_open; ASSERT_NOT_REACHED(); } bool LocalSocket::can_write(FileDescription& description) const { - if (description.socket_role() == SocketRole::Accepted) + auto role = this->role(description); + if (role == Role::Accepted) return !has_attached_peer(description) || m_for_client.bytes_in_write_buffer() < 16384; - if (description.socket_role() == SocketRole::Connected) + if (role == Role::Connected) return !has_attached_peer(description) || m_for_server.bytes_in_write_buffer() < 16384; ASSERT_NOT_REACHED(); } @@ -202,24 +200,25 @@ ssize_t LocalSocket::sendto(FileDescription& description, const void* data, size { if (!has_attached_peer(description)) return -EPIPE; - if (description.socket_role() == SocketRole::Accepted) + auto role = this->role(description); + if (role == Role::Accepted) return m_for_client.write((const u8*)data, data_size); - if (description.socket_role() == SocketRole::Connected) + if (role == Role::Connected) return m_for_server.write((const u8*)data, data_size); ASSERT_NOT_REACHED(); } ssize_t LocalSocket::recvfrom(FileDescription& description, void* buffer, size_t buffer_size, int, sockaddr*, socklen_t*) { - auto role = description.socket_role(); - if (role == SocketRole::Accepted) { + auto role = this->role(description); + if (role == Role::Accepted) { if (!description.is_blocking()) { if (m_for_server.is_empty()) return -EAGAIN; } return m_for_server.read((u8*)buffer, buffer_size); } - if (role == SocketRole::Connected) { + if (role == Role::Connected) { if (!description.is_blocking()) { if (m_for_client.is_empty()) return -EAGAIN; diff --git a/Kernel/Net/LocalSocket.h b/Kernel/Net/LocalSocket.h index 9cae1adc88..fb47364f12 100644 --- a/Kernel/Net/LocalSocket.h +++ b/Kernel/Net/LocalSocket.h @@ -28,11 +28,25 @@ private: virtual bool is_local() const override { return true; } bool has_attached_peer(const FileDescription&) const; + // An open socket file on the filesystem. RefPtr m_file; + // A single LocalSocket is shared between two file descriptions + // on the connect side and the accept side; so we need to store + // an additional role for the connect side and differentiate + // between them. + Role m_connect_side_role { Role::None }; + FileDescription* m_connect_side_fd { nullptr }; + + virtual Role role(const FileDescription& description) const override + { + if (m_connect_side_fd == &description) + return m_connect_side_role; + return m_role; + } + bool m_bound { false }; bool m_accept_side_fd_open { false }; - bool m_connect_side_fd_open { false }; sockaddr_un m_address; DoubleBuffer m_for_client; diff --git a/Kernel/Net/Socket.cpp b/Kernel/Net/Socket.cpp index 6a2ea1fbd9..2229de18d2 100644 --- a/Kernel/Net/Socket.cpp +++ b/Kernel/Net/Socket.cpp @@ -48,6 +48,7 @@ KResult Socket::listen(int backlog) if (m_type != SOCK_STREAM) return KResult(-EOPNOTSUPP); m_backlog = backlog; + m_role = Role::Listener; kprintf("Socket{%p} listening with backlog=%d\n", this, m_backlog); return KSuccess; } @@ -64,6 +65,7 @@ RefPtr Socket::accept() ASSERT(!client->is_connected()); client->set_setup_state(SetupState::Completed); client->m_connected = true; + client->m_role = Role::Accepted; return client; } @@ -146,14 +148,14 @@ void Socket::load_send_deadline() m_send_deadline.tv_usec %= 1000000; } -static const char* to_string(SocketRole role) +static const char* to_string(Socket::Role role) { switch (role) { - case SocketRole::Listener: + case Socket::Role::Listener: return "Listener"; - case SocketRole::Accepted: + case Socket::Role::Accepted: return "Accepted"; - case SocketRole::Connected: + case Socket::Role::Connected: return "Connected"; default: return "None"; @@ -162,7 +164,7 @@ static const char* to_string(SocketRole role) String Socket::absolute_path(const FileDescription& description) const { - return String::format("socket:%x (role: %s)", this, ::to_string(description.socket_role())); + return String::format("socket:%x (role: %s)", this, ::to_string(role(description))); } ssize_t Socket::read(FileDescription& description, u8* buffer, ssize_t size) diff --git a/Kernel/Net/Socket.h b/Kernel/Net/Socket.h index d4d4eec40e..c10a085212 100644 --- a/Kernel/Net/Socket.h +++ b/Kernel/Net/Socket.h @@ -9,13 +9,6 @@ #include #include -enum class SocketRole : u8 { - None, - Listener, - Accepted, - Connected, - Connecting -}; enum class ShouldBlock { No = 0, Yes = 1 @@ -38,6 +31,14 @@ public: Completed, // the setup process is complete, but not necessarily successful }; + enum class Role : u8 { + None, + Listener, + Accepted, + Connected, + Connecting + }; + static const char* to_string(SetupState setup_state) { switch (setup_state) { @@ -55,6 +56,8 @@ public: SetupState setup_state() const { return m_setup_state; } void set_setup_state(SetupState setup_state); + virtual Role role(const FileDescription&) const { return m_role; } + bool is_connected() const { return m_connected; } void set_connected(bool connected) { m_connected = connected; } @@ -101,6 +104,8 @@ protected: virtual const char* class_name() const override { return "Socket"; } + Role m_role { Role::None }; + private: virtual bool is_socket() const final { return true; } diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index ee7a402015..f7fe4af29c 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -24,6 +24,9 @@ void TCPSocket::set_state(State new_state) #endif m_state = new_state; + + if (new_state == State::Established && m_direction == Direction::Outgoing) + m_role = Role::Connected; } Lockable>& TCPSocket::sockets_by_tuple() @@ -268,14 +271,17 @@ KResult TCPSocket::protocol_connect(FileDescription& description, ShouldBlock sh set_setup_state(SetupState::InProgress); send_tcp_packet(TCPFlags::SYN); m_state = State::SynSent; + m_role = Role::Connecting; m_direction = Direction::Outgoing; if (should_block == ShouldBlock::Yes) { if (current->block(description) == Thread::BlockResult::InterruptedBySignal) return KResult(-EINTR); ASSERT(setup_state() == SetupState::Completed); - if (has_error()) + if (has_error()) { + m_role = Role::None; return KResult(-ECONNREFUSED); + } return KSuccess; } diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index aa74bc15c1..e0fd77d040 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -81,6 +81,12 @@ int UDPSocket::protocol_send(const void* data, int data_length) return data_length; } +KResult UDPSocket::protocol_connect(FileDescription&, ShouldBlock) +{ + m_role = Role::Connected; + return KSuccess; +} + int UDPSocket::protocol_allocate_local_port() { static const u16 first_ephemeral_port = 32768; diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index ff08f38cdb..1950f737c7 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -17,7 +17,7 @@ private: virtual int protocol_receive(const KBuffer&, void* buffer, size_t buffer_size, int flags) override; virtual int protocol_send(const void*, int) override; - virtual KResult protocol_connect(FileDescription&, ShouldBlock) override { return KSuccess; } + virtual KResult protocol_connect(FileDescription&, ShouldBlock) override; virtual int protocol_allocate_local_port() override; virtual KResult protocol_bind() override; }; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index fc9e764773..4029b235ac 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2134,11 +2134,7 @@ int Process::sys$listen(int sockfd, int backlog) if (!description->is_socket()) return -ENOTSOCK; auto& socket = *description->socket(); - auto result = socket.listen(backlog); - if (result.is_error()) - return result; - description->set_socket_role(SocketRole::Listener); - return 0; + return socket.listen(backlog); } int Process::sys$accept(int accepting_socket_fd, sockaddr* address, socklen_t* address_size) @@ -2168,7 +2164,7 @@ int Process::sys$accept(int accepting_socket_fd, sockaddr* address, socklen_t* a ASSERT(accepted_socket); bool success = accepted_socket->get_peer_address(address, address_size); ASSERT(success); - auto accepted_socket_description = FileDescription::create(*accepted_socket, SocketRole::Accepted); + auto accepted_socket_description = FileDescription::create(*accepted_socket); // NOTE: The accepted socket inherits fd flags from the accepting socket. // I'm not sure if this matches other systems but it makes sense to me. accepted_socket_description->set_blocking(accepting_socket_description->is_blocking()); @@ -2188,17 +2184,9 @@ int Process::sys$connect(int sockfd, const sockaddr* address, socklen_t address_ return -EBADF; if (!description->is_socket()) return -ENOTSOCK; - if (description->socket_role() == SocketRole::Connected) - return -EISCONN; + auto& socket = *description->socket(); - description->set_socket_role(SocketRole::Connecting); - auto result = socket.connect(*description, address, address_size, description->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No); - if (result.is_error()) { - description->set_socket_role(SocketRole::None); - return result; - } - description->set_socket_role(SocketRole::Connected); - return 0; + return socket.connect(*description, address, address_size, description->is_blocking() ? ShouldBlock::Yes : ShouldBlock::No); } ssize_t Process::sys$sendto(const Syscall::SC_sendto_params* params)