Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interface specific addresses #26

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/IPAddressExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Text;

namespace Makaretu.Dns
{
/// <summary>
/// Extensions for <see cref="IPAddress"/>.
/// </summary>
public static class IPAddressExtensions
{
/// <summary>
/// Gets the subnet mask associated with the IP address.
/// </summary>
/// <param name="address">
/// An IP Addresses.
/// </param>
/// <returns>
/// The subnet mask; for example "127.0.0.1" returns "255.0.0.0".
/// Or <b>null</b> When <paramref name="address"/> does not belong to
/// the localhost.
/// s</returns>
public static IPAddress GetSubnetMask(this IPAddress address)
{
return NetworkInterface.GetAllNetworkInterfaces()
.SelectMany(nic => nic.GetIPProperties().UnicastAddresses)
.Where(a => a.Address.Equals(address))
.Select(a => a.IPv4Mask)
.FirstOrDefault();
}

/// <summary>
/// Determines if the local IP address can be used by the
/// remote address.
/// </summary>
/// <param name="local"></param>
/// <param name="remote"></param>
/// <returns>
/// <b>true</b> if <paramref name="local"/> can be used by <paramref name="remote"/>;
/// otherwise, <b>false</b>.
/// </returns>
public static bool IsReachable(this IPAddress local, IPAddress remote)
{
// Loopback addresses are only reachable when the remote is
// the same host.
if (local.Equals(IPAddress.Loopback) || local.Equals(IPAddress.IPv6Loopback))
{
return MulticastService.GetIPAddresses().Contains(remote);
}

// IPv4 addresses are reachable when on the same subnet.
if (local.AddressFamily == AddressFamily.InterNetwork && remote.AddressFamily == AddressFamily.InterNetwork)
{
var mask = local.GetSubnetMask();
if (mask != null)
{
var network = IPNetwork.Parse(local, mask);
return network.Contains(remote);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing this on multi-interface setup, I was still seeing some weird results, like my local Hyper-V interface addresses being used in answers to queries sent from remote machines on a completely different network, which I'd expect would be filtered out.

}
}

// IPv6 link local addresses are reachabe when using the same scope id.
if (local.AddressFamily == AddressFamily.InterNetworkV6 && remote.AddressFamily == AddressFamily.InterNetworkV6)
{
if (local.IsIPv6LinkLocal || remote.IsIPv6LinkLocal)
{
return local.Equals(remote);
}
}

// Can not determine reachability, assume that network routing can do it.
return true;
Copy link

@jbrestan jbrestan Sep 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there no further heuristics we could use if all else fails? Maybe something like local.AddressFamily == remote.AddressFamily? In my multi-interface test setup I was seeing a lot of IPv6 responses that get removed if I add that additional comparison, but I'm not sure if that's right or wrong...

}
}
}
4 changes: 2 additions & 2 deletions src/MessageEventArgs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ public class MessageEventArgs : EventArgs
public Message Message { get; set; }

/// <summary>
/// The DNS message sender endpoint.
/// Where the message originated from.
/// </summary>
/// <value>
/// The endpoint from the message was received.
/// The IP address and port of the sender.
/// </value>
public IPEndPoint RemoteEndPoint { get; set; }

Expand Down
12 changes: 12 additions & 0 deletions src/ServiceDiscovery.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using Common.Logging;
using Makaretu.Dns.Resolving;

Expand Down Expand Up @@ -381,6 +382,11 @@ void OnQuery(object sender, MessageEventArgs e)
;
}

// Only return address records that the querier can reach.
response.Answers.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
response.AuthorityRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
response.AdditionalRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));

if (QU)
{
// TODO: Send a Unicast response if required.
Expand All @@ -402,6 +408,12 @@ void OnQuery(object sender, MessageEventArgs e)
//Console.WriteLine($"Response time {(DateTime.Now - request.CreationTime).TotalMilliseconds}ms");
}

bool IsUnreachable(ResourceRecord rr, IPEndPoint sender)
{
var arecord = rr as AddressRecord;
return !arecord?.Address.IsReachable(sender.Address) ?? false;
}

#region IDisposable Support

/// <inheritdoc />
Expand Down
2 changes: 1 addition & 1 deletion src/ServiceProfile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public ServiceProfile(DomainName instanceName, DomainName serviceName, ushort po
Strings = { "txtvers=1" }
});

foreach (var address in addresses ?? MulticastService.GetLinkLocalAddresses())
foreach (var address in addresses ?? MulticastService.GetIPAddresses())
{
Resources.Add(AddressRecord.Create(HostName, address));
}
Expand Down
106 changes: 106 additions & 0 deletions test/IPAddressExtensionsTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using Makaretu.Dns;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;

