diff --git a/src/IPAddressExtensions.cs b/src/IPAddressExtensions.cs new file mode 100644 index 0000000..73c016e --- /dev/null +++ b/src/IPAddressExtensions.cs @@ -0,0 +1,79 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; +using System.Text; + +namespace Makaretu.Dns +{ + /// + /// Extensions for . + /// + public static class IPAddressExtensions + { + /// + /// Gets the subnet mask associated with the IP address. + /// + /// + /// An IP Addresses. + /// + /// + /// The subnet mask; for example "127.0.0.1" returns "255.0.0.0". + /// Or null When does not belong to + /// the localhost. + /// s + public static IPAddress GetSubnetMask(this IPAddress address) + { + return NetworkInterface.GetAllNetworkInterfaces() + .SelectMany(nic => nic.GetIPProperties().UnicastAddresses) + .Where(a => a.Address.Equals(address)) + .Select(a => a.IPv4Mask) + .FirstOrDefault(); + } + + /// + /// Determines if the local IP address can be used by the + /// remote address. + /// + /// + /// + /// + /// true if can be used by ; + /// otherwise, false. + /// + public static bool IsReachable(this IPAddress local, IPAddress remote) + { + // Loopback addresses are only reachable when the remote is + // the same host. + if (local.Equals(IPAddress.Loopback) || local.Equals(IPAddress.IPv6Loopback)) + { + return MulticastService.GetIPAddresses().Contains(remote); + } + + // IPv4 addresses are reachable when on the same subnet. + if (local.AddressFamily == AddressFamily.InterNetwork && remote.AddressFamily == AddressFamily.InterNetwork) + { + var mask = local.GetSubnetMask(); + if (mask != null) + { + var network = IPNetwork.Parse(local, mask); + return network.Contains(remote); + } + } + + // IPv6 link local addresses are reachabe when using the same scope id. + if (local.AddressFamily == AddressFamily.InterNetworkV6 && remote.AddressFamily == AddressFamily.InterNetworkV6) + { + if (local.IsIPv6LinkLocal || remote.IsIPv6LinkLocal) + { + return local.Equals(remote); + } + } + + // Can not determine reachability, assume that network routing can do it. + return true; + } + } +} diff --git a/src/MessageEventArgs.cs b/src/MessageEventArgs.cs index c9e5c54..59abc09 100644 --- a/src/MessageEventArgs.cs +++ b/src/MessageEventArgs.cs @@ -18,10 +18,10 @@ public class MessageEventArgs : EventArgs public Message Message { get; set; } /// - /// The DNS message sender endpoint. + /// Where the message originated from. /// /// - /// The endpoint from the message was received. + /// The IP address and port of the sender. /// public IPEndPoint RemoteEndPoint { get; set; } diff --git a/src/ServiceDiscovery.cs b/src/ServiceDiscovery.cs index 2ca89a0..f3607f6 100644 --- a/src/ServiceDiscovery.cs +++ b/src/ServiceDiscovery.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Net; using Common.Logging; using Makaretu.Dns.Resolving; @@ -381,6 +382,11 @@ void OnQuery(object sender, MessageEventArgs e) ; } + // Only return address records that the querier can reach. + response.Answers.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint)); + response.AuthorityRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint)); + response.AdditionalRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint)); + if (QU) { // TODO: Send a Unicast response if required. @@ -402,6 +408,12 @@ void OnQuery(object sender, MessageEventArgs e) //Console.WriteLine($"Response time {(DateTime.Now - request.CreationTime).TotalMilliseconds}ms"); } + bool IsUnreachable(ResourceRecord rr, IPEndPoint sender) + { + var arecord = rr as AddressRecord; + return !arecord?.Address.IsReachable(sender.Address) ?? false; + } + #region IDisposable Support /// diff --git a/src/ServiceProfile.cs b/src/ServiceProfile.cs index 3fdf652..fcb4096 100644 --- a/src/ServiceProfile.cs +++ b/src/ServiceProfile.cs @@ -74,7 +74,7 @@ public ServiceProfile(DomainName instanceName, DomainName serviceName, ushort po Strings = { "txtvers=1" } }); - foreach (var address in addresses ?? MulticastService.GetLinkLocalAddresses()) + foreach (var address in addresses ?? MulticastService.GetIPAddresses()) { Resources.Add(AddressRecord.Create(HostName, address)); } diff --git a/test/IPAddressExtensionsTest.cs b/test/IPAddressExtensionsTest.cs new file mode 100644 index 0000000..3c8b259 --- /dev/null +++ b/test/IPAddressExtensionsTest.cs @@ -0,0 +1,106 @@ +using Makaretu.Dns; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Net.NetworkInformation; +using System.Net.Sockets; + +namespace Makaretu.Dns +{ + + [TestClass] + public class IPAddressExtensionsTest + { + [TestMethod] + public void SubnetMask_Ipv4Loopback() + { + var mask = IPAddress.Loopback.GetSubnetMask(); + Assert.AreEqual(IPAddress.Parse("255.0.0.0"), mask); + var network = IPNetwork.Parse(IPAddress.Loopback, mask); + Console.Write(network.ToString()); + } + + [TestMethod] + [TestCategory("IPv6")] + public void SubnetMask_Ipv6Loopback() + { + var mask = IPAddress.IPv6Loopback.GetSubnetMask(); + Assert.AreEqual(IPAddress.Parse("0.0.0.0"), mask); + } + + [TestMethod] + public void SubmetMask_NotLocalhost() + { + var mask = IPAddress.Parse("1.1.1.1").GetSubnetMask(); + Assert.IsNull(mask); + } + + [TestMethod] + public void SubnetMask_All() + { + foreach (var a in MulticastService.GetIPAddresses()) + { + var network = IPNetwork.Parse(a, a.GetSubnetMask()); + + Console.WriteLine($"{a} mask {a.GetSubnetMask()} {network}"); + + Assert.IsTrue(network.Contains(a), $"{a} is not reachable"); + } + } + + [TestMethod] + public void LinkLocal() + { + foreach (var a in MulticastService.GetIPAddresses()) + { + Console.WriteLine($"{a} ll={a.IsIPv6LinkLocal} ss={a.IsIPv6SiteLocal}"); + } + } + + [TestMethod] + public void Reachable_Loopback_From_Localhost() + { + var me = IPAddress.Loopback; + foreach (var a in MulticastService.GetIPAddresses()) + { + Assert.IsTrue(me.IsReachable(a), $"{a}"); + } + Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1"))); + Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111"))); + + me = IPAddress.IPv6Loopback; + foreach (var a in MulticastService.GetIPAddresses()) + { + Assert.IsTrue(me.IsReachable(a), $"{a}"); + } + Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1"))); + Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111"))); + } + + [TestMethod] + public void Reachable_Ipv4() + { + var me = MulticastService.GetIPAddresses() + .First(a => a.AddressFamily == AddressFamily.InterNetwork && !IPAddress.IsLoopback(a)); + Assert.IsTrue(me.IsReachable(me)); + Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1"))); + + var nat = IPAddress.Parse("165.84.19.151"); // NAT PCP assigned address + Assert.IsTrue(nat.IsReachable(IPAddress.Parse("1.1.1.1"))); + } + + [TestMethod] + public void Reachable_Ipv6_LinkLocal() + { + var me1 = IPAddress.Parse("fe80::1:2:3:4%1"); + var me2 = IPAddress.Parse("fe80::1:2:3:4%2"); + var me5 = IPAddress.Parse("fe80::1:2:3:5%1"); + Assert.IsTrue(me1.IsReachable(me1)); + Assert.IsTrue(me2.IsReachable(me2)); + Assert.IsFalse(me1.IsReachable(me2)); + Assert.IsFalse(me1.IsReachable(me5)); + } + } +} diff --git a/test/MulticastServiceTest.cs b/test/MulticastServiceTest.cs index 9c25936..d4220ed 100644 --- a/test/MulticastServiceTest.cs +++ b/test/MulticastServiceTest.cs @@ -35,6 +35,7 @@ public void SendQuery() var ready = new ManualResetEvent(false); var done = new ManualResetEvent(false); Message msg = null; + IPEndPoint sender = null; var mdns = new MulticastService(); mdns.NetworkInterfaceDiscovered += (s, e) => ready.Set(); @@ -43,7 +44,7 @@ public void SendQuery() if ("some-service.local" == e.Message.Questions.First().Name) { msg = e.Message; - Assert.IsFalse(e.IsLegacyUnicast); + sender = e.RemoteEndPoint; done.Set(); } }; @@ -53,6 +54,8 @@ public void SendQuery() Assert.IsTrue(ready.WaitOne(TimeSpan.FromSeconds(1)), "ready timeout"); mdns.SendQuery("some-service.local"); Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "query timeout"); + Assert.IsNotNull(msg); + Assert.IsNotNull(sender); Assert.AreEqual("some-service.local", msg.Questions.First().Name); Assert.AreEqual(DnsClass.IN, msg.Questions.First().Class); } @@ -97,6 +100,7 @@ public void ReceiveAnswer() var service = Guid.NewGuid().ToString() + ".local"; var done = new ManualResetEvent(false); Message response = null; + IPEndPoint sender = null; using (var mdns = new MulticastService()) { @@ -121,12 +125,14 @@ public void ReceiveAnswer() if (msg.Answers.Any(answer => answer.Name == service)) { response = msg; + sender = e.RemoteEndPoint; done.Set(); } }; mdns.Start(); Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "answer timeout"); Assert.IsNotNull(response); + Assert.IsNotNull(sender); Assert.IsTrue(response.IsResponse); Assert.AreEqual(MessageStatus.NoError, response.Status); Assert.IsTrue(response.AA);