mirror of
https://github.com/RGBCube/serenity
synced 2025-07-25 19:07:35 +00:00
Kernel: Handle OOM from DoubleBuffer usage in IPv4Socket
The IPv4Socket requires a DoubleBuffer for storage of any data it received on the socket. However it was previously using the default constructor which can not observe allocation failure. Address this by plumbing the receive buffer through the various derived classes.
This commit is contained in:
parent
109c885585
commit
ca94a83337
6 changed files with 41 additions and 26 deletions
|
@ -35,22 +35,31 @@ Lockable<HashTable<IPv4Socket*>>& IPv4Socket::all_sockets()
|
|||
return *s_table;
|
||||
}
|
||||
|
||||
OwnPtr<DoubleBuffer> IPv4Socket::create_receive_buffer()
|
||||
{
|
||||
return DoubleBuffer::try_create(256 * KiB);
|
||||
}
|
||||
|
||||
KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)
|
||||
{
|
||||
auto receive_buffer = IPv4Socket::create_receive_buffer();
|
||||
if (!receive_buffer)
|
||||
return ENOMEM;
|
||||
|
||||
if (type == SOCK_STREAM) {
|
||||
auto tcp_socket = TCPSocket::create(protocol);
|
||||
auto tcp_socket = TCPSocket::create(protocol, receive_buffer.release_nonnull());
|
||||
if (tcp_socket.is_error())
|
||||
return tcp_socket.error();
|
||||
return tcp_socket.release_value();
|
||||
}
|
||||
if (type == SOCK_DGRAM) {
|
||||
auto udp_socket = UDPSocket::create(protocol);
|
||||
auto udp_socket = UDPSocket::create(protocol, receive_buffer.release_nonnull());
|
||||
if (udp_socket.is_error())
|
||||
return udp_socket.error();
|
||||
return udp_socket.release_value();
|
||||
}
|
||||
if (type == SOCK_RAW) {
|
||||
auto raw_socket = adopt_ref_if_nonnull(new (nothrow) IPv4Socket(type, protocol));
|
||||
auto raw_socket = adopt_ref_if_nonnull(new (nothrow) IPv4Socket(type, protocol, receive_buffer.release_nonnull()));
|
||||
if (raw_socket)
|
||||
return raw_socket.release_nonnull();
|
||||
return ENOMEM;
|
||||
|
@ -58,8 +67,9 @@ KResultOr<NonnullRefPtr<Socket>> IPv4Socket::create(int type, int protocol)
|
|||
return EINVAL;
|
||||
}
|
||||
|
||||
IPv4Socket::IPv4Socket(int type, int protocol)
|
||||
IPv4Socket::IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
|
||||
: Socket(AF_INET, type, protocol)
|
||||
, m_receive_buffer(move(receive_buffer))
|
||||
{
|
||||
dbgln_if(IPV4_SOCKET_DEBUG, "IPv4Socket({}) created with type={}, protocol={}", this, type, protocol);
|
||||
m_buffer_mode = type == SOCK_STREAM ? BufferMode::Bytes : BufferMode::Packets;
|
||||
|
@ -248,7 +258,7 @@ KResultOr<size_t> IPv4Socket::sendto(FileDescription&, const UserOrKernelBuffer&
|
|||
KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description, UserOrKernelBuffer& buffer, size_t buffer_length, int flags, Userspace<sockaddr*>, Userspace<socklen_t*>)
|
||||
{
|
||||
MutexLocker locker(lock());
|
||||
if (m_receive_buffer.is_empty()) {
|
||||
if (m_receive_buffer->is_empty()) {
|
||||
if (protocol_is_disconnected())
|
||||
return 0;
|
||||
if (!description.is_blocking())
|
||||
|
@ -270,14 +280,14 @@ KResultOr<size_t> IPv4Socket::receive_byte_buffered(FileDescription& description
|
|||
|
||||
KResultOr<size_t> nreceived_or_error { 0 };
|
||||
if (flags & MSG_PEEK)
|
||||
nreceived_or_error = m_receive_buffer.peek(buffer, buffer_length);
|
||||
nreceived_or_error = m_receive_buffer->peek(buffer, buffer_length);
|
||||
else
|
||||
nreceived_or_error = m_receive_buffer.read(buffer, buffer_length);
|
||||
nreceived_or_error = m_receive_buffer->read(buffer, buffer_length);
|
||||
|
||||
if (!nreceived_or_error.is_error() && nreceived_or_error.value() > 0 && !(flags & MSG_PEEK))
|
||||
Thread::current()->did_ipv4_socket_read(nreceived_or_error.value());
|
||||
|
||||
set_can_read(!m_receive_buffer.is_empty());
|
||||
set_can_read(!m_receive_buffer->is_empty());
|
||||
return nreceived_or_error;
|
||||
}
|
||||
|
||||
|
@ -406,7 +416,7 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
|
|||
auto packet_size = packet.size();
|
||||
|
||||
if (buffer_mode() == BufferMode::Bytes) {
|
||||
size_t space_in_receive_buffer = m_receive_buffer.space_for_writing();
|
||||
size_t space_in_receive_buffer = m_receive_buffer->space_for_writing();
|
||||
if (packet_size > space_in_receive_buffer) {
|
||||
dbgln("IPv4Socket({}): did_receive refusing packet since buffer is full.", this);
|
||||
VERIFY(m_can_read);
|
||||
|
@ -416,10 +426,10 @@ bool IPv4Socket::did_receive(const IPv4Address& source_address, u16 source_port,
|
|||
auto nreceived_or_error = protocol_receive(ReadonlyBytes { packet.data(), packet.size() }, scratch_buffer, m_scratch_buffer.value().size(), 0);
|
||||
if (nreceived_or_error.is_error())
|
||||
return false;
|
||||
auto nwritten_or_error = m_receive_buffer.write(scratch_buffer, nreceived_or_error.value());
|
||||
auto nwritten_or_error = m_receive_buffer->write(scratch_buffer, nreceived_or_error.value());
|
||||
if (nwritten_or_error.is_error())
|
||||
return false;
|
||||
set_can_read(!m_receive_buffer.is_empty());
|
||||
set_can_read(!m_receive_buffer->is_empty());
|
||||
} else {
|
||||
if (m_receive_queue.size() > 2000) {
|
||||
dbgln("IPv4Socket({}): did_receive refusing packet since queue is full.", this);
|
||||
|
|
|
@ -74,7 +74,7 @@ public:
|
|||
BufferMode buffer_mode() const { return m_buffer_mode; }
|
||||
|
||||
protected:
|
||||
IPv4Socket(int type, int protocol);
|
||||
IPv4Socket(int type, int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
|
||||
virtual StringView class_name() const override { return "IPv4Socket"; }
|
||||
|
||||
PortAllocationResult allocate_local_port_if_needed();
|
||||
|
@ -92,6 +92,8 @@ protected:
|
|||
void set_local_address(IPv4Address address) { m_local_address = address; }
|
||||
void set_peer_address(IPv4Address address) { m_peer_address = address; }
|
||||
|
||||
static OwnPtr<DoubleBuffer> create_receive_buffer();
|
||||
|
||||
private:
|
||||
virtual bool is_ipv4() const override { return true; }
|
||||
|
||||
|
@ -115,7 +117,7 @@ private:
|
|||
|
||||
SinglyLinkedListWithCount<ReceivedPacket> m_receive_queue;
|
||||
|
||||
DoubleBuffer m_receive_buffer { 256 * KiB };
|
||||
NonnullOwnPtr<DoubleBuffer> m_receive_buffer;
|
||||
|
||||
u16 m_local_port { 0 };
|
||||
u16 m_peer_port { 0 };
|
||||
|
|
|
@ -96,7 +96,10 @@ RefPtr<TCPSocket> TCPSocket::create_client(const IPv4Address& new_local_address,
|
|||
return {};
|
||||
}
|
||||
|
||||
auto result = TCPSocket::create(protocol());
|
||||
auto receive_buffer = create_receive_buffer();
|
||||
if (!receive_buffer)
|
||||
return {};
|
||||
auto result = TCPSocket::create(protocol(), receive_buffer.release_nonnull());
|
||||
if (result.is_error())
|
||||
return {};
|
||||
|
||||
|
@ -131,8 +134,8 @@ void TCPSocket::release_for_accept(RefPtr<TCPSocket> socket)
|
|||
[[maybe_unused]] auto rc = queue_connection_from(*socket);
|
||||
}
|
||||
|
||||
TCPSocket::TCPSocket(int protocol)
|
||||
: IPv4Socket(SOCK_STREAM, protocol)
|
||||
TCPSocket::TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
|
||||
: IPv4Socket(SOCK_STREAM, protocol, move(receive_buffer))
|
||||
{
|
||||
m_last_retransmit_time = kgettimeofday();
|
||||
}
|
||||
|
@ -147,9 +150,9 @@ TCPSocket::~TCPSocket()
|
|||
dbgln_if(TCP_SOCKET_DEBUG, "~TCPSocket in state {}", to_string(state()));
|
||||
}
|
||||
|
||||
KResultOr<NonnullRefPtr<TCPSocket>> TCPSocket::create(int protocol)
|
||||
KResultOr<NonnullRefPtr<TCPSocket>> TCPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
|
||||
{
|
||||
auto socket = adopt_ref_if_nonnull(new (nothrow) TCPSocket(protocol));
|
||||
auto socket = adopt_ref_if_nonnull(new (nothrow) TCPSocket(protocol, move(receive_buffer)));
|
||||
if (socket)
|
||||
return socket.release_nonnull();
|
||||
return ENOMEM;
|
||||
|
|
|
@ -18,7 +18,7 @@ namespace Kernel {
|
|||
class TCPSocket final : public IPv4Socket {
|
||||
public:
|
||||
static void for_each(Function<void(const TCPSocket&)>);
|
||||
static KResultOr<NonnullRefPtr<TCPSocket>> create(int protocol);
|
||||
static KResultOr<NonnullRefPtr<TCPSocket>> create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
|
||||
virtual ~TCPSocket() override;
|
||||
|
||||
enum class Direction {
|
||||
|
@ -165,7 +165,7 @@ protected:
|
|||
void set_direction(Direction direction) { m_direction = direction; }
|
||||
|
||||
private:
|
||||
explicit TCPSocket(int protocol);
|
||||
explicit TCPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
|
||||
virtual StringView class_name() const override { return "TCPSocket"; }
|
||||
|
||||
virtual void shut_down_for_writing() override;
|
||||
|
|
|
@ -43,8 +43,8 @@ SocketHandle<UDPSocket> UDPSocket::from_port(u16 port)
|
|||
return { *socket };
|
||||
}
|
||||
|
||||
UDPSocket::UDPSocket(int protocol)
|
||||
: IPv4Socket(SOCK_DGRAM, protocol)
|
||||
UDPSocket::UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
|
||||
: IPv4Socket(SOCK_DGRAM, protocol, move(receive_buffer))
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -54,9 +54,9 @@ UDPSocket::~UDPSocket()
|
|||
sockets_by_port().resource().remove(local_port());
|
||||
}
|
||||
|
||||
KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol)
|
||||
KResultOr<NonnullRefPtr<UDPSocket>> UDPSocket::create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer)
|
||||
{
|
||||
auto socket = adopt_ref_if_nonnull(new (nothrow) UDPSocket(protocol));
|
||||
auto socket = adopt_ref_if_nonnull(new (nothrow) UDPSocket(protocol, move(receive_buffer)));
|
||||
if (socket)
|
||||
return socket.release_nonnull();
|
||||
return ENOMEM;
|
||||
|
|
|
@ -13,14 +13,14 @@ namespace Kernel {
|
|||
|
||||
class UDPSocket final : public IPv4Socket {
|
||||
public:
|
||||
static KResultOr<NonnullRefPtr<UDPSocket>> create(int protocol);
|
||||
static KResultOr<NonnullRefPtr<UDPSocket>> create(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
|
||||
virtual ~UDPSocket() override;
|
||||
|
||||
static SocketHandle<UDPSocket> from_port(u16);
|
||||
static void for_each(Function<void(const UDPSocket&)>);
|
||||
|
||||
private:
|
||||
explicit UDPSocket(int protocol);
|
||||
explicit UDPSocket(int protocol, NonnullOwnPtr<DoubleBuffer> receive_buffer);
|
||||
virtual StringView class_name() const override { return "UDPSocket"; }
|
||||
static Lockable<HashMap<u16, UDPSocket*>>& sockets_by_port();
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue