From b39db0e7981791156a7bfa8118322ab1e5914a6c Mon Sep 17 00:00:00 2001 From: muathendirangu Date: Thu, 12 Oct 2023 13:35:44 +0300 Subject: [PATCH] extend our dns for recursive lookups --- src/main.rs | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index 5774f3f..22879f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -718,12 +718,75 @@ impl DnsPacket { } Ok(()) } + + /// It's useful to be able to pick a random A record from a packet. When we + /// get multiple IP's for a single name, it doesn't matter which one we + /// choose, so in those cases we can now pick one at random. + pub fn get_random_a_record(&self) -> Option { + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { addr, .. } => Some(*addr), + _ => None, + }) + .next() + } + + /// A helper function which returns an iterator over all name servers in + /// the authorities section, represented as (domain, host) tuples + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities + .iter() + // In practice, these are always NS records in well formed packages. + // Convert the NS records to a tuple which has only the data we need + // to make it easy to work with. + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + // Discard servers which aren't authoritative to our query + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + /// We'll use the fact that name servers often bundle the corresponding + /// A records when replying to an NS query to implement a function that + /// returns the actual IP for an NS record if possible. + pub fn get_resolved_ns(&self, qname: &str) -> Option { + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + // Now we need to look for a matching A record in the additional + // section. Since we just want the first valid record, we can just + // build a stream of matching records. + .flat_map(|(_, host)| { + self.resources + .iter() + // Filter for A records where the domain match the host + // of the NS record that we are currently processing + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| *addr) + // Finally, pick the first valid entry + .next() + } + + /// However, not all name servers are as that nice. In certain cases there won't + /// be any A records in the additional section, and we'll have to perform *another* + /// lookup in the midst. For this, we introduce a method for returning the host + /// name of an appropriate name server. + pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + .map(|(_, host)| host) + // Finally, pick the first valid entry + .next() + } } // Add lookup method to lookup DNS records -fn lookup(query_name: &str, query_type: QueryType) -> Result { - //forward query to public dns - let server = ("8.8.8.8", 53); +fn lookup(query_name: &str, query_type: QueryType, server: (Ipv4Addr, u16)) -> Result { // bind a UDP socket to arbitrary port let socket = UdpSocket::bind(("0.0.0.0", 42340))?; @@ -754,6 +817,63 @@ fn lookup(query_name: &str, query_type: QueryType) -> Result { DnsPacket::from_buffer(&mut res_buffer) } + +// Recursively query name servers until we get an answer or hit an error +fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { + // For now we're always starting with *a.root-servers.net*. + let mut ns = "198.41.0.4".parse::().unwrap(); + + // Since it might take an arbitrary number of steps, we enter an unbounded loop. + loop { + println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + + // The next step is to send the query to the active server. + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = lookup(qname, qtype, server)?; + + // If there are entries in the answer section, and no errors, we are done! + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + return Ok(response); + } + + // We might also get a `NXDOMAIN` reply, which is the authoritative name servers + // way of telling us that the name doesn't exist. + if response.header.rescode == ResultCode::NXDOMAIN { + return Ok(response); + } + + // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A + // record in the additional section. If this succeeds, we can switch name server + // and retry the loop. + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + + continue; + } + + // If not, we'll have to resolve the ip of a NS record. If no NS records exist, + // we'll go with what the last server told us. + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => return Ok(response), + }; + + // Here we go down the rabbit hole by starting _another_ lookup sequence in the + // midst of our current one. Hopefully, this will give us the IP of an appropriate + // name server. + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; + + // Finally, we pick a random ip from the result, and restart the loop. If no such + // record is available, we again return the last result we got. + if let Some(new_ns) = recursive_response.get_random_a_record() { + ns = new_ns; + } else { + return Ok(response); + } + } +} /// Handle a single incoming packet fn handle_query(socket: &UdpSocket) -> Result<()> { // With a socket ready, we can go ahead and read a packet. This will @@ -786,7 +906,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> { // fail, in which case the `SERVFAIL` response code is set to indicate // as much to the client. If rather everything goes as planned, the // question and response records as copied into our response packet. - if let Ok(result) = lookup(&question.name, question.question_type) { + if let Ok(result) = recursive_lookup(&question.name, question.question_type) { packet.questions.push(question); packet.header.rescode = result.header.rescode;