Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for some pskN variants #19

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/main/java/com/southernstorm/noise/protocol/HandshakeState.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class HandshakeState implements Destroyable {
private int patternIndex;
private byte[] preSharedKey;
private byte[] prologue;
private boolean isNoisePsk;

/**
* Enumerated value that indicates that the handshake object
Expand Down Expand Up @@ -158,13 +159,17 @@ public HandshakeState(String protocolName, int role) throws NoSuchAlgorithmExcep
String hash = components[4];
if (!prefix.equals("Noise") && !prefix.equals("NoisePSK"))
throw new IllegalArgumentException("Prefix must be Noise or NoisePSK");
isNoisePsk = prefix.equals("NoisePSK");
pattern = Pattern.lookup(patternId);
if (pattern == null)
throw new IllegalArgumentException("Handshake pattern is not recognized");
throw new IllegalArgumentException("Handshake pattern is not recognized " + patternId + " " + protocolName);
short flags = pattern[0];
int extraReqs = 0;
if ((flags & Pattern.FLAG_REMOTE_REQUIRED) != 0 && patternId.length() > 1)
extraReqs |= FALLBACK_POSSIBLE;
if((flags & Pattern.FLAG_PSK) != 0) {
extraReqs |= PSK_REQUIRED;
}
if (role == RESPONDER) {
// Reverse the pattern flags so that the responder is "local".
flags = Pattern.reverseFlags(flags);
Expand Down Expand Up @@ -299,7 +304,7 @@ public void setPreSharedKey(byte[] key, int offset, int length)
}
preSharedKey = Noise.copySubArray(key, offset, length);
}

/**
* Sets the prologue for this handshake.
*
Expand Down Expand Up @@ -508,17 +513,18 @@ public void start()
if (preSharedKey == null)
throw new IllegalStateException("Pre-shared key required");
}

// Hash the prologue value.
if (prologue != null)
symmetric.mixHash(prologue, 0, prologue.length);
else
symmetric.mixHash(emptyPrologue, 0, 0);

// Hash the pre-shared key into the chaining key and handshake hash.
if (preSharedKey != null)
// FIXME: AM: isNoisePsk needed to support NNpsk0 etc. Why? ;)
if (isNoisePsk && preSharedKey != null)
symmetric.mixPreSharedKey(preSharedKey);

// Mix the pre-supplied public keys into the handshake hash.
if (isInitiator) {
if ((requirements & LOCAL_PREMSG) != 0)
Expand Down Expand Up @@ -905,6 +911,12 @@ public int writeMessage(byte[] message, int messageOffset, byte[] payload, int p
}
break;

case Pattern.PSK:
{
symmetric.mixKeyAndHash(preSharedKey, 0, preSharedKey.length);
}
break;

default:
{
// Unknown token code. Abort.
Expand Down Expand Up @@ -1105,6 +1117,11 @@ public int readMessage(byte[] message, int messageOffset, int messageLength, byt
mixDH(localHybrid, remoteHybrid);
}
break;
case Pattern.PSK:
{
symmetric.mixKeyAndHash(preSharedKey, 0, preSharedKey.length);
}
break;

default:
{
Expand Down
73 changes: 71 additions & 2 deletions src/main/java/com/southernstorm/noise/protocol/Pattern.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ private Pattern() {}
public static final short SS = 6;
public static final short F = 7;
public static final short FF = 8;
public static final short PSK = 9;
public static final short FLIP_DIR = 255;

// Pattern flag bits.
Expand All @@ -53,6 +54,7 @@ private Pattern() {}
public static final short FLAG_REMOTE_EPHEM_REQ = 0x0800;
public static final short FLAG_REMOTE_HYBRID = 0x1000;
public static final short FLAG_REMOTE_HYBRID_REQ = 0x2000;
public static final short FLAG_PSK = 0x4000;

private static final short[] noise_pattern_N = {
FLAG_LOCAL_EPHEMERAL |
Expand Down Expand Up @@ -737,8 +739,52 @@ private Pattern() {}
* @param name The name of the pattern.
* @return The pattern description or null.
*/
public static short[] lookup(String name)
{
public static short[] lookup(String name) {
int pskIndex = pskModifier(name);
name = cutPsk(name);
short[] pattern = get(name);
pattern = insertPsk(pattern, pskIndex);
return pattern;
}

/**
* Insert psk token if needed
*/
private static short[] insertPsk(short[] pattern, int pskIndex) {
if (pattern != null && pskIndex > -1) {
if (pskIndex != 0) {
int handshake = 0;
int pos = 1;
for (; pos < pattern.length; pos++) {
if (pattern[pos] == FLIP_DIR) {
handshake++;
}
if (handshake == pskIndex) {
break;
}
}
pskIndex = pos - 1;
}
pattern = insertPskTokenAt(pattern, pskIndex);
pattern[0] |= FLAG_PSK;
}
return pattern;
}

private static short[] insertPskTokenAt(short[] pattern, int pskIndex) {
short[] newPattern = new short[pattern.length + 1];
for (int pos = 0; pos <= pskIndex; pos++) {
newPattern[pos] = pattern[pos];
}
newPattern[pskIndex + 1] = PSK;
for (int pos = pskIndex + 1; pos < pattern.length; pos++) {
newPattern[pos + 1] = pattern[pos];
}
return newPattern;
}

private static short[] get(String name) {

if (name.equals("N"))
return noise_pattern_N;
else if (name.equals("K"))
Expand Down Expand Up @@ -822,6 +868,29 @@ else if (name.equals("IXnoidh+hfs"))
return null;
}

private static String cutPsk(String name) {
int pos = name.indexOf("+psk");
if (pos > 0) {
return name.substring(0, pos) + name.substring(pos + 4);
}
pos = name.indexOf("psk");
if (pos > 0) {
return name.substring(0, pos) + name.substring(pos + 4);
}
return name;
}

/*
* determine the psk modifier if used in pattern.
*/
private static int pskModifier(String name) {
int pos = name.indexOf("psk");
if (pos > -1) {
return Integer.parseInt(name.substring(pos + 3, pos + 4));
}
return -1;
}

/**
* Reverses the local and remote flags for a pattern.
*
Expand Down
69 changes: 60 additions & 9 deletions src/main/java/com/southernstorm/noise/protocol/SymmetricState.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

package com.southernstorm.noise.protocol;

import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.security.DigestException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
Expand Down Expand Up @@ -64,14 +64,9 @@ public SymmetricState(String protocolName, String cipherName, String hashName) t
prev_h = new byte [hashLength];

byte[] protocolNameBytes;
try {
protocolNameBytes = protocolName.getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
// If UTF-8 is not supported, then we are definitely in trouble!
throw new UnsupportedOperationException("UTF-8 encoding is not supported");
}

if (protocolNameBytes.length <= hashLength) {
protocolNameBytes = protocolName.getBytes(StandardCharsets.UTF_8);

if (protocolNameBytes.length <= hashLength) {
System.arraycopy(protocolNameBytes, 0, h, 0, protocolNameBytes.length);
Arrays.fill(h, protocolNameBytes.length, h.length, (byte)0);
} else {
Expand Down Expand Up @@ -139,6 +134,7 @@ public void mixHash(byte[] data, int offset, int length)
* @param key The pre-shared key value.
*/
public void mixPreSharedKey(byte[] key)

{
byte[] temp = new byte [hash.getDigestLength()];
try {
Expand All @@ -165,6 +161,43 @@ public void mixPublicKey(DHState dh)
}
}

/**
* Mixes data into the chaining key.
*
* @param data The buffer containing the data to mix in.
* @param offset The offset of the first data byte to mix in.
* @param length The number of bytes to mix in.
*/
public void mixKeyAndHash(byte[] data, int offset, int length)
{
int keyLength = cipher.getKeyLength();
byte[] tempKey = new byte [keyLength];

int hashLength = hash.getDigestLength();
byte [] tempHash = new byte [hashLength];

try {
hkdf(ck, 0, ck.length,
data, offset, length,
ck, 0, ck.length,
tempHash, 0, hashLength,
tempKey, 0, keyLength);

mixHash(tempHash, 0, hashLength);

// Truncate tempKey
if (hashLength == 64 && keyLength > 32) {
byte[] newKey = Noise.copySubArray(tempKey, 0, 32);
Noise.destroy(tempKey);
tempKey = newKey;
}
cipher.initializeKey(tempKey, 0);
} finally {
Noise.destroy(tempKey);
Noise.destroy(tempHash);
}
}

/**
* Mixes a pre-supplied public key into the chaining key.
*
Expand Down Expand Up @@ -464,6 +497,19 @@ private void hkdf(byte[] key, int keyOffset, int keyLength,
byte[] data, int dataOffset, int dataLength,
byte[] output1, int output1Offset, int output1Length,
byte[] output2, int output2Offset, int output2Length)
{
hkdf(key, keyOffset, keyLength,
data, dataOffset, dataLength,
output1, output1Offset, output1Length,
output2, output2Offset, output2Length,
null, 0,0);
}

private void hkdf(byte[] key, int keyOffset, int keyLength,
byte[] data, int dataOffset, int dataLength,
byte[] output1, int output1Offset, int output1Length,
byte[] output2, int output2Offset, int output2Length,
byte[] output3, int output3Offset, int output3Length)
{
int hashLength = hash.getDigestLength();
byte[] tempKey = new byte [hashLength];
Expand All @@ -476,6 +522,11 @@ private void hkdf(byte[] key, int keyOffset, int keyLength,
tempHash[hashLength] = (byte)0x02;
hmac(tempKey, 0, hashLength, tempHash, 0, hashLength + 1, tempHash, 0, hashLength);
System.arraycopy(tempHash, 0, output2, output2Offset, output2Length);
if (output3 != null) {
tempHash[hashLength] = (byte)0x03;
hmac(tempKey, 0, hashLength, tempHash, 0, hashLength + 1, tempHash, 0, hashLength);
System.arraycopy(tempHash, 0, output3, output3Offset, output3Length);
}
} finally {
Noise.destroy(tempKey);
Noise.destroy(tempHash);
Expand Down
27 changes: 27 additions & 0 deletions src/test/java/com/southernstorm/noise/tests/ProtocolTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.southernstorm.noise.tests;

import com.southernstorm.noise.protocol.HandshakeState;
import org.junit.Test;

import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;

import static org.junit.Assert.assertNotNull;

public class ProtocolTest {
@Test
public void testCreateHandshakeState() throws NoSuchAlgorithmException {
HandshakeState client = new HandshakeState("Noise_NNpsk0_25519_ChaChaPoly_SHA256", HandshakeState.INITIATOR);
assertNotNull(client);

byte[] key = Base64.getDecoder().decode("LOaZwNhb6Ct5o5jRHIVQElRz4Lq25a4vEQ8TGTQT4hw=");
assert key.length == 32;

client.setPreSharedKey(key, 0, key.length);
byte[] prologue = "NoiseAPIInit\0\0".getBytes(StandardCharsets.US_ASCII);
client.setPrologue(prologue, 0, prologue.length);
client.start();

}
}
12 changes: 11 additions & 1 deletion src/test/java/com/southernstorm/noise/tests/UnitVectorTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@ public void testBasicVector() throws Exception {
+ "/tests/vector/noise-c-basic.txt").openStream()) {
VectorTests vectorTests = new VectorTests();
vectorTests.processInputStream(stream);
Assert.assertEquals(vectorTests.getFailed(), 0);
Assert.assertEquals(0, vectorTests.getFailed());
}
}

@Test
public void testCacophonyVector() throws Exception {
try (InputStream stream = new URL(
"https://raw.githubusercontent.com/centromere/cacophony/master/vectors/cacophony.txt").openStream()) {
VectorTests vectorTests = new VectorTests();
vectorTests.processInputStream(stream);
Assert.assertEquals(0, vectorTests.getFailed());
}
}
}
Loading