diff --git a/src/IPAddressExtensions.cs b/src/IPAddressExtensions.cs
new file mode 100644
index 0000000..73c016e
--- /dev/null
+++ b/src/IPAddressExtensions.cs
@@ -0,0 +1,79 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.NetworkInformation;
+using System.Net.Sockets;
+using System.Text;
+
+namespace Makaretu.Dns
+{
+ ///
+ /// Extensions for .
+ ///
+ public static class IPAddressExtensions
+ {
+ ///
+ /// Gets the subnet mask associated with the IP address.
+ ///
+ ///
+ /// An IP Addresses.
+ ///
+ ///
+ /// The subnet mask; for example "127.0.0.1" returns "255.0.0.0".
+ /// Or null When does not belong to
+ /// the localhost.
+ /// s
+ public static IPAddress GetSubnetMask(this IPAddress address)
+ {
+ return NetworkInterface.GetAllNetworkInterfaces()
+ .SelectMany(nic => nic.GetIPProperties().UnicastAddresses)
+ .Where(a => a.Address.Equals(address))
+ .Select(a => a.IPv4Mask)
+ .FirstOrDefault();
+ }
+
+ ///
+ /// Determines if the local IP address can be used by the
+ /// remote address.
+ ///
+ ///
+ ///
+ ///
+ /// true if can be used by ;
+ /// otherwise, false.
+ ///
+ public static bool IsReachable(this IPAddress local, IPAddress remote)
+ {
+ // Loopback addresses are only reachable when the remote is
+ // the same host.
+ if (local.Equals(IPAddress.Loopback) || local.Equals(IPAddress.IPv6Loopback))
+ {
+ return MulticastService.GetIPAddresses().Contains(remote);
+ }
+
+ // IPv4 addresses are reachable when on the same subnet.
+ if (local.AddressFamily == AddressFamily.InterNetwork && remote.AddressFamily == AddressFamily.InterNetwork)
+ {
+ var mask = local.GetSubnetMask();
+ if (mask != null)
+ {
+ var network = IPNetwork.Parse(local, mask);
+ return network.Contains(remote);
+ }
+ }
+
+ // IPv6 link local addresses are reachabe when using the same scope id.
+ if (local.AddressFamily == AddressFamily.InterNetworkV6 && remote.AddressFamily == AddressFamily.InterNetworkV6)
+ {
+ if (local.IsIPv6LinkLocal || remote.IsIPv6LinkLocal)
+ {
+ return local.Equals(remote);
+ }
+ }
+
+ // Can not determine reachability, assume that network routing can do it.
+ return true;
+ }
+ }
+}
diff --git a/src/MessageEventArgs.cs b/src/MessageEventArgs.cs
index c9e5c54..59abc09 100644
--- a/src/MessageEventArgs.cs
+++ b/src/MessageEventArgs.cs
@@ -18,10 +18,10 @@ public class MessageEventArgs : EventArgs
public Message Message { get; set; }
///
- /// The DNS message sender endpoint.
+ /// Where the message originated from.
///
///
- /// The endpoint from the message was received.
+ /// The IP address and port of the sender.
///
public IPEndPoint RemoteEndPoint { get; set; }
diff --git a/src/ServiceDiscovery.cs b/src/ServiceDiscovery.cs
index 2ca89a0..f3607f6 100644
--- a/src/ServiceDiscovery.cs
+++ b/src/ServiceDiscovery.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Net;
using Common.Logging;
using Makaretu.Dns.Resolving;
@@ -381,6 +382,11 @@ void OnQuery(object sender, MessageEventArgs e)
;
}
+ // Only return address records that the querier can reach.
+ response.Answers.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
+ response.AuthorityRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
+ response.AdditionalRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
+
if (QU)
{
// TODO: Send a Unicast response if required.
@@ -402,6 +408,12 @@ void OnQuery(object sender, MessageEventArgs e)
//Console.WriteLine($"Response time {(DateTime.Now - request.CreationTime).TotalMilliseconds}ms");
}
+ bool IsUnreachable(ResourceRecord rr, IPEndPoint sender)
+ {
+ var arecord = rr as AddressRecord;
+ return !arecord?.Address.IsReachable(sender.Address) ?? false;
+ }
+
#region IDisposable Support
///
diff --git a/src/ServiceProfile.cs b/src/ServiceProfile.cs
index 3fdf652..fcb4096 100644
--- a/src/ServiceProfile.cs
+++ b/src/ServiceProfile.cs
@@ -74,7 +74,7 @@ public ServiceProfile(DomainName instanceName, DomainName serviceName, ushort po
Strings = { "txtvers=1" }
});
- foreach (var address in addresses ?? MulticastService.GetLinkLocalAddresses())
+ foreach (var address in addresses ?? MulticastService.GetIPAddresses())
{
Resources.Add(AddressRecord.Create(HostName, address));
}
diff --git a/test/IPAddressExtensionsTest.cs b/test/IPAddressExtensionsTest.cs
new file mode 100644
index 0000000..3c8b259
--- /dev/null
+++ b/test/IPAddressExtensionsTest.cs
@@ -0,0 +1,106 @@
+using Makaretu.Dns;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Net.NetworkInformation;
+using System.Net.Sockets;
+
+namespace Makaretu.Dns
+{
+
+ [TestClass]
+ public class IPAddressExtensionsTest
+ {
+ [TestMethod]
+ public void SubnetMask_Ipv4Loopback()
+ {
+ var mask = IPAddress.Loopback.GetSubnetMask();
+ Assert.AreEqual(IPAddress.Parse("255.0.0.0"), mask);
+ var network = IPNetwork.Parse(IPAddress.Loopback, mask);
+ Console.Write(network.ToString());
+ }
+
+ [TestMethod]
+ [TestCategory("IPv6")]
+ public void SubnetMask_Ipv6Loopback()
+ {
+ var mask = IPAddress.IPv6Loopback.GetSubnetMask();
+ Assert.AreEqual(IPAddress.Parse("0.0.0.0"), mask);
+ }
+
+ [TestMethod]
+ public void SubmetMask_NotLocalhost()
+ {
+ var mask = IPAddress.Parse("1.1.1.1").GetSubnetMask();
+ Assert.IsNull(mask);
+ }
+
+ [TestMethod]
+ public void SubnetMask_All()
+ {
+ foreach (var a in MulticastService.GetIPAddresses())
+ {
+ var network = IPNetwork.Parse(a, a.GetSubnetMask());
+
+ Console.WriteLine($"{a} mask {a.GetSubnetMask()} {network}");
+
+ Assert.IsTrue(network.Contains(a), $"{a} is not reachable");
+ }
+ }
+
+ [TestMethod]
+ public void LinkLocal()
+ {
+ foreach (var a in MulticastService.GetIPAddresses())
+ {
+ Console.WriteLine($"{a} ll={a.IsIPv6LinkLocal} ss={a.IsIPv6SiteLocal}");
+ }
+ }
+
+ [TestMethod]
+ public void Reachable_Loopback_From_Localhost()
+ {
+ var me = IPAddress.Loopback;
+ foreach (var a in MulticastService.GetIPAddresses())
+ {
+ Assert.IsTrue(me.IsReachable(a), $"{a}");
+ }
+ Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
+ Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));
+
+ me = IPAddress.IPv6Loopback;
+ foreach (var a in MulticastService.GetIPAddresses())
+ {
+ Assert.IsTrue(me.IsReachable(a), $"{a}");
+ }
+ Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
+ Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));
+ }
+
+ [TestMethod]
+ public void Reachable_Ipv4()
+ {
+ var me = MulticastService.GetIPAddresses()
+ .First(a => a.AddressFamily == AddressFamily.InterNetwork && !IPAddress.IsLoopback(a));
+ Assert.IsTrue(me.IsReachable(me));
+ Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
+
+ var nat = IPAddress.Parse("165.84.19.151"); // NAT PCP assigned address
+ Assert.IsTrue(nat.IsReachable(IPAddress.Parse("1.1.1.1")));
+ }
+
+ [TestMethod]
+ public void Reachable_Ipv6_LinkLocal()
+ {
+ var me1 = IPAddress.Parse("fe80::1:2:3:4%1");
+ var me2 = IPAddress.Parse("fe80::1:2:3:4%2");
+ var me5 = IPAddress.Parse("fe80::1:2:3:5%1");
+ Assert.IsTrue(me1.IsReachable(me1));
+ Assert.IsTrue(me2.IsReachable(me2));
+ Assert.IsFalse(me1.IsReachable(me2));
+ Assert.IsFalse(me1.IsReachable(me5));
+ }
+ }
+}
diff --git a/test/MulticastServiceTest.cs b/test/MulticastServiceTest.cs
index 9c25936..d4220ed 100644
--- a/test/MulticastServiceTest.cs
+++ b/test/MulticastServiceTest.cs
@@ -35,6 +35,7 @@ public void SendQuery()
var ready = new ManualResetEvent(false);
var done = new ManualResetEvent(false);
Message msg = null;
+ IPEndPoint sender = null;
var mdns = new MulticastService();
mdns.NetworkInterfaceDiscovered += (s, e) => ready.Set();
@@ -43,7 +44,7 @@ public void SendQuery()
if ("some-service.local" == e.Message.Questions.First().Name)
{
msg = e.Message;
- Assert.IsFalse(e.IsLegacyUnicast);
+ sender = e.RemoteEndPoint;
done.Set();
}
};
@@ -53,6 +54,8 @@ public void SendQuery()
Assert.IsTrue(ready.WaitOne(TimeSpan.FromSeconds(1)), "ready timeout");
mdns.SendQuery("some-service.local");
Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "query timeout");
+ Assert.IsNotNull(msg);
+ Assert.IsNotNull(sender);
Assert.AreEqual("some-service.local", msg.Questions.First().Name);
Assert.AreEqual(DnsClass.IN, msg.Questions.First().Class);
}
@@ -97,6 +100,7 @@ public void ReceiveAnswer()
var service = Guid.NewGuid().ToString() + ".local";
var done = new ManualResetEvent(false);
Message response = null;
+ IPEndPoint sender = null;
using (var mdns = new MulticastService())
{
@@ -121,12 +125,14 @@ public void ReceiveAnswer()
if (msg.Answers.Any(answer => answer.Name == service))
{
response = msg;
+ sender = e.RemoteEndPoint;
done.Set();
}
};
mdns.Start();
Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "answer timeout");
Assert.IsNotNull(response);
+ Assert.IsNotNull(sender);
Assert.IsTrue(response.IsResponse);
Assert.AreEqual(MessageStatus.NoError, response.Status);
Assert.IsTrue(response.AA);