diff --git a/Kernel/IPv4Socket.cpp b/Kernel/IPv4Socket.cpp index 163e7597ff..282a4b5fe1 100644 --- a/Kernel/IPv4Socket.cpp +++ b/Kernel/IPv4Socket.cpp @@ -20,28 +20,6 @@ Lockable>& IPv4Socket::sockets_by_udp_port() return *s_map; } -Lockable>& IPv4Socket::sockets_by_tcp_port() -{ - static Lockable>* s_map; - if (!s_map) - s_map = new Lockable>; - return *s_map; -} - -TCPSocketHandle IPv4Socket::from_tcp_port(word port) -{ - RetainPtr socket; - { - LOCKER(sockets_by_tcp_port().lock()); - auto it = sockets_by_tcp_port().resource().find(port); - if (it == sockets_by_tcp_port().resource().end()) - return { }; - socket = (*it).value; - ASSERT(socket); - } - return { move(socket) }; -} - IPv4SocketHandle IPv4Socket::from_udp_port(word port) { RetainPtr socket; @@ -89,10 +67,6 @@ IPv4Socket::~IPv4Socket() LOCKER(sockets_by_udp_port().lock()); sockets_by_udp_port().resource().remove(m_source_port); } - if (type() == SOCK_STREAM) { - LOCKER(sockets_by_tcp_port().lock()); - sockets_by_tcp_port().resource().remove(m_source_port); - } } bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size) @@ -180,19 +154,11 @@ void IPv4Socket::allocate_source_port_if_needed() ASSERT_NOT_REACHED(); } if (type() == SOCK_STREAM) { - // This is not a very efficient allocation algorithm. - // FIXME: Replace it with a bitmap or some other fast-paced looker-upper. - LOCKER(sockets_by_tcp_port().lock()); - for (word port = 2000; port < 60000; ++port) { - auto it = sockets_by_tcp_port().resource().find(port); - if (it == sockets_by_tcp_port().resource().end()) { - m_source_port = port; - sockets_by_tcp_port().resource().set(port, static_cast(this)); - return; - } - } - ASSERT_NOT_REACHED(); + protocol_allocate_source_port(); + return; } + + ASSERT_NOT_REACHED(); } ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length) diff --git a/Kernel/IPv4Socket.h b/Kernel/IPv4Socket.h index 5951099854..d67c0092a9 100644 --- a/Kernel/IPv4Socket.h +++ b/Kernel/IPv4Socket.h @@ -19,10 +19,8 @@ public: virtual ~IPv4Socket() override; static Lockable>& all_sockets(); - static Lockable>& sockets_by_udp_port(); - static Lockable>& sockets_by_tcp_port(); - static TCPSocketHandle from_tcp_port(word); + static Lockable>& sockets_by_udp_port(); static IPv4SocketHandle from_udp_port(word); virtual KResult bind(const sockaddr*, socklen_t) override; @@ -41,17 +39,21 @@ public: const IPv4Address& source_address() const; word source_port() const { return m_source_port; } + void set_source_port(word port) { m_source_port = port; } const IPv4Address& destination_address() const { return m_destination_address; } word destination_port() const { return m_destination_port; } + void set_destination_port(word port) { m_destination_port = port; } protected: IPv4Socket(int type, int protocol); + void allocate_source_port_if_needed(); virtual int protocol_receive(const ByteBuffer&, void*, size_t, int, sockaddr*, socklen_t*) { return -ENOTIMPL; } virtual int protocol_send(const void*, int) { return -ENOTIMPL; } virtual KResult protocol_connect() { return KSuccess; } + virtual void protocol_allocate_source_port() { } private: virtual bool is_ipv4() const override { return true; } diff --git a/Kernel/NetworkTask.cpp b/Kernel/NetworkTask.cpp index e79f68d9e3..ef7a308db5 100644 --- a/Kernel/NetworkTask.cpp +++ b/Kernel/NetworkTask.cpp @@ -276,7 +276,7 @@ void handle_tcp(const EthernetFrameHeader& eth, int frame_size) ); #endif - auto socket = IPv4Socket::from_tcp_port(tcp_packet.destination_port()); + auto socket = TCPSocket::from_port(tcp_packet.destination_port()); if (!socket) { kprintf("handle_tcp: No TCP socket for port %u\n", tcp_packet.destination_port()); return; diff --git a/Kernel/TCPSocket.cpp b/Kernel/TCPSocket.cpp index 0433d6d809..00a14dcf3f 100644 --- a/Kernel/TCPSocket.cpp +++ b/Kernel/TCPSocket.cpp @@ -3,6 +3,29 @@ #include #include +Lockable>& TCPSocket::sockets_by_port() +{ + static Lockable>* s_map; + if (!s_map) + s_map = new Lockable>; + return *s_map; +} + +TCPSocketHandle TCPSocket::from_port(word port) +{ + RetainPtr socket; + { + LOCKER(sockets_by_port().lock()); + auto it = sockets_by_port().resource().find(port); + if (it == sockets_by_port().resource().end()) + return { }; + socket = (*it).value; + ASSERT(socket); + } + return { move(socket) }; +} + + TCPSocket::TCPSocket(int protocol) : IPv4Socket(SOCK_STREAM, protocol) { @@ -10,6 +33,8 @@ TCPSocket::TCPSocket(int protocol) TCPSocket::~TCPSocket() { + LOCKER(sockets_by_port().lock()); + sockets_by_port().resource().remove(source_port()); } Retained TCPSocket::create(int protocol) @@ -146,3 +171,18 @@ KResult TCPSocket::protocol_connect() ASSERT(is_connected()); return KSuccess; } + +void TCPSocket::protocol_allocate_source_port() +{ + // This is not a very efficient allocation algorithm. + // FIXME: Replace it with a bitmap or some other fast-paced looker-upper. + LOCKER(sockets_by_port().lock()); + for (word port = 2000; port < 60000; ++port) { + auto it = sockets_by_port().resource().find(port); + if (it == sockets_by_port().resource().end()) { + set_source_port(port); + sockets_by_port().resource().set(port, this); + return; + } + } +} diff --git a/Kernel/TCPSocket.h b/Kernel/TCPSocket.h index 92b431463a..48e1284417 100644 --- a/Kernel/TCPSocket.h +++ b/Kernel/TCPSocket.h @@ -24,6 +24,9 @@ public: void send_tcp_packet(word flags, const void* = nullptr, int = 0); + static Lockable>& sockets_by_port(); + static TCPSocketHandle from_port(word); + private: explicit TCPSocket(int protocol); @@ -32,6 +35,7 @@ private: virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; virtual int protocol_send(const void*, int) override; virtual KResult protocol_connect() override; + virtual void protocol_allocate_source_port() override; dword m_sequence_number { 0 }; dword m_ack_number { 0 };