diff --git a/Kernel/IPv4Socket.cpp b/Kernel/IPv4Socket.cpp index 282a4b5fe1..547c1b0e5e 100644 --- a/Kernel/IPv4Socket.cpp +++ b/Kernel/IPv4Socket.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -12,28 +13,6 @@ #define IPV4_SOCKET_DEBUG -Lockable>& IPv4Socket::sockets_by_udp_port() -{ - static Lockable>* s_map; - if (!s_map) - s_map = new Lockable>; - return *s_map; -} - -IPv4SocketHandle IPv4Socket::from_udp_port(word port) -{ - RetainPtr socket; - { - LOCKER(sockets_by_udp_port().lock()); - auto it = sockets_by_udp_port().resource().find(port); - if (it == sockets_by_udp_port().resource().end()) - return { }; - socket = (*it).value; - ASSERT(socket); - } - return { move(socket) }; -} - Lockable>& IPv4Socket::all_sockets() { static Lockable>* s_table; @@ -46,6 +25,8 @@ Retained IPv4Socket::create(int type, int protocol) { if (type == SOCK_STREAM) return TCPSocket::create(protocol); + if (type == SOCK_DGRAM) + return UDPSocket::create(protocol); return adopt(*new IPv4Socket(type, protocol)); } @@ -59,14 +40,8 @@ IPv4Socket::IPv4Socket(int type, int protocol) IPv4Socket::~IPv4Socket() { - { - LOCKER(all_sockets().lock()); - all_sockets().resource().remove(this); - } - if (type() == SOCK_DGRAM) { - LOCKER(sockets_by_udp_port().lock()); - sockets_by_udp_port().resource().remove(m_source_port); - } + LOCKER(all_sockets().lock()); + all_sockets().resource().remove(this); } bool IPv4Socket::get_address(sockaddr* address, socklen_t* address_size) @@ -139,26 +114,7 @@ void IPv4Socket::allocate_source_port_if_needed() { if (m_source_port) return; - if (type() == SOCK_DGRAM) { - // This is not a very efficient allocation algorithm. - // FIXME: Replace it with a bitmap or some other fast-paced looker-upper. - LOCKER(sockets_by_udp_port().lock()); - for (word port = 2000; port < 60000; ++port) { - auto it = sockets_by_udp_port().resource().find(port); - if (it == sockets_by_udp_port().resource().end()) { - m_source_port = port; - sockets_by_udp_port().resource().set(port, this); - return; - } - } - ASSERT_NOT_REACHED(); - } - if (type() == SOCK_STREAM) { - protocol_allocate_source_port(); - return; - } - - ASSERT_NOT_REACHED(); + protocol_allocate_source_port(); } ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, const sockaddr* addr, socklen_t addr_length) @@ -193,26 +149,7 @@ ssize_t IPv4Socket::sendto(const void* data, size_t data_length, int flags, cons return data_length; } - if (type() == SOCK_DGRAM) { - auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length); - auto& udp_packet = *(UDPPacket*)(buffer.pointer()); - udp_packet.set_source_port(m_source_port); - udp_packet.set_destination_port(m_destination_port); - udp_packet.set_length(sizeof(UDPPacket) + data_length); - memcpy(udp_packet.payload(), data, data_length); - kprintf("sending as udp packet from %s:%u to %s:%u!\n", - adapter->ipv4_address().to_string().characters(), - source_port(), - m_destination_address.to_string().characters(), - m_destination_port); - adapter->send_ipv4(MACAddress(), m_destination_address, IPv4Protocol::UDP, move(buffer)); - return data_length; - } - - if (type() == SOCK_STREAM) - return protocol_send(data, data_length); - - ASSERT_NOT_REACHED(); + return protocol_send(data, data_length); } ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sockaddr* addr, socklen_t* addr_length) @@ -266,22 +203,7 @@ ssize_t IPv4Socket::recvfrom(void* buffer, size_t buffer_length, int flags, sock return ipv4_packet.payload_size(); } - if (type() == SOCK_DGRAM) { - auto& udp_packet = *static_cast(ipv4_packet.payload()); - ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier. - ASSERT(buffer_length >= (udp_packet.length() - sizeof(UDPPacket))); - if (addr) { - auto& ia = *(sockaddr_in*)addr; - ia.sin_port = htons(udp_packet.destination_port()); - } - memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket)); - return udp_packet.length() - sizeof(UDPPacket); - } - - if (type() == SOCK_STREAM) - return protocol_receive(packet_buffer, buffer, buffer_length, flags, addr, addr_length); - - ASSERT_NOT_REACHED(); + return protocol_receive(packet_buffer, buffer, buffer_length, flags, addr, addr_length); } void IPv4Socket::did_receive(ByteBuffer&& packet) diff --git a/Kernel/IPv4Socket.h b/Kernel/IPv4Socket.h index d67c0092a9..dc7ad8d217 100644 --- a/Kernel/IPv4Socket.h +++ b/Kernel/IPv4Socket.h @@ -20,9 +20,6 @@ public: static Lockable>& all_sockets(); - static Lockable>& sockets_by_udp_port(); - static IPv4SocketHandle from_udp_port(word); - virtual KResult bind(const sockaddr*, socklen_t) override; virtual KResult connect(const sockaddr*, socklen_t) override; virtual bool get_address(sockaddr*, socklen_t*) override; diff --git a/Kernel/Makefile b/Kernel/Makefile index 2f69877f12..5d42acab6d 100644 --- a/Kernel/Makefile +++ b/Kernel/Makefile @@ -36,6 +36,7 @@ KERNEL_OBJS = \ LocalSocket.o \ IPv4Socket.o \ TCPSocket.o \ + UDPSocket.o \ NetworkAdapter.o \ E1000NetworkAdapter.o \ NetworkTask.o diff --git a/Kernel/NetworkTask.cpp b/Kernel/NetworkTask.cpp index ef7a308db5..a2d8ff5a08 100644 --- a/Kernel/NetworkTask.cpp +++ b/Kernel/NetworkTask.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -235,7 +236,7 @@ void handle_udp(const EthernetFrameHeader& eth, int frame_size) ); #endif - auto socket = IPv4Socket::from_udp_port(udp_packet.destination_port()); + auto socket = UDPSocket::from_port(udp_packet.destination_port()); if (!socket) { kprintf("handle_udp: No UDP socket for port %u\n", udp_packet.destination_port()); return; diff --git a/Kernel/UDPSocket.cpp b/Kernel/UDPSocket.cpp new file mode 100644 index 0000000000..1214eae465 --- /dev/null +++ b/Kernel/UDPSocket.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include + +Lockable>& UDPSocket::sockets_by_port() +{ + static Lockable>* s_map; + if (!s_map) + s_map = new Lockable>; + return *s_map; +} + +UDPSocketHandle UDPSocket::from_port(word port) +{ + RetainPtr socket; + { + LOCKER(sockets_by_port().lock()); + auto it = sockets_by_port().resource().find(port); + if (it == sockets_by_port().resource().end()) + return { }; + socket = (*it).value; + ASSERT(socket); + } + return { move(socket) }; +} + + +UDPSocket::UDPSocket(int protocol) + : IPv4Socket(SOCK_DGRAM, protocol) +{ +} + +UDPSocket::~UDPSocket() +{ + LOCKER(sockets_by_port().lock()); + sockets_by_port().resource().remove(source_port()); +} + +Retained UDPSocket::create(int protocol) +{ + return adopt(*new UDPSocket(protocol)); +} + +int UDPSocket::protocol_receive(const ByteBuffer& packet_buffer, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) +{ + (void)flags; + (void)addr_length; + ASSERT(!packet_buffer.is_null()); + auto& ipv4_packet = *(const IPv4Packet*)(packet_buffer.pointer()); + auto& udp_packet = *static_cast(ipv4_packet.payload()); + ASSERT(udp_packet.length() >= sizeof(UDPPacket)); // FIXME: This should be rejected earlier. + ASSERT(buffer_size >= (udp_packet.length() - sizeof(UDPPacket))); + if (addr) { + auto& ia = *(sockaddr_in*)addr; + ia.sin_port = htons(udp_packet.destination_port()); + } + memcpy(buffer, udp_packet.payload(), udp_packet.length() - sizeof(UDPPacket)); + return udp_packet.length() - sizeof(UDPPacket); +} + +int UDPSocket::protocol_send(const void* data, int data_length) +{ + // FIXME: Figure out the adapter somehow differently. + auto& adapter = *NetworkAdapter::from_ipv4_address(IPv4Address(192, 168, 5, 2)); + auto buffer = ByteBuffer::create_zeroed(sizeof(UDPPacket) + data_length); + auto& udp_packet = *(UDPPacket*)(buffer.pointer()); + udp_packet.set_source_port(source_port()); + udp_packet.set_destination_port(destination_port()); + udp_packet.set_length(sizeof(UDPPacket) + data_length); + memcpy(udp_packet.payload(), data, data_length); + kprintf("sending as udp packet from %s:%u to %s:%u!\n", + adapter.ipv4_address().to_string().characters(), + source_port(), + destination_address().to_string().characters(), + destination_port()); + adapter.send_ipv4(MACAddress(), destination_address(), IPv4Protocol::UDP, move(buffer)); + return data_length; +} + +KResult UDPSocket::protocol_connect() +{ + return KSuccess; +} + +void UDPSocket::protocol_allocate_source_port() +{ + // This is not a very efficient allocation algorithm. + // FIXME: Replace it with a bitmap or some other fast-paced looker-upper. + LOCKER(sockets_by_port().lock()); + for (word port = 2000; port < 60000; ++port) { + auto it = sockets_by_port().resource().find(port); + if (it == sockets_by_port().resource().end()) { + set_source_port(port); + sockets_by_port().resource().set(port, this); + return; + } + } +} diff --git a/Kernel/UDPSocket.h b/Kernel/UDPSocket.h new file mode 100644 index 0000000000..4bfe6ee712 --- /dev/null +++ b/Kernel/UDPSocket.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +class UDPSocketHandle; + +class UDPSocket final : public IPv4Socket { +public: + static Retained create(int protocol); + virtual ~UDPSocket() override; + + static Lockable>& sockets_by_port(); + static UDPSocketHandle from_port(word); + +private: + explicit UDPSocket(int protocol); + + virtual int protocol_receive(const ByteBuffer&, void* buffer, size_t buffer_size, int flags, sockaddr* addr, socklen_t* addr_length) override; + virtual int protocol_send(const void*, int) override; + virtual KResult protocol_connect() override; + virtual void protocol_allocate_source_port() override; +}; + +class UDPSocketHandle : public SocketHandle { +public: + UDPSocketHandle() { } + + UDPSocketHandle(RetainPtr&& socket) + : SocketHandle(move(socket)) + { + } + + UDPSocketHandle(UDPSocketHandle&& other) + : SocketHandle(move(other)) + { + } + + UDPSocketHandle(const UDPSocketHandle&) = delete; + UDPSocketHandle& operator=(const UDPSocketHandle&) = delete; + + UDPSocket* operator->() { return &socket(); } + const UDPSocket* operator->() const { return &socket(); } + + UDPSocket& socket() { return static_cast(SocketHandle::socket()); } + const UDPSocket& socket() const { return static_cast(SocketHandle::socket()); } +}; +