diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/RakChannel.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/RakChannel.java index e379a6b3..35c67833 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/RakChannel.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/RakChannel.java @@ -17,7 +17,6 @@ package org.cloudburstmc.netty.channel.raknet; import io.netty.channel.Channel; -import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelPipeline; import org.cloudburstmc.netty.channel.raknet.config.RakChannelConfig; diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/DefaultRakServerConfig.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/DefaultRakServerConfig.java index 957ad2f6..2bad184a 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/DefaultRakServerConfig.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/DefaultRakServerConfig.java @@ -46,6 +46,7 @@ public class DefaultRakServerConfig extends DefaultChannelConfig implements RakS private volatile int packetLimit = RakConstants.DEFAULT_PACKET_LIMIT; private volatile int globalPacketLimit = RakConstants.DEFAULT_GLOBAL_PACKET_LIMIT; private volatile RakServerMetrics metrics; + private volatile boolean sendCookie; public DefaultRakServerConfig(RakServerChannel channel) { super(channel); @@ -56,7 +57,8 @@ public Map, Object> getOptions() { return getOptions( super.getOptions(), RakChannelOption.RAK_GUID, RakChannelOption.RAK_MAX_CHANNELS, RakChannelOption.RAK_MAX_CONNECTIONS, RakChannelOption.RAK_SUPPORTED_PROTOCOLS, RakChannelOption.RAK_UNCONNECTED_MAGIC, - RakChannelOption.RAK_ADVERTISEMENT, RakChannelOption.RAK_HANDLE_PING, RakChannelOption.RAK_PACKET_LIMIT, RakChannelOption.RAK_GLOBAL_PACKET_LIMIT, RakChannelOption.RAK_SERVER_METRICS); + RakChannelOption.RAK_ADVERTISEMENT, RakChannelOption.RAK_HANDLE_PING, RakChannelOption.RAK_PACKET_LIMIT, RakChannelOption.RAK_GLOBAL_PACKET_LIMIT, RakChannelOption.RAK_SEND_COOKIE, + RakChannelOption.RAK_SERVER_METRICS); } @SuppressWarnings("unchecked") @@ -98,6 +100,9 @@ public T getOption(ChannelOption option) { if (option == RakChannelOption.RAK_SERVER_METRICS) { return (T) this.getMetrics(); } + if (option == RakChannelOption.RAK_SEND_COOKIE) { + return (T) Boolean.valueOf(this.sendCookie); + } return this.channel.parent().config().getOption(option); } @@ -127,6 +132,8 @@ public boolean setOption(ChannelOption option, T value) { this.setPacketLimit((Integer) value); } else if (option == RakChannelOption.RAK_GLOBAL_PACKET_LIMIT) { this.setGlobalPacketLimit((Integer) value); + } else if (option == RakChannelOption.RAK_SEND_COOKIE) { + this.sendCookie = (Boolean) value; } else if (option == RakChannelOption.RAK_SERVER_METRICS) { this.setMetrics((RakServerMetrics) value); } else{ @@ -265,6 +272,16 @@ public void setGlobalPacketLimit(int globalPacketLimit) { this.globalPacketLimit = globalPacketLimit; } + @Override + public void setSendCookie(boolean sendCookie) { + this.sendCookie = sendCookie; + } + + @Override + public boolean getSendCookie() { + return this.sendCookie; + } + @Override public void setMetrics(RakServerMetrics metrics) { this.metrics = metrics; diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakChannelOption.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakChannelOption.java index 289a01fe..571b7a64 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakChannelOption.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakChannelOption.java @@ -152,6 +152,12 @@ public class RakChannelOption extends ChannelOption { public static final ChannelOption RAK_GLOBAL_PACKET_LIMIT = valueOf(RakChannelOption.class, "RAK_GLOBAL_PACKET_LIMIT"); + /** + * Whether to send a cookie to the client during the connection process. + */ + public static final ChannelOption RAK_SEND_COOKIE = + valueOf(RakChannelOption.class, "RAK_SEND_COOKIE"); + @SuppressWarnings("deprecation") protected RakChannelOption() { super(null); diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakServerChannelConfig.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakServerChannelConfig.java index f8711d53..b9028aa1 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakServerChannelConfig.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/channel/raknet/config/RakServerChannelConfig.java @@ -65,6 +65,10 @@ public interface RakServerChannelConfig extends ChannelConfig { void setGlobalPacketLimit(int limit); + void setSendCookie(boolean sendCookie); + + boolean getSendCookie(); + void setMetrics(RakServerMetrics metrics); RakServerMetrics getMetrics(); diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/handler/codec/raknet/client/RakClientOfflineHandler.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/handler/codec/raknet/client/RakClientOfflineHandler.java index 098d8c07..7f0a16e5 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/handler/codec/raknet/client/RakClientOfflineHandler.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/handler/codec/raknet/client/RakClientOfflineHandler.java @@ -19,7 +19,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.channel.*; -import io.netty.channel.socket.DatagramPacket; import io.netty.handler.codec.CorruptedFrameException; import io.netty.util.concurrent.ScheduledFuture; import org.cloudburstmc.netty.channel.raknet.RakChannel; @@ -45,6 +44,8 @@ public class RakClientOfflineHandler extends SimpleChannelInboundHandler pendingConnections = ExpiringMap.builder() + private final ThreadLocal random = ThreadLocal.withInitial(() -> { + try { + return SecureRandom.getInstance(SecureAlgorithmProvider.getSecurityAlgorithm()); + } catch (NoSuchAlgorithmException e) { + return new SecureRandom(); + } + }); + + private final ExpiringMap pendingConnections = ExpiringMap.builder() .expiration(10, TimeUnit.SECONDS) .expirationPolicy(ExpirationPolicy.CREATED) .expirationListener((key, value) -> ReferenceCountUtil.release(value)) @@ -163,16 +173,31 @@ private void onOpenConnectionRequest1(ChannelHandlerContext ctx, DatagramPacket // TODO: banned address check? // TODO: max connections check? - Integer version = this.pendingConnections.put(sender, protocolVersion); - if (version != null && log.isTraceEnabled()) { + + boolean sendCookie = ctx.channel().config().getOption(RakChannelOption.RAK_SEND_COOKIE); + int cookie; + + if (sendCookie) { + cookie = this.random.get().nextInt(); + } else { + cookie = 0; + } + + PendingConnection connection = this.pendingConnections.put(sender, new PendingConnection(protocolVersion, cookie)); + if (connection != null && log.isTraceEnabled()) { log.trace("Received duplicate open connection request 1 from {}", sender); } - ByteBuf replyBuffer = ctx.alloc().ioBuffer(28, 28); + int bufferCapacity = sendCookie ? 32 : 28; // 4 byte cookie + + ByteBuf replyBuffer = ctx.alloc().ioBuffer(bufferCapacity, bufferCapacity); replyBuffer.writeByte(ID_OPEN_CONNECTION_REPLY_1); replyBuffer.writeBytes(magicBuf, magicBuf.readerIndex(), magicBuf.readableBytes()); replyBuffer.writeLong(guid); - replyBuffer.writeBoolean(false); // Security + replyBuffer.writeBoolean(sendCookie); // Security + if (sendCookie) { + replyBuffer.writeInt(cookie); + } replyBuffer.writeShort(RakUtils.clamp(mtu, ctx.channel().config().getOption(RakChannelOption.RAK_MIN_MTU), ctx.channel().config().getOption(RakChannelOption.RAK_MAX_MTU))); ctx.writeAndFlush(new DatagramPacket(replyBuffer, sender)); } @@ -183,18 +208,31 @@ private void onOpenConnectionRequest2(ChannelHandlerContext ctx, DatagramPacket // Skip already verified magic buffer.skipBytes(magicBuf.readableBytes()); - Integer version = this.pendingConnections.remove(sender); - if (version == null) { - // We can't determine the version without the previous request, so assume it's the wrong version. + + PendingConnection connection = this.pendingConnections.remove(sender); + if (connection == null) { if (log.isTraceEnabled()) { log.trace("Received open connection request 2 from {} without open connection request 1", sender); } - int[] supportedProtocols = ctx.channel().config().getOption(RakChannelOption.RAK_SUPPORTED_PROTOCOLS); - int latestVersion = supportedProtocols == null ? RakConstants.RAKNET_PROTOCOL_VERSION : supportedProtocols[supportedProtocols.length - 1]; - this.sendIncompatibleVersion(ctx, sender, latestVersion, magicBuf, guid); + // Don't respond yet as we cannot verify the connection source IP return; } + boolean sendCookie = ctx.channel().config().getOption(RakChannelOption.RAK_SEND_COOKIE); + if (sendCookie) { + int cookie = buffer.readInt(); + int expectedCookie = connection.getCookie(); + if (expectedCookie != cookie) { + if (log.isTraceEnabled()) { + log.trace("Received open connection request 2 from {} with invalid cookie (expected {}, but received {})", sender, expectedCookie, cookie); + } + // Incorrect cookie provided + // This is likely source IP spoofing so we will not reply + return; + } + buffer.readBoolean(); // Client wrote challenge + } + // TODO: Verify serverAddress matches? InetSocketAddress serverAddress = RakUtils.readAddress(buffer); int mtu = buffer.readUnsignedShort(); @@ -207,7 +245,7 @@ private void onOpenConnectionRequest2(ChannelHandlerContext ctx, DatagramPacket } RakServerChannel serverChannel = (RakServerChannel) ctx.channel(); - RakChildChannel channel = serverChannel.createChildChannel(sender, clientGuid, version, mtu); + RakChildChannel channel = serverChannel.createChildChannel(sender, clientGuid, connection.getProtocolVersion(), mtu); if (channel == null) { // Already connected this.sendAlreadyConnected(ctx, sender, magicBuf, guid); @@ -240,4 +278,22 @@ private void sendAlreadyConnected(ChannelHandlerContext ctx, InetSocketAddress s buffer.writeLong(guid); ctx.writeAndFlush(new DatagramPacket(buffer, sender)); } + + private class PendingConnection { + private final int protocolVersion; + private final int cookie; + + public PendingConnection(int protocolVersion, int cookie) { + this.protocolVersion = protocolVersion; + this.cookie = cookie; + } + + public int getProtocolVersion() { + return this.protocolVersion; + } + + public int getCookie() { + return this.cookie; + } + } } diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/util/RakUtils.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/util/RakUtils.java index 7d9f217f..5fee28dc 100644 --- a/transport-raknet/src/main/java/org/cloudburstmc/netty/util/RakUtils.java +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/util/RakUtils.java @@ -88,7 +88,7 @@ public static InetSocketAddress readAddress(ByteBuf buffer) { int scopeId = buffer.readInt(); address = Inet6Address.getByAddress(null, addressBytes, scopeId); } else { - throw new UnsupportedOperationException("Unknown Internet Protocol version."); + throw new UnsupportedOperationException("Unknown Internet Protocol version. Expected 4 or 6, got " + type); } } catch (UnknownHostException e) { throw new IllegalArgumentException(e); diff --git a/transport-raknet/src/main/java/org/cloudburstmc/netty/util/SecureAlgorithmProvider.java b/transport-raknet/src/main/java/org/cloudburstmc/netty/util/SecureAlgorithmProvider.java new file mode 100644 index 00000000..9e634518 --- /dev/null +++ b/transport-raknet/src/main/java/org/cloudburstmc/netty/util/SecureAlgorithmProvider.java @@ -0,0 +1,37 @@ +package org.cloudburstmc.netty.util; + +import java.security.Provider; +import java.security.SecureRandom; +import java.security.Security; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +public class SecureAlgorithmProvider { + private static final String SECURITY_ALGORITHM; + + static { + // SecureRandom algorithms in order of most preferred to least preferred. + final List preferredAlgorithms = Arrays.asList( + "SHA1PRNG", + "NativePRNGNonBlocking", + "Windows-PRNG", + "NativePRNG", + "PKCS11", + "DRBG", + "NativePRNGBlocking" + ); + + SECURITY_ALGORITHM = Stream.of(Security.getProviders()) + .flatMap(provider -> provider.getServices().stream()) + .filter(service -> "SecureRandom".equals(service.getType())) + .map(Provider.Service::getAlgorithm) + .filter(preferredAlgorithms::contains) + .min((s1, s2) -> Integer.compare(preferredAlgorithms.indexOf(s1), preferredAlgorithms.indexOf(s2))) + .orElse(new SecureRandom().getAlgorithm()); + } + + public static String getSecurityAlgorithm() { + return SECURITY_ALGORITHM; + } +} diff --git a/transport-raknet/src/test/java/org/cloudburstmc/netty/RakTests.java b/transport-raknet/src/test/java/org/cloudburstmc/netty/RakTests.java index 6e4b8d4a..2b54acb6 100644 --- a/transport-raknet/src/test/java/org/cloudburstmc/netty/RakTests.java +++ b/transport-raknet/src/test/java/org/cloudburstmc/netty/RakTests.java @@ -27,7 +27,6 @@ import org.cloudburstmc.netty.channel.raknet.*; import org.cloudburstmc.netty.channel.raknet.config.RakChannelOption; import org.cloudburstmc.netty.channel.raknet.packet.RakMessage; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -100,7 +99,20 @@ private static ServerBootstrap serverBootstrap() { .option(RakChannelOption.RAK_MAX_CONNECTIONS, 1) .childOption(RakChannelOption.RAK_ORDERING_CHANNELS, 1) .option(RakChannelOption.RAK_GUID, ThreadLocalRandom.current().nextLong()) - .option(RakChannelOption.RAK_ADVERTISEMENT, Unpooled.wrappedBuffer(ADVERTISEMENT)); + .option(RakChannelOption.RAK_ADVERTISEMENT, Unpooled.wrappedBuffer(ADVERTISEMENT)) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(RakServerChannel ch) throws Exception { + System.out.println("Initialised server channel"); + } + }) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(RakChildChannel ch) throws Exception { + System.out.println("Server child channel initialized " + ch.remoteAddress()); + ch.pipeline().addLast(RESEND_HANDLER()); + } + }); } private static Bootstrap clientBootstrap(int mtu) { @@ -117,32 +129,44 @@ private static IntStream validMtu() { .filter(i -> i % 12 == 0); } - @BeforeEach public void setupServer() { serverBootstrap() - .handler(new ChannelInitializer() { - @Override - protected void initChannel(RakServerChannel ch) throws Exception { - System.out.println("Initialised server channel"); - } - }) - .childHandler(new ChannelInitializer() { - @Override - protected void initChannel(RakChildChannel ch) throws Exception { - System.out.println("Server child channel initialized " + ch.remoteAddress()); - ch.pipeline().addLast(RESEND_HANDLER()); - } - }) + .bind(new InetSocketAddress("127.0.0.1", 19132)) + .awaitUninterruptibly(); + } + + public void setupCookieServer() { + serverBootstrap() + .option(RakChannelOption.RAK_SEND_COOKIE, true) .bind(new InetSocketAddress("127.0.0.1", 19132)) .awaitUninterruptibly(); } @Test public void testClientConnect() { + setupServer(); int mtu = RakConstants.MAXIMUM_MTU_SIZE; System.out.println("Testing client with MTU " + mtu); - Channel channel = clientBootstrap(mtu) + clientBootstrap(mtu) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(RakClientChannel ch) throws Exception { + System.out.println("Client channel initialized"); + } + }) + .connect(new InetSocketAddress("127.0.0.1", 19132)) + .awaitUninterruptibly() + .channel(); + } + + @Test + public void testClientConnectWithCookie() { + setupCookieServer(); + int mtu = RakConstants.MAXIMUM_MTU_SIZE; + System.out.println("Testing client with MTU " + mtu + " and cookie enabled"); + + clientBootstrap(mtu) .handler(new ChannelInitializer() { @Override protected void initChannel(RakClientChannel ch) throws Exception { @@ -158,6 +182,7 @@ protected void initChannel(RakClientChannel ch) throws Exception { @ParameterizedTest @MethodSource("validMtu") public void testClientResend(int mtu) { + setupServer(); System.out.println("Testing client with MTU " + mtu); SecureRandom random = new SecureRandom();