diff --git a/Userland/Libraries/LibWebSocket/WebSocket.cpp b/Userland/Libraries/LibWebSocket/WebSocket.cpp index a1e692d277..887d0e4547 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.cpp +++ b/Userland/Libraries/LibWebSocket/WebSocket.cpp @@ -134,6 +134,13 @@ void WebSocket::drain_read() } break; case InternalState::Open: case InternalState::Closing: { + auto result = m_impl->read(65536); + if (result.is_error()) { + fatal_error(WebSocket::Error::ServerClosedSocket); + return; + } + auto bytes = result.release_value(); + m_buffered_data.append(bytes.data(), bytes.size()); read_frame(); } break; case InternalState::Closed: @@ -375,16 +382,23 @@ void WebSocket::read_frame() VERIFY(m_impl); VERIFY(m_state == WebSocket::InternalState::Open || m_state == WebSocket::InternalState::Closing); - auto head_bytes_result = m_impl->read(2); - if (head_bytes_result.is_error() || head_bytes_result.value().is_empty()) { + size_t cursor = 0; + auto get_buffered_bytes = [&](size_t count) -> ReadonlyBytes { + if (cursor + count > m_buffered_data.size()) + return {}; + auto bytes = m_buffered_data.span().slice(cursor, count); + cursor += count; + return bytes; + }; + + auto head_bytes = get_buffered_bytes(2); + if (head_bytes.is_null() || head_bytes.is_empty()) { // The connection got closed. m_state = WebSocket::InternalState::Closed; notify_close(m_last_close_code, m_last_close_message, true); discard_connection(); return; } - auto head_bytes = head_bytes_result.release_value(); - VERIFY(head_bytes.size() == 2); bool is_final_frame = head_bytes[0] & 0x80; if (!is_final_frame) { @@ -400,8 +414,9 @@ void WebSocket::read_frame() auto payload_length_bits = head_bytes[1] & 0x7f; if (payload_length_bits == 127) { // A code of 127 means that the next 8 bytes contains the payload length - auto actual_bytes = MUST(m_impl->read(8)); - VERIFY(actual_bytes.size() == 8); + auto actual_bytes = get_buffered_bytes(8); + if (actual_bytes.is_null()) + return; u64 full_payload_length = (u64)((u64)(actual_bytes[0] & 0xff) << 56) | (u64)((u64)(actual_bytes[1] & 0xff) << 48) | (u64)((u64)(actual_bytes[2] & 0xff) << 40) @@ -414,8 +429,9 @@ void WebSocket::read_frame() payload_length = (size_t)full_payload_length; } else if (payload_length_bits == 126) { // A code of 126 means that the next 2 bytes contains the payload length - auto actual_bytes = MUST(m_impl->read(2)); - VERIFY(actual_bytes.size() == 2); + auto actual_bytes = get_buffered_bytes(2); + if (actual_bytes.is_null()) + return; payload_length = (size_t)((size_t)(actual_bytes[0] & 0xff) << 8) | (size_t)((size_t)(actual_bytes[1] & 0xff) << 0); } else { @@ -430,8 +446,9 @@ void WebSocket::read_frame() // But because it doesn't cost much, we can support receiving masked frames anyways. u8 masking_key[4]; if (is_masked) { - auto masking_key_data = MUST(m_impl->read(4)); - VERIFY(masking_key_data.size() == 4); + auto masking_key_data = get_buffered_bytes(4); + if (masking_key_data.is_null()) + return; masking_key[0] = masking_key_data[0]; masking_key[1] = masking_key_data[1]; masking_key[2] = masking_key_data[2]; @@ -441,19 +458,22 @@ void WebSocket::read_frame() auto payload = ByteBuffer::create_uninitialized(payload_length).release_value_but_fixme_should_propagate_errors(); // FIXME: Handle possible OOM situation. u64 read_length = 0; while (read_length < payload_length) { - auto payload_part_result = m_impl->read(payload_length - read_length); - if (payload_part_result.is_error() || payload_part_result.value().is_empty()) { - // We got disconnected, somehow. - dbgln("Websocket: Server disconnected while sending payload ({} bytes read out of {})", read_length, payload_length); - fatal_error(WebSocket::Error::ServerClosedSocket); + auto payload_part = get_buffered_bytes(payload_length - read_length); + if (payload_part.is_null()) return; - } - auto payload_part = payload_part_result.release_value(); // We read at most "actual_length - read" bytes, so this is safe to do. payload.overwrite(read_length, payload_part.data(), payload_part.size()); read_length += payload_part.size(); } + if (cursor == m_buffered_data.size()) { + m_buffered_data.clear(); + } else { + Vector new_buffered_data; + new_buffered_data.append(m_buffered_data.data() + cursor, m_buffered_data.size() - cursor); + m_buffered_data = move(new_buffered_data); + } + if (is_masked) { // Unmask the payload for (size_t i = 0; i < payload.size(); ++i) { diff --git a/Userland/Libraries/LibWebSocket/WebSocket.h b/Userland/Libraries/LibWebSocket/WebSocket.h index 734e46f496..030745ada9 100644 --- a/Userland/Libraries/LibWebSocket/WebSocket.h +++ b/Userland/Libraries/LibWebSocket/WebSocket.h @@ -106,6 +106,8 @@ private: ConnectionInfo m_connection; RefPtr m_impl; + + Vector m_buffered_data; }; }