diff --git a/Servers/LookupServer/DNSQuestion.h b/Servers/LookupServer/DNSQuestion.h index 42fab135ef..1918c5a81f 100644 --- a/Servers/LookupServer/DNSQuestion.h +++ b/Servers/LookupServer/DNSQuestion.h @@ -16,6 +16,16 @@ public: u16 class_code() const { return m_class_code; } const String& name() const { return m_name; } + bool operator==(const DNSQuestion& other) const + { + return m_name == other.m_name && m_record_type == other.m_record_type && m_class_code == other.m_class_code; + } + + bool operator!=(const DNSQuestion& other) const + { + return !(*this == other); + } + private: String m_name; u16 m_record_type { 0 }; diff --git a/Servers/LookupServer/DNSRequest.h b/Servers/LookupServer/DNSRequest.h index 4289989cc2..4a73b35718 100644 --- a/Servers/LookupServer/DNSRequest.h +++ b/Servers/LookupServer/DNSRequest.h @@ -18,6 +18,8 @@ public: void add_question(const String& name, u16 record_type); + const Vector& questions() const { return m_questions; } + u16 question_count() const { ASSERT(m_questions.size() < UINT16_MAX); diff --git a/Servers/LookupServer/DNSResponse.cpp b/Servers/LookupServer/DNSResponse.cpp index ef202477c9..4d86dcb528 100644 --- a/Servers/LookupServer/DNSResponse.cpp +++ b/Servers/LookupServer/DNSResponse.cpp @@ -28,7 +28,6 @@ private: static_assert(sizeof(DNSRecordWithoutName) == 10); - Optional DNSResponse::from_raw_response(const u8* raw_data, size_t raw_size) { if (raw_size < sizeof(DNSPacket)) { @@ -79,7 +78,7 @@ Optional DNSResponse::from_raw_response(const u8* raw_data, size_t // FIXME: Parse some other record types perhaps? dbg() << " data=(unimplemented record type " << record.type() << ")"; } - dbg() << "Answer #" << i << ": type=" << record.type() << ", ttl=" << record.ttl() << ", length=" << record.data_length() << ", data=_" << data << "_"; + dbg() << "Answer #" << i << ": name=_" << name << "_, type=" << record.type() << ", ttl=" << record.ttl() << ", length=" << record.data_length() << ", data=_" << data << "_"; response.m_answers.empend(name, record.type(), record.record_class(), record.ttl(), data); offset += record.data_length(); } diff --git a/Servers/LookupServer/LookupServer.cpp b/Servers/LookupServer/LookupServer.cpp index 818c0a61ff..ae7605d76f 100644 --- a/Servers/LookupServer/LookupServer.cpp +++ b/Servers/LookupServer/LookupServer.cpp @@ -216,10 +216,22 @@ Vector LookupServer::lookup(const String& hostname, bool& did_timeout, u dbgprintf("LookupServer: ID mismatch (%u vs %u) :(\n", response.id(), request.id()); return {}; } - if (response.question_count() != 1) { + if (response.question_count() != request.question_count()) { dbgprintf("LookupServer: Question count (%u vs %u) :(\n", response.question_count(), request.question_count()); return {}; } + + for (size_t i = 0; i < request.question_count(); ++i) { + auto& request_question = request.questions()[i]; + auto& response_question = response.questions()[i]; + if (request_question != response_question) { + dbg() << "Request and response questions do not match"; + dbg() << " Request: {_" << request_question.name() << "_, " << request_question.record_type() << ", " << request_question.class_code() << "}"; + dbg() << " Response: {_" << response_question.name() << "_, " << response_question.record_type() << ", " << response_question.class_code() << "}"; + return {}; + } + } + if (response.answer_count() < 1) { dbgprintf("LookupServer: Not enough answers (%u) :(\n", response.answer_count()); return {};