1
Fork 0
mirror of https://github.com/RGBCube/serenity synced 2025-07-24 22:17:42 +00:00

Kernel: More sockets work. Fleshing out accept().

This commit is contained in:
Andreas Kling 2019-02-14 15:17:30 +01:00
parent 1d66670ad7
commit 54b1d6f57f
7 changed files with 103 additions and 9 deletions

View file

@ -19,6 +19,16 @@ LocalSocket::~LocalSocket()
{ {
} }
bool LocalSocket::get_address(sockaddr* address, socklen_t* address_size)
{
// FIXME: Look into what fallback behavior we should have here.
if (*address_size != sizeof(sockaddr_un))
return false;
memcpy(address, &m_address, sizeof(sockaddr_un));
*address_size = sizeof(sockaddr_un);
return true;
}
bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error) bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& error)
{ {
if (address_size != sizeof(sockaddr_un)) { if (address_size != sizeof(sockaddr_un)) {
@ -37,11 +47,14 @@ bool LocalSocket::bind(const sockaddr* address, socklen_t address_size, int& err
kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address); kprintf("%s(%u) LocalSocket{%p} bind(%s)\n", current->name().characters(), current->pid(), safe_address);
auto descriptor = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode()); m_file = VFS::the().open(safe_address, error, O_CREAT | O_EXCL, S_IFSOCK | 0666, *current->cwd_inode());
if (!descriptor) { if (!m_file) {
if (error == -EEXIST) if (error == -EEXIST)
error = -EADDRINUSE; error = -EADDRINUSE;
return error; return error;
} }
m_address = local_address;
m_bound = true;
return true; return true;
} }

View file

@ -3,16 +3,24 @@
#include <Kernel/Socket.h> #include <Kernel/Socket.h>
#include <Kernel/DoubleBuffer.h> #include <Kernel/DoubleBuffer.h>
class FileDescriptor;
class LocalSocket final : public Socket { class LocalSocket final : public Socket {
public: public:
static RetainPtr<LocalSocket> create(int type); static RetainPtr<LocalSocket> create(int type);
virtual ~LocalSocket() override; virtual ~LocalSocket() override;
virtual bool bind(const sockaddr*, socklen_t, int& error) override; virtual bool bind(const sockaddr*, socklen_t, int& error) override;
virtual bool get_address(sockaddr*, socklen_t*) override;
private: private:
explicit LocalSocket(int type); explicit LocalSocket(int type);
RetainPtr<FileDescriptor> m_file;
bool m_bound { false };
sockaddr_un m_address;
DoubleBuffer m_for_client; DoubleBuffer m_for_client;
DoubleBuffer m_for_server; DoubleBuffer m_for_server;
}; };

View file

@ -2258,7 +2258,12 @@ int Process::sys$socket(int domain, int type, int protocol)
if (!socket) if (!socket)
return error; return error;
auto descriptor = FileDescriptor::create(move(socket)); auto descriptor = FileDescriptor::create(move(socket));
m_fds[fd].set(move(descriptor)); unsigned flags = 0;
if (type & SOCK_CLOEXEC)
flags |= O_CLOEXEC;
if (type & SOCK_NONBLOCK)
descriptor->set_blocking(false);
m_fds[fd].set(move(descriptor), flags);
return fd; return fd;
} }
@ -2280,12 +2285,38 @@ int Process::sys$bind(int sockfd, const sockaddr* address, socklen_t address_len
int Process::sys$listen(int sockfd, int backlog) int Process::sys$listen(int sockfd, int backlog)
{ {
return -ENOTIMPL; auto* descriptor = file_descriptor(sockfd);
if (!descriptor)
return -EBADF;
if (!descriptor->is_socket())
return -ENOTSOCK;
auto& socket = *descriptor->socket();
int error;
if (!socket.listen(backlog, error))
return error;
return 0;
} }
int Process::sys$accept(int sockfd, sockaddr*, socklen_t) int Process::sys$accept(int sockfd, sockaddr* address, socklen_t* address_size)
{ {
return -ENOTIMPL; if (!validate_write_typed(address_size))
return -EFAULT;
if (!validate_write(address, *address_size))
return -EFAULT;
auto* descriptor = file_descriptor(sockfd);
if (!descriptor)
return -EBADF;
if (!descriptor->is_socket())
return -ENOTSOCK;
auto& socket = *descriptor->socket();
if (!socket.can_accept()) {
ASSERT(!descriptor->is_blocking());
return -EAGAIN;
}
auto client = socket.accept();
ASSERT(client);
client->get_address(address, address_size);
return 0;
} }
int Process::sys$connect(int sockfd, const sockaddr*, socklen_t) int Process::sys$connect(int sockfd, const sockaddr*, socklen_t)

View file

@ -221,7 +221,7 @@ public:
int sys$socket(int domain, int type, int protocol); int sys$socket(int domain, int type, int protocol);
int sys$bind(int sockfd, const sockaddr* addr, socklen_t); int sys$bind(int sockfd, const sockaddr* addr, socklen_t);
int sys$listen(int sockfd, int backlog); int sys$listen(int sockfd, int backlog);
int sys$accept(int sockfd, sockaddr*, socklen_t); int sys$accept(int sockfd, sockaddr*, socklen_t*);
int sys$connect(int sockfd, const sockaddr*, socklen_t); int sys$connect(int sockfd, const sockaddr*, socklen_t);
DisplayInfo set_video_resolution(int width, int height); DisplayInfo set_video_resolution(int width, int height);

View file

@ -8,7 +8,7 @@ RetainPtr<Socket> Socket::create(int domain, int type, int protocol, int& error)
(void)protocol; (void)protocol;
switch (domain) { switch (domain) {
case AF_LOCAL: case AF_LOCAL:
return LocalSocket::create(type); return LocalSocket::create(type & SOCK_TYPE_MASK);
default: default:
error = EAFNOSUPPORT; error = EAFNOSUPPORT;
return nullptr; return nullptr;
@ -26,4 +26,25 @@ Socket::~Socket()
{ {
} }
bool Socket::listen(int backlog, int& error)
{
LOCKER(m_lock);
if (m_type != SOCK_STREAM) {
error = -EOPNOTSUPP;
return false;
}
m_backlog = backlog;
m_listening = true;
kprintf("Socket{%p} listening with backlog=%d\n", m_backlog);
return true;
}
RetainPtr<Socket> Socket::accept()
{
LOCKER(m_lock);
if (m_pending.is_empty())
return nullptr;
auto client = m_pending.take_first();
m_clients.append(client.copy_ref());
return client;
}

View file

@ -1,7 +1,10 @@
#pragma once #pragma once
#include <AK/Lock.h>
#include <AK/Retainable.h> #include <AK/Retainable.h>
#include <AK/RetainPtr.h> #include <AK/RetainPtr.h>
#include <AK/HashTable.h>
#include <AK/Vector.h>
#include <Kernel/UnixTypes.h> #include <Kernel/UnixTypes.h>
class Socket : public Retainable<Socket> { class Socket : public Retainable<Socket> {
@ -9,18 +12,30 @@ public:
static RetainPtr<Socket> create(int domain, int type, int protocol, int& error); static RetainPtr<Socket> create(int domain, int type, int protocol, int& error);
virtual ~Socket(); virtual ~Socket();
bool is_listening() const { return m_listening; }
int domain() const { return m_domain; } int domain() const { return m_domain; }
int type() const { return m_type; } int type() const { return m_type; }
int protocol() const { return m_protocol; } int protocol() const { return m_protocol; }
bool can_accept() const { return m_pending.is_empty(); }
RetainPtr<Socket> accept();
bool listen(int backlog, int& error);
virtual bool bind(const sockaddr*, socklen_t, int& error) = 0; virtual bool bind(const sockaddr*, socklen_t, int& error) = 0;
virtual bool get_address(sockaddr*, socklen_t*) = 0;
protected: protected:
Socket(int domain, int type, int protocol); Socket(int domain, int type, int protocol);
private: private:
Lock m_lock;
int m_domain { 0 }; int m_domain { 0 };
int m_type { 0 }; int m_type { 0 };
int m_protocol { 0 }; int m_protocol { 0 };
}; int m_backlog { 0 };
bool m_listening { false };
Vector<RetainPtr<Socket>> m_pending;
Vector<RetainPtr<Socket>> m_clients;
};

View file

@ -306,9 +306,15 @@ struct pollfd {
short revents; short revents;
}; };
#define AF_MASK 0xff
#define AF_UNSPEC 0 #define AF_UNSPEC 0
#define AF_LOCAL 1 #define AF_LOCAL 1
#define SOCK_TYPE_MASK 0xff
#define SOCK_STREAM 1
#define SOCK_NONBLOCK 04000
#define SOCK_CLOEXEC 02000000
struct sockaddr { struct sockaddr {
word sa_family; word sa_family;
char sa_data[14]; char sa_data[14];