From d0f314b23cbf430848010669d9ccd92f2aee5879 Mon Sep 17 00:00:00 2001 From: Sahan Fernando Date: Tue, 18 May 2021 21:49:34 +1000 Subject: [PATCH] Kernel: Fix subtle race condition in sys$write implementation There is a slight race condition in our implementation of write(). We call File::can_write() before attempting to write to it (blocking if it returns false). If it returns true, we assume that we can write to the file, and our code assumes that File::write() cannot possibly fail by being blocked. There is, however, the rare case where another process writes to the file and prevents further writes in between the call to Files::can_write() and File::write() in the first process. This would result in the first process calling File::write() when it cannot be written to. We fix this by adding a mechanism for File::can_write() to signal that it was blocked, making it the responsibilty of File::write() to check whether it can write and then finally making sys$write() check if the write failed due to it being blocked. --- Kernel/Devices/SerialDevice.cpp | 8 +++++--- Kernel/Devices/SerialDevice.h | 1 + Kernel/Syscalls/write.cpp | 20 +++++++++----------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/Kernel/Devices/SerialDevice.cpp b/Kernel/Devices/SerialDevice.cpp index d3bb36e166..ef41d620be 100644 --- a/Kernel/Devices/SerialDevice.cpp +++ b/Kernel/Devices/SerialDevice.cpp @@ -31,6 +31,7 @@ KResultOr SerialDevice::read(FileDescription&, u64, UserOrKernelBuffer& if (!size) return 0; + ScopedSpinLock lock(m_serial_lock); if (!(get_line_status() & DataReady)) return 0; @@ -46,13 +47,14 @@ bool SerialDevice::can_write(const FileDescription&, size_t) const return (get_line_status() & EmptyTransmitterHoldingRegister) != 0; } -KResultOr SerialDevice::write(FileDescription&, u64, const UserOrKernelBuffer& buffer, size_t size) +KResultOr SerialDevice::write(FileDescription& description, u64, const UserOrKernelBuffer& buffer, size_t size) { if (!size) return 0; - if (!(get_line_status() & EmptyTransmitterHoldingRegister)) - return 0; + ScopedSpinLock lock(m_serial_lock); + if (!can_write(description, size)) + return EAGAIN; return buffer.read_buffered<128>(size, [&](u8 const* data, size_t data_size) { for (size_t i = 0; i < data_size; i++) diff --git a/Kernel/Devices/SerialDevice.h b/Kernel/Devices/SerialDevice.h index 11c36f32df..71a860000a 100644 --- a/Kernel/Devices/SerialDevice.h +++ b/Kernel/Devices/SerialDevice.h @@ -135,6 +135,7 @@ private: bool m_break_enable { false }; u8 m_modem_control { 0 }; bool m_last_put_char_was_carriage_return { false }; + SpinLock m_serial_lock; }; } diff --git a/Kernel/Syscalls/write.cpp b/Kernel/Syscalls/write.cpp index c0185f9b11..0ba686ca48 100644 --- a/Kernel/Syscalls/write.cpp +++ b/Kernel/Syscalls/write.cpp @@ -60,10 +60,6 @@ KResultOr Process::sys$writev(int fd, Userspace io KResultOr Process::do_write(FileDescription& description, const UserOrKernelBuffer& data, size_t data_size) { ssize_t total_nwritten = 0; - if (!description.is_blocking()) { - if (!description.can_write()) - return EAGAIN; - } if (description.should_append() && description.file().is_seekable()) { auto seek_result = description.seek(0, SEEK_END); @@ -72,11 +68,12 @@ KResultOr Process::do_write(FileDescription& description, const UserOrK } while ((size_t)total_nwritten < data_size) { - if (!description.can_write()) { + while (!description.can_write()) { if (!description.is_blocking()) { - // Short write: We can no longer write to this non-blocking description. - VERIFY(total_nwritten > 0); - return total_nwritten; + if (total_nwritten > 0) + return total_nwritten; + else + return EAGAIN; } auto unblock_flags = Thread::FileBlocker::BlockFlags::None; if (Thread::current()->block({}, description, unblock_flags).was_interrupted()) { @@ -87,12 +84,13 @@ KResultOr Process::do_write(FileDescription& description, const UserOrK } auto nwritten_or_error = description.write(data.offset(total_nwritten), data_size - total_nwritten); if (nwritten_or_error.is_error()) { - if (total_nwritten) + if (total_nwritten > 0) return total_nwritten; + if (nwritten_or_error.error() == EAGAIN) + continue; return nwritten_or_error.error(); } - if (nwritten_or_error.value() == 0) - break; + VERIFY(nwritten_or_error.value() > 0); total_nwritten += nwritten_or_error.value(); } return total_nwritten;