diff --git a/Kernel/VM/MemoryManager.h b/Kernel/VM/MemoryManager.h index 138f16865b..af3db1243a 100644 --- a/Kernel/VM/MemoryManager.h +++ b/Kernel/VM/MemoryManager.h @@ -104,7 +104,6 @@ class MemoryManager { friend class PhysicalRegion; friend class AnonymousVMObject; friend class Region; - friend class ScatterGatherRefList; friend class VMObject; public: diff --git a/Kernel/VM/ScatterGatherList.cpp b/Kernel/VM/ScatterGatherList.cpp index 2ae1de1cec..6ebbfece2d 100644 --- a/Kernel/VM/ScatterGatherList.cpp +++ b/Kernel/VM/ScatterGatherList.cpp @@ -19,40 +19,4 @@ ScatterGatherList::ScatterGatherList(AsyncBlockDeviceRequest& request, NonnullRe m_dma_region = MM.allocate_kernel_region_with_vmobject(m_vm_object, page_round_up((request.block_count() * device_block_size)), "AHCI Scattered DMA", Region::Access::Read | Region::Access::Write, Region::Cacheable::Yes); } -ScatterGatherRefList ScatterGatherRefList::create_from_buffer(const u8* buffer, size_t size) -{ - VERIFY(buffer && size); - ScatterGatherRefList new_list; - auto* region = MM.find_region_from_vaddr(VirtualAddress(buffer)); - VERIFY(region); - while (size > 0) { - size_t offset_in_page = (VirtualAddress(buffer) - region->vaddr()).get() % PAGE_SIZE; - size_t size_in_page = min(PAGE_SIZE - offset_in_page, size); - VERIFY(offset_in_page + size_in_page - 1 <= PAGE_SIZE); - new_list.add_entry(region->physical_page(region->page_index_from_address(VirtualAddress(buffer)))->paddr().get(), offset_in_page, size_in_page); - size -= size_in_page; - buffer += size_in_page; - } - return new_list; -} - -ScatterGatherRefList ScatterGatherRefList::create_from_physical(PhysicalAddress paddr, size_t size) -{ - VERIFY(!paddr.is_null() && size); - ScatterGatherRefList new_list; - new_list.add_entry(paddr.page_base().get(), paddr.offset_in_page(), size); - return new_list; -} - -void ScatterGatherRefList::add_entry(FlatPtr addr, size_t offset, size_t size) -{ - m_entries.append({ addr, offset, size }); -} - -void ScatterGatherRefList::for_each_entry(Function callback) const -{ - for (auto& entry : m_entries) - callback(entry.page_base + entry.offset, entry.length); -} - } diff --git a/Kernel/VM/ScatterGatherList.h b/Kernel/VM/ScatterGatherList.h index 526e6acbcc..0d94cbcfa2 100644 --- a/Kernel/VM/ScatterGatherList.h +++ b/Kernel/VM/ScatterGatherList.h @@ -29,26 +29,4 @@ private: OwnPtr m_dma_region; }; -/// A Scatter-Gather List type that doesn't own its buffers - -class ScatterGatherRefList { - struct ScatterGatherRef { - FlatPtr page_base; - size_t offset; - size_t length; - }; - -public: - static ScatterGatherRefList create_from_buffer(const u8* buffer, size_t); - static ScatterGatherRefList create_from_physical(PhysicalAddress, size_t); - - void add_entry(FlatPtr, size_t offset, size_t size); - [[nodiscard]] size_t length() const { return m_entries.size(); } - - void for_each_entry(Function callback) const; - -private: - Vector m_entries; -}; - } diff --git a/Kernel/VirtIO/VirtIO.cpp b/Kernel/VirtIO/VirtIO.cpp index 6dece01201..3f81933637 100644 --- a/Kernel/VirtIO/VirtIO.cpp +++ b/Kernel/VirtIO/VirtIO.cpp @@ -335,13 +335,6 @@ void VirtIODevice::finish_init() dbgln_if(VIRTIO_DEBUG, "{}: Finished initialization", m_class_name); } -void VirtIODevice::supply_buffer_and_notify(u16 queue_index, const ScatterGatherRefList& scatter_list, BufferType buffer_type, void* token) -{ - VERIFY(queue_index < m_queue_count); - if (get_queue(queue_index).supply_buffer({}, scatter_list, buffer_type, token)) - notify_queue(queue_index); -} - u8 VirtIODevice::isr_status() { if (!m_isr_cfg) @@ -353,12 +346,14 @@ void VirtIODevice::handle_irq(const RegisterState&) { u8 isr_type = isr_status(); if (isr_type & DEVICE_CONFIG_INTERRUPT) { + dbgln_if(VIRTIO_DEBUG, "{}: VirtIO Device config interrupt!", m_class_name); if (!handle_device_config_change()) { set_status_bit(DEVICE_STATUS_FAILED); dbgln("{}: Failed to handle device config change!", m_class_name); } } if (isr_type & QUEUE_INTERRUPT) { + dbgln_if(VIRTIO_DEBUG, "{}: VirtIO Queue interrupt!", m_class_name); for (size_t i = 0; i < m_queues.size(); i++) { if (get_queue(i).new_data_available()) return handle_queue_update(i); @@ -369,4 +364,14 @@ void VirtIODevice::handle_irq(const RegisterState&) dbgln("{}: Handling interrupt with unknown type: {}", m_class_name, isr_type); } +void VirtIODevice::supply_chain_and_notify(u16 queue_index, VirtIOQueueChain& chain) +{ + auto& queue = get_queue(queue_index); + VERIFY(&chain.queue() == &queue); + VERIFY(queue.lock().is_locked()); + chain.submit_to_queue(); + if (queue.should_notify()) + notify_queue(queue_index); +} + } diff --git a/Kernel/VirtIO/VirtIO.h b/Kernel/VirtIO/VirtIO.h index b30159f2a3..800b5b7d2e 100644 --- a/Kernel/VirtIO/VirtIO.h +++ b/Kernel/VirtIO/VirtIO.h @@ -12,7 +12,6 @@ #include #include #include -#include #include namespace Kernel { @@ -196,7 +195,7 @@ protected: return is_feature_set(m_accepted_features, feature); } - void supply_buffer_and_notify(u16 queue_index, const ScatterGatherRefList&, BufferType, void* token); + void supply_chain_and_notify(u16 queue_index, VirtIOQueueChain& chain); virtual bool handle_device_config_change() = 0; virtual void handle_queue_update(u16 queue_index) = 0; diff --git a/Kernel/VirtIO/VirtIOConsole.cpp b/Kernel/VirtIO/VirtIOConsole.cpp index 3376732d11..69b80ac2c3 100644 --- a/Kernel/VirtIO/VirtIOConsole.cpp +++ b/Kernel/VirtIO/VirtIOConsole.cpp @@ -4,7 +4,6 @@ * SPDX-License-Identifier: BSD-2-Clause */ -#include #include namespace Kernel { @@ -41,11 +40,8 @@ VirtIOConsole::VirtIOConsole(PCI::Address address) } if (success) { finish_init(); - m_receive_region = MM.allocate_contiguous_kernel_region(PAGE_SIZE, "VirtIOConsole Receive", Region::Access::Read | Region::Access::Write); - if (m_receive_region) { - supply_buffer_and_notify(RECEIVEQ, ScatterGatherRefList::create_from_physical(m_receive_region->physical_page(0)->paddr(), m_receive_region->size()), BufferType::DeviceWritable, m_receive_region->vaddr().as_ptr()); - } - m_transmit_region = MM.allocate_contiguous_kernel_region(PAGE_SIZE, "VirtIOConsole Transmit", Region::Access::Read | Region::Access::Write); + m_receive_buffer = make("VirtIOConsole Receive", RINGBUFFER_SIZE); + m_transmit_buffer = make("VirtIOConsole Transmit", RINGBUFFER_SIZE); } } } @@ -62,14 +58,31 @@ bool VirtIOConsole::handle_device_config_change() void VirtIOConsole::handle_queue_update(u16 queue_index) { + dbgln_if(VIRTIO_DEBUG, "VirtIOConsole: Handle queue update"); VERIFY(queue_index <= TRANSMITQ); switch (queue_index) { - case RECEIVEQ: + case RECEIVEQ: { + ScopedSpinLock lock(get_queue(RECEIVEQ).lock()); get_queue(RECEIVEQ).discard_used_buffers(); // TODO: do something with incoming data (users writing into qemu console) instead of just clearing break; - case TRANSMITQ: - get_queue(TRANSMITQ).discard_used_buffers(); // clear outgoing buffers that the device finished with + } + case TRANSMITQ: { + ScopedSpinLock ringbuffer_lock(m_transmit_buffer->lock()); + auto& queue = get_queue(TRANSMITQ); + ScopedSpinLock queue_lock(queue.lock()); + size_t used; + VirtIOQueueChain popped_chain = queue.pop_used_buffer_chain(used); + do { + popped_chain.for_each([this](PhysicalAddress address, size_t length) { + m_transmit_buffer->reclaim_space(address, length); + }); + popped_chain.release_buffer_slots_to_queue(); + popped_chain = queue.pop_used_buffer_chain(used); + } while (!popped_chain.is_empty()); + // Unblock any IO tasks that were blocked because can_write() returned false + evaluate_block_conditions(); break; + } default: VERIFY_NOT_REACHED(); } @@ -77,31 +90,50 @@ void VirtIOConsole::handle_queue_update(u16 queue_index) bool VirtIOConsole::can_read(const FileDescription&, size_t) const { - return false; + return true; } -KResultOr VirtIOConsole::read(FileDescription&, u64, [[maybe_unused]] UserOrKernelBuffer& data, size_t size) +KResultOr VirtIOConsole::read(FileDescription&, u64, [[maybe_unused]] UserOrKernelBuffer& data, size_t) { - if (!size) - return 0; - - return 1; + return ENOTSUP; } bool VirtIOConsole::can_write(const FileDescription&, size_t) const { - return get_queue(TRANSMITQ).can_write(); + return get_queue(TRANSMITQ).has_free_slots() && m_transmit_buffer->has_space(); } -KResultOr VirtIOConsole::write(FileDescription&, u64, const UserOrKernelBuffer& data, size_t size) +KResultOr VirtIOConsole::write(FileDescription& desc, u64, const UserOrKernelBuffer& data, size_t size) { if (!size) return 0; - auto scatter_list = ScatterGatherRefList::create_from_buffer(static_cast(data.user_or_kernel_ptr()), size); - supply_buffer_and_notify(TRANSMITQ, scatter_list, BufferType::DeviceReadable, const_cast(data.user_or_kernel_ptr())); + if (!can_write(desc, size)) + return EAGAIN; - return size; + ScopedSpinLock ringbuffer_lock(m_transmit_buffer->lock()); + auto& queue = get_queue(TRANSMITQ); + ScopedSpinLock queue_lock(queue.lock()); + VirtIOQueueChain chain(queue); + + size_t total_bytes_copied = 0; + do { + PhysicalAddress start_of_chunk; + size_t length_of_chunk; + + if (!m_transmit_buffer->copy_data_in(data, total_bytes_copied, size - total_bytes_copied, start_of_chunk, length_of_chunk)) { + chain.release_buffer_slots_to_queue(); + return EINVAL; + } + + bool did_add_buffer = chain.add_buffer_to_chain(start_of_chunk, length_of_chunk, BufferType::DeviceReadable); + VERIFY(did_add_buffer); + total_bytes_copied += length_of_chunk; + } while (total_bytes_copied < size && can_write(desc, size - total_bytes_copied)); + + supply_chain_and_notify(TRANSMITQ, chain); + + return total_bytes_copied; } } diff --git a/Kernel/VirtIO/VirtIOConsole.h b/Kernel/VirtIO/VirtIOConsole.h index 7dc51be0bd..c3149dab65 100644 --- a/Kernel/VirtIO/VirtIOConsole.h +++ b/Kernel/VirtIO/VirtIOConsole.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include namespace Kernel { @@ -25,6 +26,7 @@ public: virtual ~VirtIOConsole() override; private: + constexpr static size_t RINGBUFFER_SIZE = 2 * PAGE_SIZE; virtual const char* class_name() const override { return m_class_name.characters(); } virtual bool can_read(const FileDescription&, size_t) const override; @@ -38,8 +40,8 @@ private: virtual String device_name() const override { return String::formatted("hvc{}", minor()); } virtual void handle_queue_update(u16 queue_index) override; - OwnPtr m_receive_region; - OwnPtr m_transmit_region; + OwnPtr m_receive_buffer; + OwnPtr m_transmit_buffer; static unsigned next_device_id; }; diff --git a/Kernel/VirtIO/VirtIOQueue.cpp b/Kernel/VirtIO/VirtIOQueue.cpp index 1f1151f2d9..c890b50e11 100644 --- a/Kernel/VirtIO/VirtIOQueue.cpp +++ b/Kernel/VirtIO/VirtIOQueue.cpp @@ -4,7 +4,7 @@ * SPDX-License-Identifier: BSD-2-Clause */ -#include +#include #include namespace Kernel { @@ -25,10 +25,9 @@ VirtIOQueue::VirtIOQueue(u16 queue_size, u16 notify_offset) m_descriptors = reinterpret_cast(ptr); m_driver = reinterpret_cast(ptr + size_of_descriptors); m_device = reinterpret_cast(ptr + size_of_descriptors + size_of_driver); - m_tokens.resize(queue_size); - for (auto i = 0; i < queue_size; i++) { - m_descriptors[i].next = (i + 1) % queue_size; // link all of the descriptors in a circle + for (auto i = 0; i + 1 < queue_size; i++) { + m_descriptors[i].next = i + 1; // link all of the descriptors in a line } enable_interrupts(); @@ -40,94 +39,157 @@ VirtIOQueue::~VirtIOQueue() void VirtIOQueue::enable_interrupts() { + ScopedSpinLock lock(m_lock); m_driver->flags = 0; } void VirtIOQueue::disable_interrupts() { + ScopedSpinLock lock(m_lock); m_driver->flags = 1; } -bool VirtIOQueue::supply_buffer(Badge, const ScatterGatherRefList& scatter_list, BufferType buffer_type, void* token) -{ - VERIFY(scatter_list.length() && scatter_list.length() <= m_free_buffers); - m_free_buffers -= scatter_list.length(); - - auto descriptor_index = m_free_head; - auto last_index = descriptor_index; - scatter_list.for_each_entry([&](auto paddr, auto size) { - m_descriptors[descriptor_index].flags = static_cast(buffer_type) | VIRTQ_DESC_F_NEXT; - m_descriptors[descriptor_index].address = static_cast(paddr); - m_descriptors[descriptor_index].length = static_cast(size); - last_index = descriptor_index; - descriptor_index = m_descriptors[descriptor_index].next; // ensure we place the buffer in chain order - }); - m_descriptors[last_index].flags &= ~(VIRTQ_DESC_F_NEXT); // last descriptor in chain doesn't have a next descriptor - - m_driver->rings[m_driver_index_shadow % m_queue_size] = m_free_head; // m_driver_index_shadow is used to prevent accesses to index before the rings are updated - m_tokens[m_free_head] = token; - m_free_head = descriptor_index; - - full_memory_barrier(); - - m_driver_index_shadow++; - m_driver->index++; - - full_memory_barrier(); - - auto device_flags = m_device->flags; - return !(device_flags & 1); // if bit 1 is enabled the device disabled interrupts -} - bool VirtIOQueue::new_data_available() const { - return m_device->index != m_used_tail; + const auto index = AK::atomic_load(&m_device->index, AK::MemoryOrder::memory_order_relaxed); + const auto used_tail = AK::atomic_load(&m_used_tail, AK::MemoryOrder::memory_order_relaxed); + return index != used_tail; } -void* VirtIOQueue::get_buffer(size_t* size) +VirtIOQueueChain VirtIOQueue::pop_used_buffer_chain(size_t& used) { + VERIFY(m_lock.is_locked()); if (!new_data_available()) { - *size = 0; - return nullptr; + used = 0; + return VirtIOQueueChain(*this); } full_memory_barrier(); - auto descriptor_index = m_device->rings[m_used_tail % m_queue_size].index; - *size = m_device->rings[m_used_tail % m_queue_size].length; + // Determine used length + used = m_device->rings[m_used_tail % m_queue_size].length; + // Determine start, end and number of nodes in chain + auto descriptor_index = m_device->rings[m_used_tail % m_queue_size].index; + size_t length_of_chain = 1; + auto last_index = descriptor_index; + while (m_descriptors[last_index].flags & VIRTQ_DESC_F_NEXT) { + ++length_of_chain; + last_index = m_descriptors[last_index].next; + } + + // We are now done with this buffer chain m_used_tail++; - auto token = m_tokens[descriptor_index]; - pop_buffer(descriptor_index); - return token; + return VirtIOQueueChain(*this, descriptor_index, last_index, length_of_chain); } void VirtIOQueue::discard_used_buffers() { - size_t size; - while (!get_buffer(&size)) { + VERIFY(m_lock.is_locked()); + size_t used; + for (auto buffer = pop_used_buffer_chain(used); !buffer.is_empty(); buffer = pop_used_buffer_chain(used)) { + buffer.release_buffer_slots_to_queue(); } } -void VirtIOQueue::pop_buffer(u16 descriptor_index) +void VirtIOQueue::reclaim_buffer_chain(u16 chain_start_index, u16 chain_end_index, size_t length_of_chain) { - m_tokens[descriptor_index] = nullptr; + VERIFY(m_lock.is_locked()); + m_descriptors[chain_end_index].next = m_free_head; + m_free_head = chain_start_index; + m_free_buffers += length_of_chain; +} - auto i = descriptor_index; - while (m_descriptors[i].flags & VIRTQ_DESC_F_NEXT) { - m_free_buffers++; - i = m_descriptors[i].next; +bool VirtIOQueue::has_free_slots() const +{ + const auto free_buffers = AK::atomic_load(&m_free_buffers, AK::MemoryOrder::memory_order_relaxed); + return free_buffers > 0; +} + +Optional VirtIOQueue::take_free_slot() +{ + VERIFY(m_lock.is_locked()); + if (has_free_slots()) { + auto descriptor_index = m_free_head; + m_free_head = m_descriptors[descriptor_index].next; + --m_free_buffers; + return descriptor_index; + } else { + return {}; } - m_free_buffers++; // the last descriptor in the chain doesn't have the NEXT flag - - m_descriptors[i].next = m_free_head; // empend the popped descriptors to the free chain - m_free_head = descriptor_index; } -bool VirtIOQueue::can_write() const +bool VirtIOQueue::should_notify() const { - return m_free_buffers > 0; + VERIFY(m_lock.is_locked()); + auto device_flags = m_device->flags; + return !(device_flags & VIRTQ_USED_F_NO_NOTIFY); +} + +bool VirtIOQueueChain::add_buffer_to_chain(PhysicalAddress buffer_start, size_t buffer_length, BufferType buffer_type) +{ + VERIFY(m_queue.lock().is_locked()); + + // Ensure that no readable pages will be inserted after a writable one, as required by the VirtIO spec + VERIFY(buffer_type == BufferType::DeviceWritable || !m_chain_has_writable_pages); + m_chain_has_writable_pages |= (buffer_type == BufferType::DeviceWritable); + + // Take a free slot from the queue + auto descriptor_index = m_queue.take_free_slot(); + if (!descriptor_index.has_value()) + return false; + + if (!m_start_of_chain_index.has_value()) { + // Set start of chain if it hasn't been set + m_start_of_chain_index = descriptor_index.value(); + } else { + // Link from previous element in VirtIOQueueChain + m_queue.m_descriptors[m_end_of_chain_index.value()].flags |= VIRTQ_DESC_F_NEXT; + m_queue.m_descriptors[m_end_of_chain_index.value()].next = descriptor_index.value(); + } + + // Update end of chain + m_end_of_chain_index = descriptor_index.value(); + ++m_chain_length; + + // Populate buffer info + VERIFY(buffer_length <= NumericLimits::max()); + m_queue.m_descriptors[descriptor_index.value()].address = static_cast(buffer_start.get()); + m_queue.m_descriptors[descriptor_index.value()].flags = static_cast(buffer_type); + m_queue.m_descriptors[descriptor_index.value()].length = static_cast(buffer_length); + + return true; +} + +void VirtIOQueueChain::submit_to_queue() +{ + VERIFY(m_queue.lock().is_locked()); + VERIFY(m_start_of_chain_index.has_value()); + + auto next_index = m_queue.m_driver_index_shadow % m_queue.m_queue_size; + m_queue.m_driver->rings[next_index] = m_start_of_chain_index.value(); + m_queue.m_driver_index_shadow++; + full_memory_barrier(); + m_queue.m_driver->index = m_queue.m_driver_index_shadow; + + // Reset internal chain state + m_start_of_chain_index = m_end_of_chain_index = {}; + m_chain_has_writable_pages = false; + m_chain_length = 0; +} + +void VirtIOQueueChain::release_buffer_slots_to_queue() +{ + VERIFY(m_queue.lock().is_locked()); + if (m_start_of_chain_index.has_value()) { + // Add the currently stored chain back to the queue's free pool + m_queue.reclaim_buffer_chain(m_start_of_chain_index.value(), m_end_of_chain_index.value(), m_chain_length); + // Reset internal chain state + m_start_of_chain_index = m_end_of_chain_index = {}; + m_chain_has_writable_pages = false; + m_chain_length = 0; + } } } diff --git a/Kernel/VirtIO/VirtIOQueue.h b/Kernel/VirtIO/VirtIOQueue.h index 2ea8ac38a7..7e7fab667a 100644 --- a/Kernel/VirtIO/VirtIOQueue.h +++ b/Kernel/VirtIO/VirtIOQueue.h @@ -6,7 +6,6 @@ #pragma once -#include #include #include #include @@ -16,12 +15,16 @@ namespace Kernel { #define VIRTQ_DESC_F_NEXT 1 #define VIRTQ_DESC_F_INDIRECT 4 +#define VIRTQ_AVAIL_F_NO_INTERRUPT 1 +#define VIRTQ_USED_F_NO_NOTIFY 1 + enum class BufferType { DeviceReadable = 0, DeviceWritable = 2 }; class VirtIODevice; +class VirtIOQueueChain; class VirtIOQueue { public: @@ -38,14 +41,18 @@ public: PhysicalAddress driver_area() const { return to_physical(m_driver.ptr()); } PhysicalAddress device_area() const { return to_physical(m_device.ptr()); } - bool supply_buffer(Badge, const ScatterGatherRefList&, BufferType, void* token); bool new_data_available() const; - bool can_write() const; - void* get_buffer(size_t*); + bool has_free_slots() const; + Optional take_free_slot(); + VirtIOQueueChain pop_used_buffer_chain(size_t& used); void discard_used_buffers(); + SpinLock& lock() { return m_lock; } + + bool should_notify() const; + private: - void pop_buffer(u16 descriptor_index); + void reclaim_buffer_chain(u16 chain_start_index, u16 chain_end_index, size_t length_of_chain); PhysicalAddress to_physical(const void* ptr) const { @@ -86,9 +93,94 @@ private: OwnPtr m_descriptors { nullptr }; OwnPtr m_driver { nullptr }; OwnPtr m_device { nullptr }; - Vector m_tokens; OwnPtr m_queue_region; SpinLock m_lock; + + friend class VirtIOQueueChain; +}; + +class VirtIOQueueChain { +public: + VirtIOQueueChain(VirtIOQueue& queue) + : m_queue(queue) + { + } + + VirtIOQueueChain(VirtIOQueue& queue, u16 start_index, u16 end_index, size_t chain_length) + : m_queue(queue) + , m_start_of_chain_index(start_index) + , m_end_of_chain_index(end_index) + , m_chain_length(chain_length) + { + } + + VirtIOQueueChain(VirtIOQueueChain&& other) + : m_queue(other.m_queue) + , m_start_of_chain_index(other.m_start_of_chain_index) + , m_end_of_chain_index(other.m_end_of_chain_index) + , m_chain_length(other.m_chain_length) + , m_chain_has_writable_pages(other.m_chain_has_writable_pages) + { + other.m_start_of_chain_index = {}; + other.m_end_of_chain_index = {}; + other.m_chain_length = 0; + other.m_chain_has_writable_pages = false; + } + + VirtIOQueueChain& operator=(VirtIOQueueChain&& other) + { + VERIFY(&m_queue == &other.m_queue); + ensure_chain_is_empty(); + m_start_of_chain_index = other.m_start_of_chain_index; + m_end_of_chain_index = other.m_end_of_chain_index; + m_chain_length = other.m_chain_length; + m_chain_has_writable_pages = other.m_chain_has_writable_pages; + other.m_start_of_chain_index = {}; + other.m_end_of_chain_index = {}; + other.m_chain_length = 0; + other.m_chain_has_writable_pages = false; + return *this; + } + + ~VirtIOQueueChain() + { + ensure_chain_is_empty(); + } + + [[nodiscard]] VirtIOQueue& queue() const { return m_queue; } + [[nodiscard]] bool is_empty() const { return m_chain_length == 0; } + [[nodiscard]] size_t length() const { return m_chain_length; } + bool add_buffer_to_chain(PhysicalAddress buffer_start, size_t buffer_length, BufferType buffer_type); + void submit_to_queue(); + void release_buffer_slots_to_queue(); + + void for_each(Function callback) + { + VERIFY(m_queue.lock().is_locked()); + if (!m_start_of_chain_index.has_value()) + return; + auto index = m_start_of_chain_index.value(); + for (size_t i = 0; i < m_chain_length; ++i) { + auto addr = m_queue.m_descriptors[index].address; + auto length = m_queue.m_descriptors[index].length; + callback(PhysicalAddress(addr), length); + index = m_queue.m_descriptors[index].next; + } + } + +private: + void ensure_chain_is_empty() const + { + VERIFY(!m_start_of_chain_index.has_value()); + VERIFY(!m_end_of_chain_index.has_value()); + VERIFY(m_chain_length == 0); + } + + VirtIOQueue& m_queue; + Optional m_start_of_chain_index {}; + Optional m_end_of_chain_index {}; + size_t m_chain_length {}; + bool m_chain_has_writable_pages { false }; }; } diff --git a/Kernel/VirtIO/VirtIORNG.cpp b/Kernel/VirtIO/VirtIORNG.cpp index 2fe5b817e2..cbaa6eb853 100644 --- a/Kernel/VirtIO/VirtIORNG.cpp +++ b/Kernel/VirtIO/VirtIORNG.cpp @@ -23,7 +23,7 @@ VirtIORNG::VirtIORNG(PCI::Address address) m_entropy_buffer = MM.allocate_contiguous_kernel_region(PAGE_SIZE, "VirtIORNG", Region::Access::Read | Region::Access::Write); if (m_entropy_buffer) { memset(m_entropy_buffer->vaddr().as_ptr(), 0, m_entropy_buffer->size()); - supply_buffer_and_notify(REQUESTQ, ScatterGatherRefList::create_from_physical(m_entropy_buffer->physical_page(0)->paddr(), m_entropy_buffer->size()), BufferType::DeviceWritable, m_entropy_buffer->vaddr().as_ptr()); + request_entropy_from_host(); } } } @@ -40,14 +40,33 @@ bool VirtIORNG::handle_device_config_change() void VirtIORNG::handle_queue_update(u16 queue_index) { VERIFY(queue_index == REQUESTQ); - size_t available_entropy = 0; - if (!get_queue(REQUESTQ).get_buffer(&available_entropy)) - return; + size_t available_entropy = 0, used; + auto& queue = get_queue(REQUESTQ); + { + ScopedSpinLock lock(queue.lock()); + auto chain = queue.pop_used_buffer_chain(used); + if (chain.is_empty()) + return; + VERIFY(chain.length() == 1); + chain.for_each([&available_entropy](PhysicalAddress, size_t length) { + available_entropy = length; + }); + chain.release_buffer_slots_to_queue(); + } dbgln_if(VIRTIO_DEBUG, "VirtIORNG: received {} bytes of entropy!", available_entropy); for (auto i = 0u; i < available_entropy; i++) { m_entropy_source.add_random_event(m_entropy_buffer->vaddr().as_ptr()[i]); } - // TODO: when should we ask for more entropy from the host? + // TODO: When should we get some more entropy? +} + +void VirtIORNG::request_entropy_from_host() +{ + auto& queue = get_queue(REQUESTQ); + ScopedSpinLock lock(queue.lock()); + VirtIOQueueChain chain(queue); + chain.add_buffer_to_chain(m_entropy_buffer->physical_page(0)->paddr(), PAGE_SIZE, BufferType::DeviceWritable); + supply_chain_and_notify(REQUESTQ, chain); } } diff --git a/Kernel/VirtIO/VirtIORNG.h b/Kernel/VirtIO/VirtIORNG.h index 49ca1f0b2b..6f083fa40e 100644 --- a/Kernel/VirtIO/VirtIORNG.h +++ b/Kernel/VirtIO/VirtIORNG.h @@ -33,6 +33,7 @@ public: private: virtual bool handle_device_config_change() override; virtual void handle_queue_update(u16 queue_index) override; + void request_entropy_from_host(); OwnPtr m_entropy_buffer; EntropySource m_entropy_source;