diff --git a/Kernel/Net/IPv4Socket.cpp b/Kernel/Net/IPv4Socket.cpp index 58abbbd7a7..f590ad8b51 100644 --- a/Kernel/Net/IPv4Socket.cpp +++ b/Kernel/Net/IPv4Socket.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #define IPV4_SOCKET_DEBUG @@ -63,7 +64,13 @@ KResult IPv4Socket::bind(const sockaddr* address, socklen_t address_size) if (address->sa_family != AF_INET) return KResult(-EINVAL); - ASSERT_NOT_REACHED(); + auto& ia = *(const sockaddr_in*)address; + m_source_address = IPv4Address((const byte*)&ia.sin_addr.s_addr); + m_source_port = ntohs(ia.sin_port); + + dbgprintf("IPv4Socket::bind %s{%p} to port %u\n", class_name(), this, m_source_port); + + return protocol_bind(); } KResult IPv4Socket::connect(FileDescriptor& descriptor, const sockaddr* address, socklen_t address_size, ShouldBlock should_block) @@ -91,8 +98,10 @@ void IPv4Socket::detach(FileDescriptor&) --m_attached_fds; } -bool IPv4Socket::can_read(FileDescriptor&) const +bool IPv4Socket::can_read(FileDescriptor& descriptor) const { + if (descriptor.socket_role() == SocketRole::Listener) + return can_accept(); if (protocol_is_disconnected()) return true; return m_can_read; diff --git a/Kernel/Net/IPv4Socket.h b/Kernel/Net/IPv4Socket.h index e31064740d..f011182621 100644 --- a/Kernel/Net/IPv4Socket.h +++ b/Kernel/Net/IPv4Socket.h @@ -48,6 +48,7 @@ protected: int allocate_source_port_if_needed(); + virtual KResult protocol_bind() { return KSuccess; } virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } virtual int protocol_send(const void*, int) { return -ENOTIMPL; } virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) { return KSuccess; } @@ -59,6 +60,8 @@ private: bool m_bound { false }; int m_attached_fds { 0 }; + + IPv4Address m_source_address; IPv4Address m_destination_address; DoubleBuffer m_for_client; diff --git a/Kernel/Net/TCPSocket.cpp b/Kernel/Net/TCPSocket.cpp index d3073b11e9..00d30e1476 100644 --- a/Kernel/Net/TCPSocket.cpp +++ b/Kernel/Net/TCPSocket.cpp @@ -203,3 +203,12 @@ bool TCPSocket::protocol_is_disconnected() const { return m_state == State::Disconnecting || m_state == State::Disconnected; } + +KResult TCPSocket::protocol_bind() +{ + LOCKER(sockets_by_port().lock()); + if (sockets_by_port().resource().contains(source_port())) + return KResult(-EADDRINUSE); + sockets_by_port().resource().set(source_port(), this); + return KSuccess; +} diff --git a/Kernel/Net/TCPSocket.h b/Kernel/Net/TCPSocket.h index 186c3be078..a1ebce3a0d 100644 --- a/Kernel/Net/TCPSocket.h +++ b/Kernel/Net/TCPSocket.h @@ -38,6 +38,7 @@ private: virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) override; virtual int protocol_allocate_source_port() override; virtual bool protocol_is_disconnected() const override; + virtual KResult protocol_bind() override; dword m_sequence_number { 0 }; dword m_ack_number { 0 }; diff --git a/Kernel/Net/UDPSocket.cpp b/Kernel/Net/UDPSocket.cpp index 0e96d99a57..e631054630 100644 --- a/Kernel/Net/UDPSocket.cpp +++ b/Kernel/Net/UDPSocket.cpp @@ -104,3 +104,12 @@ int UDPSocket::protocol_allocate_source_port() } return -EADDRINUSE; } + +KResult UDPSocket::protocol_bind() +{ + LOCKER(sockets_by_port().lock()); + if (sockets_by_port().resource().contains(source_port())) + return KResult(-EADDRINUSE); + sockets_by_port().resource().set(source_port(), this); + return KSuccess; +} diff --git a/Kernel/Net/UDPSocket.h b/Kernel/Net/UDPSocket.h index 41506d1df0..693f6eb349 100644 --- a/Kernel/Net/UDPSocket.h +++ b/Kernel/Net/UDPSocket.h @@ -20,6 +20,7 @@ private: virtual int protocol_send(const void*, int) override; virtual KResult protocol_connect(FileDescriptor&, ShouldBlock) override { return KSuccess; } virtual int protocol_allocate_source_port() override; + virtual KResult protocol_bind() override; }; class UDPSocketHandle : public SocketHandle {