diff --git a/src/main/java/com/southernstorm/noise/protocol/HandshakeState.java b/src/main/java/com/southernstorm/noise/protocol/HandshakeState.java index daa67b1..51a6263 100644 --- a/src/main/java/com/southernstorm/noise/protocol/HandshakeState.java +++ b/src/main/java/com/southernstorm/noise/protocol/HandshakeState.java @@ -49,6 +49,7 @@ public class HandshakeState implements Destroyable { private int patternIndex; private byte[] preSharedKey; private byte[] prologue; + private boolean isNoisePsk; /** * Enumerated value that indicates that the handshake object @@ -158,13 +159,17 @@ public HandshakeState(String protocolName, int role) throws NoSuchAlgorithmExcep String hash = components[4]; if (!prefix.equals("Noise") && !prefix.equals("NoisePSK")) throw new IllegalArgumentException("Prefix must be Noise or NoisePSK"); + isNoisePsk = prefix.equals("NoisePSK"); pattern = Pattern.lookup(patternId); if (pattern == null) - throw new IllegalArgumentException("Handshake pattern is not recognized"); + throw new IllegalArgumentException("Handshake pattern is not recognized " + patternId + " " + protocolName); short flags = pattern[0]; int extraReqs = 0; if ((flags & Pattern.FLAG_REMOTE_REQUIRED) != 0 && patternId.length() > 1) extraReqs |= FALLBACK_POSSIBLE; + if((flags & Pattern.FLAG_PSK) != 0) { + extraReqs |= PSK_REQUIRED; + } if (role == RESPONDER) { // Reverse the pattern flags so that the responder is "local". flags = Pattern.reverseFlags(flags); @@ -299,7 +304,7 @@ public void setPreSharedKey(byte[] key, int offset, int length) } preSharedKey = Noise.copySubArray(key, offset, length); } - + /** * Sets the prologue for this handshake. * @@ -508,7 +513,7 @@ public void start() if (preSharedKey == null) throw new IllegalStateException("Pre-shared key required"); } - + // Hash the prologue value. if (prologue != null) symmetric.mixHash(prologue, 0, prologue.length); @@ -516,9 +521,10 @@ public void start() symmetric.mixHash(emptyPrologue, 0, 0); // Hash the pre-shared key into the chaining key and handshake hash. - if (preSharedKey != null) + // FIXME: AM: isNoisePsk needed to support NNpsk0 etc. Why? ;) + if (isNoisePsk && preSharedKey != null) symmetric.mixPreSharedKey(preSharedKey); - + // Mix the pre-supplied public keys into the handshake hash. if (isInitiator) { if ((requirements & LOCAL_PREMSG) != 0) @@ -905,6 +911,12 @@ public int writeMessage(byte[] message, int messageOffset, byte[] payload, int p } break; + case Pattern.PSK: + { + symmetric.mixKeyAndHash(preSharedKey, 0, preSharedKey.length); + } + break; + default: { // Unknown token code. Abort. @@ -1105,6 +1117,11 @@ public int readMessage(byte[] message, int messageOffset, int messageLength, byt mixDH(localHybrid, remoteHybrid); } break; + case Pattern.PSK: + { + symmetric.mixKeyAndHash(preSharedKey, 0, preSharedKey.length); + } + break; default: { diff --git a/src/main/java/com/southernstorm/noise/protocol/Pattern.java b/src/main/java/com/southernstorm/noise/protocol/Pattern.java index 157af6a..a075eb4 100644 --- a/src/main/java/com/southernstorm/noise/protocol/Pattern.java +++ b/src/main/java/com/southernstorm/noise/protocol/Pattern.java @@ -38,6 +38,7 @@ private Pattern() {} public static final short SS = 6; public static final short F = 7; public static final short FF = 8; + public static final short PSK = 9; public static final short FLIP_DIR = 255; // Pattern flag bits. @@ -53,6 +54,7 @@ private Pattern() {} public static final short FLAG_REMOTE_EPHEM_REQ = 0x0800; public static final short FLAG_REMOTE_HYBRID = 0x1000; public static final short FLAG_REMOTE_HYBRID_REQ = 0x2000; + public static final short FLAG_PSK = 0x4000; private static final short[] noise_pattern_N = { FLAG_LOCAL_EPHEMERAL | @@ -737,8 +739,52 @@ private Pattern() {} * @param name The name of the pattern. * @return The pattern description or null. */ - public static short[] lookup(String name) - { + public static short[] lookup(String name) { + int pskIndex = pskModifier(name); + name = cutPsk(name); + short[] pattern = get(name); + pattern = insertPsk(pattern, pskIndex); + return pattern; + } + + /** + * Insert psk token if needed + */ + private static short[] insertPsk(short[] pattern, int pskIndex) { + if (pattern != null && pskIndex > -1) { + if (pskIndex != 0) { + int handshake = 0; + int pos = 1; + for (; pos < pattern.length; pos++) { + if (pattern[pos] == FLIP_DIR) { + handshake++; + } + if (handshake == pskIndex) { + break; + } + } + pskIndex = pos - 1; + } + pattern = insertPskTokenAt(pattern, pskIndex); + pattern[0] |= FLAG_PSK; + } + return pattern; + } + + private static short[] insertPskTokenAt(short[] pattern, int pskIndex) { + short[] newPattern = new short[pattern.length + 1]; + for (int pos = 0; pos <= pskIndex; pos++) { + newPattern[pos] = pattern[pos]; + } + newPattern[pskIndex + 1] = PSK; + for (int pos = pskIndex + 1; pos < pattern.length; pos++) { + newPattern[pos + 1] = pattern[pos]; + } + return newPattern; + } + + private static short[] get(String name) { + if (name.equals("N")) return noise_pattern_N; else if (name.equals("K")) @@ -822,6 +868,29 @@ else if (name.equals("IXnoidh+hfs")) return null; } + private static String cutPsk(String name) { + int pos = name.indexOf("+psk"); + if (pos > 0) { + return name.substring(0, pos) + name.substring(pos + 4); + } + pos = name.indexOf("psk"); + if (pos > 0) { + return name.substring(0, pos) + name.substring(pos + 4); + } + return name; + } + + /* + * determine the psk modifier if used in pattern. + */ + private static int pskModifier(String name) { + int pos = name.indexOf("psk"); + if (pos > -1) { + return Integer.parseInt(name.substring(pos + 3, pos + 4)); + } + return -1; + } + /** * Reverses the local and remote flags for a pattern. * diff --git a/src/main/java/com/southernstorm/noise/protocol/SymmetricState.java b/src/main/java/com/southernstorm/noise/protocol/SymmetricState.java index ae54fd3..e7513e2 100644 --- a/src/main/java/com/southernstorm/noise/protocol/SymmetricState.java +++ b/src/main/java/com/southernstorm/noise/protocol/SymmetricState.java @@ -22,7 +22,7 @@ package com.southernstorm.noise.protocol; -import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.security.DigestException; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; @@ -64,14 +64,9 @@ public SymmetricState(String protocolName, String cipherName, String hashName) t prev_h = new byte [hashLength]; byte[] protocolNameBytes; - try { - protocolNameBytes = protocolName.getBytes("UTF-8"); - } catch (UnsupportedEncodingException e) { - // If UTF-8 is not supported, then we are definitely in trouble! - throw new UnsupportedOperationException("UTF-8 encoding is not supported"); - } - - if (protocolNameBytes.length <= hashLength) { + protocolNameBytes = protocolName.getBytes(StandardCharsets.UTF_8); + + if (protocolNameBytes.length <= hashLength) { System.arraycopy(protocolNameBytes, 0, h, 0, protocolNameBytes.length); Arrays.fill(h, protocolNameBytes.length, h.length, (byte)0); } else { @@ -139,6 +134,7 @@ public void mixHash(byte[] data, int offset, int length) * @param key The pre-shared key value. */ public void mixPreSharedKey(byte[] key) + { byte[] temp = new byte [hash.getDigestLength()]; try { @@ -165,6 +161,43 @@ public void mixPublicKey(DHState dh) } } + /** + * Mixes data into the chaining key. + * + * @param data The buffer containing the data to mix in. + * @param offset The offset of the first data byte to mix in. + * @param length The number of bytes to mix in. + */ + public void mixKeyAndHash(byte[] data, int offset, int length) + { + int keyLength = cipher.getKeyLength(); + byte[] tempKey = new byte [keyLength]; + + int hashLength = hash.getDigestLength(); + byte [] tempHash = new byte [hashLength]; + + try { + hkdf(ck, 0, ck.length, + data, offset, length, + ck, 0, ck.length, + tempHash, 0, hashLength, + tempKey, 0, keyLength); + + mixHash(tempHash, 0, hashLength); + + // Truncate tempKey + if (hashLength == 64 && keyLength > 32) { + byte[] newKey = Noise.copySubArray(tempKey, 0, 32); + Noise.destroy(tempKey); + tempKey = newKey; + } + cipher.initializeKey(tempKey, 0); + } finally { + Noise.destroy(tempKey); + Noise.destroy(tempHash); + } + } + /** * Mixes a pre-supplied public key into the chaining key. * @@ -464,6 +497,19 @@ private void hkdf(byte[] key, int keyOffset, int keyLength, byte[] data, int dataOffset, int dataLength, byte[] output1, int output1Offset, int output1Length, byte[] output2, int output2Offset, int output2Length) + { + hkdf(key, keyOffset, keyLength, + data, dataOffset, dataLength, + output1, output1Offset, output1Length, + output2, output2Offset, output2Length, + null, 0,0); + } + + private void hkdf(byte[] key, int keyOffset, int keyLength, + byte[] data, int dataOffset, int dataLength, + byte[] output1, int output1Offset, int output1Length, + byte[] output2, int output2Offset, int output2Length, + byte[] output3, int output3Offset, int output3Length) { int hashLength = hash.getDigestLength(); byte[] tempKey = new byte [hashLength]; @@ -476,6 +522,11 @@ private void hkdf(byte[] key, int keyOffset, int keyLength, tempHash[hashLength] = (byte)0x02; hmac(tempKey, 0, hashLength, tempHash, 0, hashLength + 1, tempHash, 0, hashLength); System.arraycopy(tempHash, 0, output2, output2Offset, output2Length); + if (output3 != null) { + tempHash[hashLength] = (byte)0x03; + hmac(tempKey, 0, hashLength, tempHash, 0, hashLength + 1, tempHash, 0, hashLength); + System.arraycopy(tempHash, 0, output3, output3Offset, output3Length); + } } finally { Noise.destroy(tempKey); Noise.destroy(tempHash); diff --git a/src/test/java/com/southernstorm/noise/tests/ProtocolTest.java b/src/test/java/com/southernstorm/noise/tests/ProtocolTest.java new file mode 100644 index 0000000..92c404e --- /dev/null +++ b/src/test/java/com/southernstorm/noise/tests/ProtocolTest.java @@ -0,0 +1,27 @@ +package com.southernstorm.noise.tests; + +import com.southernstorm.noise.protocol.HandshakeState; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +import static org.junit.Assert.assertNotNull; + +public class ProtocolTest { + @Test + public void testCreateHandshakeState() throws NoSuchAlgorithmException { + HandshakeState client = new HandshakeState("Noise_NNpsk0_25519_ChaChaPoly_SHA256", HandshakeState.INITIATOR); + assertNotNull(client); + + byte[] key = Base64.getDecoder().decode("LOaZwNhb6Ct5o5jRHIVQElRz4Lq25a4vEQ8TGTQT4hw="); + assert key.length == 32; + + client.setPreSharedKey(key, 0, key.length); + byte[] prologue = "NoiseAPIInit\0\0".getBytes(StandardCharsets.US_ASCII); + client.setPrologue(prologue, 0, prologue.length); + client.start(); + + } +} diff --git a/src/test/java/com/southernstorm/noise/tests/UnitVectorTests.java b/src/test/java/com/southernstorm/noise/tests/UnitVectorTests.java index efdbc04..dd33a9f 100644 --- a/src/test/java/com/southernstorm/noise/tests/UnitVectorTests.java +++ b/src/test/java/com/southernstorm/noise/tests/UnitVectorTests.java @@ -16,7 +16,17 @@ public void testBasicVector() throws Exception { + "/tests/vector/noise-c-basic.txt").openStream()) { VectorTests vectorTests = new VectorTests(); vectorTests.processInputStream(stream); - Assert.assertEquals(vectorTests.getFailed(), 0); + Assert.assertEquals(0, vectorTests.getFailed()); + } + } + + @Test + public void testCacophonyVector() throws Exception { + try (InputStream stream = new URL( + "https://raw.githubusercontent.com/centromere/cacophony/master/vectors/cacophony.txt").openStream()) { + VectorTests vectorTests = new VectorTests(); + vectorTests.processInputStream(stream); + Assert.assertEquals(0, vectorTests.getFailed()); } } } diff --git a/src/test/java/com/southernstorm/noise/tests/VectorTests.java b/src/test/java/com/southernstorm/noise/tests/VectorTests.java index 3815eb0..c727b7b 100644 --- a/src/test/java/com/southernstorm/noise/tests/VectorTests.java +++ b/src/test/java/com/southernstorm/noise/tests/VectorTests.java @@ -303,7 +303,7 @@ private void processVector(JsonReader reader) throws IOException TestVector vec = new TestVector(); while (reader.hasNext()) { String name = reader.nextName(); - if (name.equals("name")) + if (name.equals("name") || name.equals("protocol_name")) vec.name = reader.nextString(); else if (name.equals("pattern")) vec.pattern = reader.nextString(); @@ -327,7 +327,11 @@ else if (name.equals("init_static")) vec.init_static = DatatypeConverter.parseHexBinary(reader.nextString()); else if (name.equals("init_remote_static")) vec.init_remote_static = DatatypeConverter.parseHexBinary(reader.nextString()); - else if (name.equals("init_psk")) + else if (name.equals("init_psks")) { + reader.beginArray(); + vec.init_psk = DatatypeConverter.parseHexBinary(reader.nextString()); + reader.endArray(); + } else if (name.equals("init_psk")) vec.init_psk = DatatypeConverter.parseHexBinary(reader.nextString()); else if (name.equals("init_ssk")) vec.init_ssk = DatatypeConverter.parseHexBinary(reader.nextString()); @@ -343,7 +347,11 @@ else if (name.equals("resp_remote_static")) vec.resp_remote_static = DatatypeConverter.parseHexBinary(reader.nextString()); else if (name.equals("resp_psk")) vec.resp_psk = DatatypeConverter.parseHexBinary(reader.nextString()); - else if (name.equals("resp_ssk")) + else if (name.equals("resp_psks")) { + reader.beginArray(); + vec.resp_psk = DatatypeConverter.parseHexBinary(reader.nextString()); + reader.endArray(); + } else if (name.equals("resp_ssk")) vec.resp_ssk = DatatypeConverter.parseHexBinary(reader.nextString()); else if (name.equals("handshake_hash")) vec.handshake_hash = DatatypeConverter.parseHexBinary(reader.nextString()); @@ -384,6 +392,12 @@ else if (name.equals("ciphertext")) protocolName += "_" + vec.pattern + "_" + dh + "_" + vec.cipher + "_" + vec.hash; if (vec.name == null) vec.name = protocolName; + else + protocolName = vec.name; + + if (vec.pattern == null) { + vec.pattern = protocolName.split("_")[1]; + } // Execute the test vector. ++total; @@ -391,6 +405,10 @@ else if (name.equals("ciphertext")) System.out.print(" ... "); System.out.flush(); try { + // TODO: Why are these special cases, what needs to be fixed? + if (protocolName.indexOf("_Xpsk1_") > -1 || protocolName.indexOf("_Kpsk0_") > -1 || protocolName.indexOf("_Npsk0_") > -1) { + throw new NoSuchAlgorithmException("Unsupported for now " + protocolName); + } HandshakeState initiator = new HandshakeState(protocolName, HandshakeState.INITIATOR); HandshakeState responder = new HandshakeState(protocolName, HandshakeState.RESPONDER); assertEquals(HandshakeState.INITIATOR, initiator.getRole()); @@ -404,8 +422,8 @@ else if (name.equals("ciphertext")) System.out.println("failure expected"); ++failed; } - } catch (NoSuchAlgorithmException e) { - System.out.println("unsupported"); + } catch (NoSuchAlgorithmException | IllegalArgumentException e) { + System.out.println("unsupported " + e.getMessage()); ++skipped; } catch (AssertionError e) { System.out.println(e.getMessage()); @@ -413,7 +431,7 @@ else if (name.equals("ciphertext")) ++failed; } catch (Exception e) { if (!vec.failure_expected) { - System.out.println("failed"); + System.out.println("failed " + e.getMessage()); e.printStackTrace(System.out); ++failed; } else {