From 54b1d6f57fe52de11701e96461a19fd2750fd2d4 Mon Sep 17 00:00:00 2001 From: Andreas Kling Date: Thu, 14 Feb 2019 15:17:30 +0100 Subject: [PATCH] Kernel: More sockets work. Fleshing out accept(). --- Kernel/LocalSocket.cpp | 17 +++++++++++++++-- Kernel/LocalSocket.h | 8 ++++++++ Kernel/Process.cpp | 39 +++++++++++++++++++++++++++++++++++---- Kernel/Process.h | 2 +- Kernel/Socket.cpp | 23 ++++++++++++++++++++++- Kernel/Socket.h | 17 ++++++++++++++++- Kernel/UnixTypes.h | 6 ++++++ 7 files changed, 103 insertions(+), 9 deletions(-) diff --git a/Kernel/LocalSocket.cpp b/Kernel/LocalSocket.cpp index fc6ea87b3b..abd3201cc8 100644 --- a/Kernel/LocalSocket.cpp +++ b/Kernel/LocalSocket.cpp @@ -19,6 +19,16 @@ LocalSocket::~LocalSocket() { } +bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size) +{ + // FIXME: Look into what fallback behavior we should have here. + if (*address_size != sizeof(sockaddr_un)) + return false; + memcpy(address, &m_address, sizeof(sockaddr_un)); + *address_size = sizeof(sockaddr_un); + return true; +} + bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error) { if (address_size != sizeof(sockaddr_un)) { @@ -37,11 +47,14 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address); - auto descriptor = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode()); - if (!descriptor) { + m_file = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode()); + if (!m_file) { if (error == -EEXIST) error = -EADDRINUSE; return error; } + + m_address = local_address; + m_bound = true; return true; } diff --git a/Kernel/LocalSocket.h b/Kernel/LocalSocket.h index 916de48ba5..ea5440694f 100644 --- a/Kernel/LocalSocket.h +++ b/Kernel/LocalSocket.h @@ -3,16 +3,24 @@ #include #include +class FileDescriptor; + class LocalSocket final : public Socket { public: static RetainPtr create(int type); virtual ~LocalSocket() override; virtual bool bind(const sockaddr*, socklen_t, int& error) override; + virtual bool get_address(sockaddr*, socklen_t*) override; private: explicit LocalSocket(int type); + RetainPtr m_file; + + bool m_bound { false }; + sockaddr_un m_address; + DoubleBuffer m_for_client; DoubleBuffer m_for_server; }; diff --git a/Kernel/Process.cpp b/Kernel/Process.cpp index af9fcfc921..549cea4810 100644 --- a/Kernel/Process.cpp +++ b/Kernel/Process.cpp @@ -2258,7 +2258,12 @@ int Process::sys$socket(int domain, int type, int protocol) if (!socket) return error; auto descriptor = FileDescriptor::create(move(socket)); - m_fds[fd].set(move(descriptor)); + unsigned flags = 0; + if (type & SOCK_CLOEXEC) + flags |= O_CLOEXEC; + if (type & SOCK_NONBLOCK) + descriptor->set_blocking(false); + m_fds[fd].set(move(descriptor), flags); return fd; } @@ -2280,12 +2285,38 @@ int Process::sys$bind(int sockfd, const sockaddr* address, socklen_t address_len int Process::sys$listen(int sockfd, int backlog) { - return -ENOTIMPL; + auto* descriptor = file_descriptor(sockfd); + if (!descriptor) + return -EBADF; + if (!descriptor->is_socket()) + return -ENOTSOCK; + auto& socket = *descriptor->socket(); + int error; + if (!socket.listen(backlog, error)) + return error; + return 0; } -int Process::sys$accept(int sockfd, sockaddr*, socklen_t) +int Process::sys$accept(int sockfd, sockaddr* address, socklen_t* address_size) { - return -ENOTIMPL; + if (!validate_write_typed(address_size)) + return -EFAULT; + if (!validate_write(address, *address_size)) + return -EFAULT; + auto* descriptor = file_descriptor(sockfd); + if (!descriptor) + return -EBADF; + if (!descriptor->is_socket()) + return -ENOTSOCK; + auto& socket = *descriptor->socket(); + if (!socket.can_accept()) { + ASSERT(!descriptor->is_blocking()); + return -EAGAIN; + } + auto client = socket.accept(); + ASSERT(client); + client->get_address(address, address_size); + return 0; } int Process::sys$connect(int sockfd, const sockaddr*, socklen_t) diff --git a/Kernel/Process.h b/Kernel/Process.h index addf30fe10..4ea660fd53 100644 --- a/Kernel/Process.h +++ b/Kernel/Process.h @@ -221,7 +221,7 @@ public: int sys$socket(int domain, int type, int protocol); int sys$bind(int sockfd, const sockaddr* addr, socklen_t); int sys$listen(int sockfd, int backlog); - int sys$accept(int sockfd, sockaddr*, socklen_t); + int sys$accept(int sockfd, sockaddr*, socklen_t*); int sys$connect(int sockfd, const sockaddr*, socklen_t); DisplayInfo set_video_resolution(int width, int height); diff --git a/Kernel/Socket.cpp b/Kernel/Socket.cpp index 44e11c3753..18fb351779 100644 --- a/Kernel/Socket.cpp +++ b/Kernel/Socket.cpp @@ -8,7 +8,7 @@ RetainPtr Socket::create(int domain, int type, int protocol, int& error) (void)protocol; switch (domain) { case AF_LOCAL: - return LocalSocket::create(type); + return LocalSocket::create(type & SOCK_TYPE_MASK); default: error = EAFNOSUPPORT; return nullptr; @@ -26,4 +26,25 @@ Socket::~Socket() { } +bool Socket::listen(int backlog, int& error) +{ + LOCKER(m_lock); + if (m_type != SOCK_STREAM) { + error = -EOPNOTSUPP; + return false; + } + m_backlog = backlog; + m_listening = true; + kprintf("Socket{%p} listening with backlog=%d\n", m_backlog); + return true; +} +RetainPtr Socket::accept() +{ + LOCKER(m_lock); + if (m_pending.is_empty()) + return nullptr; + auto client = m_pending.take_first(); + m_clients.append(client.copy_ref()); + return client; +} diff --git a/Kernel/Socket.h b/Kernel/Socket.h index a6c8809f36..d5e91cfb52 100644 --- a/Kernel/Socket.h +++ b/Kernel/Socket.h @@ -1,7 +1,10 @@ #pragma once +#include #include #include +#include +#include #include class Socket : public Retainable { @@ -9,18 +12,30 @@ public: static RetainPtr create(int domain, int type, int protocol, int& error); virtual ~Socket(); + bool is_listening() const { return m_listening; } int domain() const { return m_domain; } int type() const { return m_type; } int protocol() const { return m_protocol; } + bool can_accept() const { return m_pending.is_empty(); } + RetainPtr accept(); + + bool listen(int backlog, int& error); + virtual bool bind(const sockaddr*, socklen_t, int& error) = 0; + virtual bool get_address(sockaddr*, socklen_t*) = 0; protected: Socket(int domain, int type, int protocol); private: + Lock m_lock; int m_domain { 0 }; int m_type { 0 }; int m_protocol { 0 }; -}; + int m_backlog { 0 }; + bool m_listening { false }; + Vector> m_pending; + Vector> m_clients; +}; diff --git a/Kernel/UnixTypes.h b/Kernel/UnixTypes.h index d8302f326a..18e661034d 100644 --- a/Kernel/UnixTypes.h +++ b/Kernel/UnixTypes.h @@ -306,9 +306,15 @@ struct pollfd { short revents; }; +#define AF_MASK 0xff #define AF_UNSPEC 0 #define AF_LOCAL 1 +#define SOCK_TYPE_MASK 0xff +#define SOCK_STREAM 1 +#define SOCK_NONBLOCK 04000 +#define SOCK_CLOEXEC 02000000 + struct sockaddr { word sa_family; char sa_data[14];