1 #[cfg(feature = "async")]
2 use core::task::Waker;
3
4 use heapless::Vec;
5 use managed::ManagedSlice;
6
7 use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
8 use crate::socket::{Context, PollAt};
9 use crate::time::{Duration, Instant};
10 use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
11 use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
12
13 #[cfg(feature = "async")]
14 use super::WakerRegistration;
15
16 const DNS_PORT: u16 = 53;
17 const MDNS_DNS_PORT: u16 = 5353;
18 const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
19 const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
20 const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
21
22 #[cfg(feature = "proto-ipv6")]
23 const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
24 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
25 ]));
26
27 #[cfg(feature = "proto-ipv4")]
28 const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
29
30 /// Error returned by [`Socket::start_query`]
31 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
32 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
33 pub enum StartQueryError {
34 NoFreeSlot,
35 InvalidName,
36 NameTooLong,
37 }
38
39 /// Error returned by [`Socket::get_query_result`]
40 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
41 #[cfg_attr(feature = "defmt", derive(defmt::Format))]
42 pub enum GetQueryResultError {
43 /// Query is not done yet.
44 Pending,
45 /// Query failed.
46 Failed,
47 }
48
49 /// State for an in-progress DNS query.
50 ///
51 /// The only reason this struct is public is to allow the socket state
52 /// to be allocated externally.
53 #[derive(Debug)]
54 pub struct DnsQuery {
55 state: State,
56
57 #[cfg(feature = "async")]
58 waker: WakerRegistration,
59 }
60
61 impl DnsQuery {
set_state(&mut self, state: State)62 fn set_state(&mut self, state: State) {
63 self.state = state;
64 #[cfg(feature = "async")]
65 self.waker.wake();
66 }
67 }
68
69 #[derive(Debug)]
70 #[allow(clippy::large_enum_variant)]
71 enum State {
72 Pending(PendingQuery),
73 Completed(CompletedQuery),
74 Failure,
75 }
76
77 #[derive(Debug)]
78 struct PendingQuery {
79 name: Vec<u8, DNS_MAX_NAME_SIZE>,
80 type_: Type,
81
82 port: u16, // UDP port (src for request, dst for response)
83 txid: u16, // transaction ID
84
85 timeout_at: Option<Instant>,
86 retransmit_at: Instant,
87 delay: Duration,
88
89 server_idx: usize,
90 mdns: MulticastDns,
91 }
92
93 #[derive(Debug)]
94 pub enum MulticastDns {
95 Disabled,
96 #[cfg(feature = "socket-mdns")]
97 Enabled,
98 }
99
100 #[derive(Debug)]
101 struct CompletedQuery {
102 addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
103 }
104
105 /// A handle to an in-progress DNS query.
106 #[derive(Clone, Copy)]
107 pub struct QueryHandle(usize);
108
109 /// A Domain Name System socket.
110 ///
111 /// A UDP socket is bound to a specific endpoint, and owns transmit and receive
112 /// packet buffers.
113 #[derive(Debug)]
114 pub struct Socket<'a> {
115 servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
116 queries: ManagedSlice<'a, Option<DnsQuery>>,
117
118 /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
119 hop_limit: Option<u8>,
120 }
121
122 impl<'a> Socket<'a> {
123 /// Create a DNS socket.
124 ///
125 /// # Panics
126 ///
127 /// Panics if `servers.len() > MAX_SERVER_COUNT`
new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a> where Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,128 pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
129 where
130 Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
131 {
132 Socket {
133 servers: Vec::from_slice(servers).unwrap(),
134 queries: queries.into(),
135 hop_limit: None,
136 }
137 }
138
139 /// Update the list of DNS servers, will replace all existing servers
140 ///
141 /// # Panics
142 ///
143 /// Panics if `servers.len() > MAX_SERVER_COUNT`
update_servers(&mut self, servers: &[IpAddress])144 pub fn update_servers(&mut self, servers: &[IpAddress]) {
145 self.servers = Vec::from_slice(servers).unwrap();
146 }
147
148 /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
149 ///
150 /// See also the [set_hop_limit](#method.set_hop_limit) method
hop_limit(&self) -> Option<u8>151 pub fn hop_limit(&self) -> Option<u8> {
152 self.hop_limit
153 }
154
155 /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
156 ///
157 /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
158 /// value (64).
159 ///
160 /// # Panics
161 ///
162 /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
163 ///
164 /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
165 /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
set_hop_limit(&mut self, hop_limit: Option<u8>)166 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
167 // A host MUST NOT send a datagram with a hop limit value of 0
168 if let Some(0) = hop_limit {
169 panic!("the time-to-live value of a packet must not be zero")
170 }
171
172 self.hop_limit = hop_limit
173 }
174
find_free_query(&mut self) -> Option<QueryHandle>175 fn find_free_query(&mut self) -> Option<QueryHandle> {
176 for (i, q) in self.queries.iter().enumerate() {
177 if q.is_none() {
178 return Some(QueryHandle(i));
179 }
180 }
181
182 match &mut self.queries {
183 ManagedSlice::Borrowed(_) => None,
184 #[cfg(feature = "alloc")]
185 ManagedSlice::Owned(queries) => {
186 queries.push(None);
187 let index = queries.len() - 1;
188 Some(QueryHandle(index))
189 }
190 }
191 }
192
193 /// Start a query.
194 ///
195 /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
196 /// It accepts names both with and without trailing dot, and they're treated
197 /// the same (there's no support for DNS search path).
start_query( &mut self, cx: &mut Context, name: &str, query_type: Type, ) -> Result<QueryHandle, StartQueryError>198 pub fn start_query(
199 &mut self,
200 cx: &mut Context,
201 name: &str,
202 query_type: Type,
203 ) -> Result<QueryHandle, StartQueryError> {
204 let mut name = name.as_bytes();
205
206 if name.is_empty() {
207 net_trace!("invalid name: zero length");
208 return Err(StartQueryError::InvalidName);
209 }
210
211 // Remove trailing dot, if any
212 if name[name.len() - 1] == b'.' {
213 name = &name[..name.len() - 1];
214 }
215
216 let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
217
218 let mut mdns = MulticastDns::Disabled;
219 #[cfg(feature = "socket-mdns")]
220 if name.split(|&c| c == b'.').last().unwrap() == b"local" {
221 net_trace!("Starting a mDNS query");
222 mdns = MulticastDns::Enabled;
223 }
224
225 for s in name.split(|&c| c == b'.') {
226 if s.len() > 63 {
227 net_trace!("invalid name: too long label");
228 return Err(StartQueryError::InvalidName);
229 }
230 if s.is_empty() {
231 net_trace!("invalid name: zero length label");
232 return Err(StartQueryError::InvalidName);
233 }
234
235 // Push label
236 raw_name
237 .push(s.len() as u8)
238 .map_err(|_| StartQueryError::NameTooLong)?;
239 raw_name
240 .extend_from_slice(s)
241 .map_err(|_| StartQueryError::NameTooLong)?;
242 }
243
244 // Push terminator.
245 raw_name
246 .push(0x00)
247 .map_err(|_| StartQueryError::NameTooLong)?;
248
249 self.start_query_raw(cx, &raw_name, query_type, mdns)
250 }
251
252 /// Start a query with a raw (wire-format) DNS name.
253 /// `b"\x09rust-lang\x03org\x00"`
254 ///
255 /// You probably want to use [`start_query`] instead.
start_query_raw( &mut self, cx: &mut Context, raw_name: &[u8], query_type: Type, mdns: MulticastDns, ) -> Result<QueryHandle, StartQueryError>256 pub fn start_query_raw(
257 &mut self,
258 cx: &mut Context,
259 raw_name: &[u8],
260 query_type: Type,
261 mdns: MulticastDns,
262 ) -> Result<QueryHandle, StartQueryError> {
263 let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
264
265 self.queries[handle.0] = Some(DnsQuery {
266 state: State::Pending(PendingQuery {
267 name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
268 type_: query_type,
269 txid: cx.rand().rand_u16(),
270 port: cx.rand().rand_source_port(),
271 delay: RETRANSMIT_DELAY,
272 timeout_at: None,
273 retransmit_at: Instant::ZERO,
274 server_idx: 0,
275 mdns,
276 }),
277 #[cfg(feature = "async")]
278 waker: WakerRegistration::new(),
279 });
280 Ok(handle)
281 }
282
283 /// Get the result of a query.
284 ///
285 /// If the query is completed, the query slot is automatically freed.
286 ///
287 /// # Panics
288 /// Panics if the QueryHandle corresponds to a free slot.
get_query_result( &mut self, handle: QueryHandle, ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError>289 pub fn get_query_result(
290 &mut self,
291 handle: QueryHandle,
292 ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
293 let slot = &mut self.queries[handle.0];
294 let q = slot.as_mut().unwrap();
295 match &mut q.state {
296 // Query is not done yet.
297 State::Pending(_) => Err(GetQueryResultError::Pending),
298 // Query is done
299 State::Completed(q) => {
300 let res = q.addresses.clone();
301 *slot = None; // Free up the slot for recycling.
302 Ok(res)
303 }
304 State::Failure => {
305 *slot = None; // Free up the slot for recycling.
306 Err(GetQueryResultError::Failed)
307 }
308 }
309 }
310
311 /// Cancels a query, freeing the slot.
312 ///
313 /// # Panics
314 ///
315 /// Panics if the QueryHandle corresponds to an already free slot.
cancel_query(&mut self, handle: QueryHandle)316 pub fn cancel_query(&mut self, handle: QueryHandle) {
317 let slot = &mut self.queries[handle.0];
318 if slot.is_none() {
319 panic!("Canceling query in a free slot.")
320 }
321 *slot = None; // Free up the slot for recycling.
322 }
323
324 /// Assign a waker to a query slot
325 ///
326 /// The waker will be woken when the query completes, either successfully or failed.
327 ///
328 /// # Panics
329 ///
330 /// Panics if the QueryHandle corresponds to an already free slot.
331 #[cfg(feature = "async")]
register_query_waker(&mut self, handle: QueryHandle, waker: &Waker)332 pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
333 self.queries[handle.0]
334 .as_mut()
335 .unwrap()
336 .waker
337 .register(waker);
338 }
339
accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool340 pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
341 (udp_repr.src_port == DNS_PORT
342 && self
343 .servers
344 .iter()
345 .any(|server| *server == ip_repr.src_addr()))
346 || (udp_repr.src_port == MDNS_DNS_PORT)
347 }
348
process( &mut self, _cx: &mut Context, ip_repr: &IpRepr, udp_repr: &UdpRepr, payload: &[u8], )349 pub(crate) fn process(
350 &mut self,
351 _cx: &mut Context,
352 ip_repr: &IpRepr,
353 udp_repr: &UdpRepr,
354 payload: &[u8],
355 ) {
356 debug_assert!(self.accepts(ip_repr, udp_repr));
357
358 let size = payload.len();
359
360 net_trace!(
361 "receiving {} octets from {:?}:{}",
362 size,
363 ip_repr.src_addr(),
364 udp_repr.dst_port
365 );
366
367 let p = match Packet::new_checked(payload) {
368 Ok(x) => x,
369 Err(_) => {
370 net_trace!("dns packet malformed");
371 return;
372 }
373 };
374 if p.opcode() != Opcode::Query {
375 net_trace!("unwanted opcode {:?}", p.opcode());
376 return;
377 }
378
379 if !p.flags().contains(Flags::RESPONSE) {
380 net_trace!("packet doesn't have response bit set");
381 return;
382 }
383
384 if p.question_count() != 1 {
385 net_trace!("bad question count {:?}", p.question_count());
386 return;
387 }
388
389 // Find pending query
390 for q in self.queries.iter_mut().flatten() {
391 if let State::Pending(pq) = &mut q.state {
392 if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
393 continue;
394 }
395
396 if p.rcode() == Rcode::NXDomain {
397 net_trace!("rcode NXDomain");
398 q.set_state(State::Failure);
399 continue;
400 }
401
402 let payload = p.payload();
403 let (mut payload, question) = match Question::parse(payload) {
404 Ok(x) => x,
405 Err(_) => {
406 net_trace!("question malformed");
407 return;
408 }
409 };
410
411 if question.type_ != pq.type_ {
412 net_trace!("question type mismatch");
413 return;
414 }
415
416 match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
417 Ok(true) => {}
418 Ok(false) => {
419 net_trace!("question name mismatch");
420 return;
421 }
422 Err(_) => {
423 net_trace!("dns question name malformed");
424 return;
425 }
426 }
427
428 let mut addresses = Vec::new();
429
430 for _ in 0..p.answer_record_count() {
431 let (payload2, r) = match Record::parse(payload) {
432 Ok(x) => x,
433 Err(_) => {
434 net_trace!("dns answer record malformed");
435 return;
436 }
437 };
438 payload = payload2;
439
440 match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
441 Ok(true) => {}
442 Ok(false) => {
443 net_trace!("answer name mismatch: {:?}", r);
444 continue;
445 }
446 Err(_) => {
447 net_trace!("dns answer record name malformed");
448 return;
449 }
450 }
451
452 match r.data {
453 #[cfg(feature = "proto-ipv4")]
454 RecordData::A(addr) => {
455 net_trace!("A: {:?}", addr);
456 if addresses.push(addr.into()).is_err() {
457 net_trace!("too many addresses in response, ignoring {:?}", addr);
458 }
459 }
460 #[cfg(feature = "proto-ipv6")]
461 RecordData::Aaaa(addr) => {
462 net_trace!("AAAA: {:?}", addr);
463 if addresses.push(addr.into()).is_err() {
464 net_trace!("too many addresses in response, ignoring {:?}", addr);
465 }
466 }
467 RecordData::Cname(name) => {
468 net_trace!("CNAME: {:?}", name);
469
470 // When faced with a CNAME, recursive resolvers are supposed to
471 // resolve the CNAME and append the results for it.
472 //
473 // We update the query with the new name, so that we pick up the A/AAAA
474 // records for the CNAME when we parse them later.
475 // I believe it's mandatory the CNAME results MUST come *after* in the
476 // packet, so it's enough to do one linear pass over it.
477 if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
478 net_trace!("dns answer cname malformed");
479 return;
480 }
481 }
482 RecordData::Other(type_, data) => {
483 net_trace!("unknown: {:?} {:?}", type_, data)
484 }
485 }
486 }
487
488 q.set_state(if addresses.is_empty() {
489 State::Failure
490 } else {
491 State::Completed(CompletedQuery { addresses })
492 });
493
494 // If we get here, packet matched the current query, stop processing.
495 return;
496 }
497 }
498
499 // If we get here, packet matched with no query.
500 net_trace!("no query matched");
501 }
502
dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E> where F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,503 pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
504 where
505 F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
506 {
507 let hop_limit = self.hop_limit.unwrap_or(64);
508
509 for q in self.queries.iter_mut().flatten() {
510 if let State::Pending(pq) = &mut q.state {
511 // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
512 // so we internally overwrite the servers for any of those queries
513 // in this function.
514 let servers = match pq.mdns {
515 #[cfg(feature = "socket-mdns")]
516 MulticastDns::Enabled => &[
517 #[cfg(feature = "proto-ipv6")]
518 MDNS_IPV6_ADDR,
519 #[cfg(feature = "proto-ipv4")]
520 MDNS_IPV4_ADDR,
521 ],
522 MulticastDns::Disabled => self.servers.as_slice(),
523 };
524
525 let timeout = if let Some(timeout) = pq.timeout_at {
526 timeout
527 } else {
528 let v = cx.now() + RETRANSMIT_TIMEOUT;
529 pq.timeout_at = Some(v);
530 v
531 };
532
533 // Check timeout
534 if timeout < cx.now() {
535 // DNS timeout
536 pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
537 pq.retransmit_at = Instant::ZERO;
538 pq.delay = RETRANSMIT_DELAY;
539
540 // Try next server. We check below whether we've tried all servers.
541 pq.server_idx += 1;
542 }
543 // Check if we've run out of servers to try.
544 if pq.server_idx >= servers.len() {
545 net_trace!("already tried all servers.");
546 q.set_state(State::Failure);
547 continue;
548 }
549
550 // Check so the IP address is valid
551 if servers[pq.server_idx].is_unspecified() {
552 net_trace!("invalid unspecified DNS server addr.");
553 q.set_state(State::Failure);
554 continue;
555 }
556
557 if pq.retransmit_at > cx.now() {
558 // query is waiting for retransmit
559 continue;
560 }
561
562 let repr = Repr {
563 transaction_id: pq.txid,
564 flags: Flags::RECURSION_DESIRED,
565 opcode: Opcode::Query,
566 question: Question {
567 name: &pq.name,
568 type_: pq.type_,
569 },
570 };
571
572 let mut payload = [0u8; 512];
573 let payload = &mut payload[..repr.buffer_len()];
574 repr.emit(&mut Packet::new_unchecked(payload));
575
576 let dst_port = match pq.mdns {
577 #[cfg(feature = "socket-mdns")]
578 MulticastDns::Enabled => MDNS_DNS_PORT,
579 MulticastDns::Disabled => DNS_PORT,
580 };
581
582 let udp_repr = UdpRepr {
583 src_port: pq.port,
584 dst_port,
585 };
586
587 let dst_addr = servers[pq.server_idx];
588 let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap
589 let ip_repr = IpRepr::new(
590 src_addr,
591 dst_addr,
592 IpProtocol::Udp,
593 udp_repr.header_len() + payload.len(),
594 hop_limit,
595 );
596
597 net_trace!(
598 "sending {} octets to {} from port {}",
599 payload.len(),
600 ip_repr.dst_addr(),
601 udp_repr.src_port
602 );
603
604 emit(cx, (ip_repr, udp_repr, payload))?;
605
606 pq.retransmit_at = cx.now() + pq.delay;
607 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
608
609 return Ok(());
610 }
611 }
612
613 // Nothing to dispatch
614 Ok(())
615 }
616
poll_at(&self, _cx: &Context) -> PollAt617 pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
618 self.queries
619 .iter()
620 .flatten()
621 .filter_map(|q| match &q.state {
622 State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
623 State::Completed(_) => None,
624 State::Failure => None,
625 })
626 .min()
627 .unwrap_or(PollAt::Ingress)
628 }
629 }
630
eq_names<'a>( mut a: impl Iterator<Item = wire::Result<&'a [u8]>>, mut b: impl Iterator<Item = wire::Result<&'a [u8]>>, ) -> wire::Result<bool>631 fn eq_names<'a>(
632 mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
633 mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
634 ) -> wire::Result<bool> {
635 loop {
636 match (a.next(), b.next()) {
637 // Handle errors
638 (Some(Err(e)), _) => return Err(e),
639 (_, Some(Err(e))) => return Err(e),
640
641 // Both finished -> equal
642 (None, None) => return Ok(true),
643
644 // One finished before the other -> not equal
645 (None, _) => return Ok(false),
646 (_, None) => return Ok(false),
647
648 // Got two labels, check if they're equal
649 (Some(Ok(la)), Some(Ok(lb))) => {
650 if la != lb {
651 return Ok(false);
652 }
653 }
654 }
655 }
656 }
657
copy_name<'a, const N: usize>( dest: &mut Vec<u8, N>, name: impl Iterator<Item = wire::Result<&'a [u8]>>, ) -> Result<(), wire::Error>658 fn copy_name<'a, const N: usize>(
659 dest: &mut Vec<u8, N>,
660 name: impl Iterator<Item = wire::Result<&'a [u8]>>,
661 ) -> Result<(), wire::Error> {
662 dest.truncate(0);
663
664 for label in name {
665 let label = label?;
666 dest.push(label.len() as u8).map_err(|_| wire::Error)?;
667 dest.extend_from_slice(label).map_err(|_| wire::Error)?;
668 }
669
670 // Write terminator 0x00
671 dest.push(0).map_err(|_| wire::Error)?;
672
673 Ok(())
674 }
675