namespace Makaretu.Dns
{

[TestClass]
public class IPAddressExtensionsTest
{
[TestMethod]
public void SubnetMask_Ipv4Loopback()
{
var mask = IPAddress.Loopback.GetSubnetMask();
Assert.AreEqual(IPAddress.Parse("255.0.0.0"), mask);
var network = IPNetwork.Parse(IPAddress.Loopback, mask);
Console.Write(network.ToString());
}

[TestMethod]
[TestCategory("IPv6")]
public void SubnetMask_Ipv6Loopback()
{
var mask = IPAddress.IPv6Loopback.GetSubnetMask();
Assert.AreEqual(IPAddress.Parse("0.0.0.0"), mask);
}

[TestMethod]
public void SubmetMask_NotLocalhost()
{
var mask = IPAddress.Parse("1.1.1.1").GetSubnetMask();
Assert.IsNull(mask);
}

[TestMethod]
public void SubnetMask_All()
{
foreach (var a in MulticastService.GetIPAddresses())
{
var network = IPNetwork.Parse(a, a.GetSubnetMask());

Console.WriteLine($"{a} mask {a.GetSubnetMask()} {network}");

Assert.IsTrue(network.Contains(a), $"{a} is not reachable");
}
}

[TestMethod]
public void LinkLocal()
{
foreach (var a in MulticastService.GetIPAddresses())
{
Console.WriteLine($"{a} ll={a.IsIPv6LinkLocal} ss={a.IsIPv6SiteLocal}");
}
}

[TestMethod]
public void Reachable_Loopback_From_Localhost()
{
var me = IPAddress.Loopback;
foreach (var a in MulticastService.GetIPAddresses())
{
Assert.IsTrue(me.IsReachable(a), $"{a}");
}
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));

me = IPAddress.IPv6Loopback;
foreach (var a in MulticastService.GetIPAddresses())
{
Assert.IsTrue(me.IsReachable(a), $"{a}");
}
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));
}

[TestMethod]
public void Reachable_Ipv4()
{
var me = MulticastService.GetIPAddresses()
.First(a => a.AddressFamily == AddressFamily.InterNetwork && !IPAddress.IsLoopback(a));
Assert.IsTrue(me.IsReachable(me));
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));

var nat = IPAddress.Parse("165.84.19.151"); // NAT PCP assigned address
Assert.IsTrue(nat.IsReachable(IPAddress.Parse("1.1.1.1")));
}

[TestMethod]
public void Reachable_Ipv6_LinkLocal()
{
var me1 = IPAddress.Parse("fe80::1:2:3:4%1");
var me2 = IPAddress.Parse("fe80::1:2:3:4%2");
var me5 = IPAddress.Parse("fe80::1:2:3:5%1");
Assert.IsTrue(me1.IsReachable(me1));
Assert.IsTrue(me2.IsReachable(me2));
Assert.IsFalse(me1.IsReachable(me2));
Assert.IsFalse(me1.IsReachable(me5));
}
}
}
8 changes: 7 additions & 1 deletion test/MulticastServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public void SendQuery()
var ready = new ManualResetEvent(false);
var done = new ManualResetEvent(false);
Message msg = null;
IPEndPoint sender = null;

var mdns = new MulticastService();
mdns.NetworkInterfaceDiscovered += (s, e) => ready.Set();
Expand All @@ -43,7 +44,7 @@ public void SendQuery()
if ("some-service.local" == e.Message.Questions.First().Name)
{
msg = e.Message;
Assert.IsFalse(e.IsLegacyUnicast);
sender = e.RemoteEndPoint;
done.Set();
}
};
Expand All @@ -53,6 +54,8 @@ public void SendQuery()
Assert.IsTrue(ready.WaitOne(TimeSpan.FromSeconds(1)), "ready timeout");
mdns.SendQuery("some-service.local");
Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "query timeout");
Assert.IsNotNull(msg);
Assert.IsNotNull(sender);
Assert.AreEqual("some-service.local", msg.Questions.First().Name);
Assert.AreEqual(DnsClass.IN, msg.Questions.First().Class);
}
Expand Down Expand Up @@ -97,6 +100,7 @@ public void ReceiveAnswer()
var service = Guid.NewGuid().ToString() + ".local";
var done = new ManualResetEvent(false);
Message response = null;
IPEndPoint sender = null;

using (var mdns = new MulticastService())
{
Expand All @@ -121,12 +125,14 @@ public void ReceiveAnswer()
if (msg.Answers.Any(answer => answer.Name == service))
{
response = msg;
sender = e.RemoteEndPoint;
done.Set();
}
};
mdns.Start();
Assert.IsTrue(done.WaitOne(TimeSpan.FromSeconds(1)), "answer timeout");
Assert.IsNotNull(response);
Assert.IsNotNull(sender);
Assert.IsTrue(response.IsResponse);
Assert.AreEqual(MessageStatus.NoError, response.Status);
Assert.IsTrue(response.AA);
Expand Down