Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle mls-welcome events, creating necessary data with core-crypto #WPB-12154 #111

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/maven-release.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
name: Release to Maven Central

on:
push:
tags:
- '*'
workflow_dispatch:
release:
types: [ published ]
Comment on lines +4 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌🏻


jobs:
tests:
Expand Down
14 changes: 13 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>com.wire</groupId>
<artifactId>xenon</artifactId>
<version>1.7.0</version>
<version>1.7.1</version>

<name>Xenon</name>
<description>Base Wire Bots Library</description>
Expand Down Expand Up @@ -115,6 +115,12 @@
<version>5.10.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.14.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
Expand Down Expand Up @@ -198,6 +204,12 @@
<groupId>org.jetbrains.dokka</groupId>
<artifactId>dokka-maven-plugin</artifactId>
<version>1.9.20</version>
<configuration>
<sourceDirectories>
<dir>${project.basedir}/src/main/java</dir>
<dir>${project.basedir}/src/main/kotlin</dir>
</sourceDirectories>
</configuration>
<executions>
<execution>
<phase>compile</phase>
Expand Down
22 changes: 20 additions & 2 deletions src/main/java/com/wire/xenon/MessageResourceBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

public abstract class MessageResourceBase {
protected final MessageHandlerBase handler;
protected static final Integer PREKEYS_DEFAULT_REPLENISH = 10;

public MessageResourceBase(MessageHandlerBase handler) {
this.handler = handler;
Expand Down Expand Up @@ -57,6 +58,14 @@ protected void handleMessage(UUID eventId, Payload payload, WireClient client) t

handler.onEvent(client, fromMls, genericMessageMls);
break;
case "conversation.mls-welcome":
Logger.info("conversation.mls-welcome: bot: %s in: %s", botId, payload.conversation);

client.processWelcomeMessage(payload.data.text);

SystemMessage welcomeSystemMessage = getSystemMessage(eventId, payload);
handler.onNewConversation(client, welcomeSystemMessage);
break;
case "conversation.member-join":
Logger.debug("conversation.member-join: bot: %s", botId);

Expand All @@ -71,7 +80,7 @@ protected void handleMessage(UUID eventId, Payload payload, WireClient client) t
return;
}

// Check if we still have some prekeys available. Upload new prekeys if needed
// Check if we still have some prekeys available. Upload them if needed
handler.validatePreKeys(client, participants.size());

SystemMessage systemMessage = getSystemMessage(eventId, payload);
Expand Down Expand Up @@ -107,16 +116,25 @@ protected void handleMessage(UUID eventId, Payload payload, WireClient client) t
Logger.debug("conversation.create: bot: %s", botId);

systemMessage = getSystemMessage(eventId, payload);
Integer preKeysUserCount = PREKEYS_DEFAULT_REPLENISH;
if (systemMessage.conversation.members != null) {
preKeysUserCount = systemMessage.conversation.members.others.size();
Member self = new Member();
String selfDomain = null;
if (systemMessage.conversation != null && systemMessage.conversation.id != null) {
if (systemMessage.conversation.id != null) {
selfDomain = systemMessage.conversation.id.domain;
}
self.id = new QualifiedId(botId, selfDomain);
systemMessage.conversation.members.others.add(self);
}

// Check if we still have some prekeys and keyPackages available. Upload them if needed
if (systemMessage.conversation.protocol == Conversation.Protocol.PROTEUS)
handler.validatePreKeys(client, preKeysUserCount);
else {
client.checkAndReplenishKeyPackages();
}

handler.onNewConversation(client, systemMessage);
break;
case "conversation.rename":
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/com/wire/xenon/WireClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
*/
public interface WireClient extends Closeable {

Integer KEY_PACKAGES_LOWER_THRESHOLD = 10;
Integer KEY_PACKAGES_REPLENISH_AMOUNT = 50;

/**
* Post a generic message into conversation
*
Expand Down Expand Up @@ -173,6 +176,20 @@ public interface WireClient extends Closeable {
*/
void joinMlsConversation(QualifiedId conversationId, String mlsGroupId);

/**
* When a mls-welcome event is received, this method is called to process it.
* It will create a MLS conversation record in the local core-crypto storage.
* @param welcome base64 encoded welcome message
* @return the MLS group id of the conversation
*/
byte[] processWelcomeMessage(String welcome);

/**
* Checks if the number of available key packages is below the threshold and replenishes them if necessary.
* NOTE: Will make an API call to publish the new key packages if needed.
*/
void checkAndReplenishKeyPackages();

/**
* Invoked by the sdk. Called once when the conversation is created
*
Expand Down
20 changes: 19 additions & 1 deletion src/main/java/com/wire/xenon/WireClientBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,32 @@ public void uploadMlsKeyPackages(int keyPackageAmount) {

@Override
public void joinMlsConversation(QualifiedId conversationId, String mlsGroupId) {
if (cryptoMlsClient.conversationExists(mlsGroupId)) {
Logger.info("Conversation %s already exists, ignore it", conversationId);
return;
}
final byte[] conversationGroupInfo = api.getConversationGroupInfo(conversationId);
final byte[] commitBundle = cryptoMlsClient.createJoinConversationRequest(conversationGroupInfo);
api.commitMlsBundle(commitBundle);
// TODO some error recovery
// TODO Add error recovery, maybe a simple 3 times retry on api calls with quadratic backoff

cryptoMlsClient.markConversationAsJoined(mlsGroupId);
}

@Override
public byte[] processWelcomeMessage(String welcome) {
checkAndReplenishKeyPackages();
return cryptoMlsClient.processWelcomeMessage(welcome);
}

@Override
public void checkAndReplenishKeyPackages() {
if (cryptoMlsClient.validKeyPackageCount() < KEY_PACKAGES_LOWER_THRESHOLD) {
Logger.info("Too few Key packages, replenish them");
cryptoMlsClient.generateKeyPackages(KEY_PACKAGES_REPLENISH_AMOUNT);
}
}

@Override
public PreKey newLastPreKey() throws CryptoException {
return crypto.newLastPreKey();
Expand Down
20 changes: 14 additions & 6 deletions src/main/kotlin/com/wire/xenon/crypto/mls/CryptoMlsClient.kt
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,12 @@ class CryptoMlsClient : Closeable {
return keyPackages.map { it.value }
}

// TODO handle conversation marked as complete, after both welcomeMessage and member-join events have been received,
// TODO remember checking there are enough key packages
// https://wearezeta.atlassian.net/wiki/spaces/ENGINEERIN/pages/563053166/Use+case+being+added+to+a+conversation+MLS
/**
* Process a welcome message, adding this client to a conversation, and return the group id.
*/
fun welcomeMessage(welcome: ByteArray): ByteArray {
val welcomeBundle = runBlocking { mlsClient.processWelcomeMessage(Welcome(welcome)) }
fun processWelcomeMessage(welcome: String): ByteArray {
val welcomeBytes: ByteArray = Base64.getDecoder().decode(welcome)
val welcomeBundle = runBlocking { mlsClient.processWelcomeMessage(Welcome(welcomeBytes)) }
return welcomeBundle.id.value
}

Expand All @@ -77,6 +75,11 @@ class CryptoMlsClient : Closeable {
return packageCount.toLong()
}

fun conversationExists(mlsGroupId: String): Boolean {
val mlsGroupIdBytes: ByteArray = Base64.getDecoder().decode(mlsGroupId)
return runBlocking { mlsClient.conversationExists(MLSGroupId(mlsGroupIdBytes)) }
}

/**
* Create a request to join a conversation.
* Needs to be followed by a call to markConversationAsJoined() to complete the process.
Expand All @@ -97,9 +100,14 @@ class CryptoMlsClient : Closeable {
return bundle.commit.value + bundle.groupInfoBundle.payload.value + (bundle.welcome?.value ?: ByteArray(0))
}

/**
* Completes the process of joining a conversation.
* To be called after createJoinConversationRequest(), and having a successful response from the backend
* while uploading the commitBundle.
*/
fun markConversationAsJoined(mlsGroupId: String) {
val mlsGroupIdBytes: ByteArray = Base64.getDecoder().decode(mlsGroupId)
val commitBundle = runBlocking { mlsClient.mergePendingGroupFromExternalCommit(MLSGroupId(mlsGroupIdBytes)) }
runBlocking { mlsClient.mergePendingGroupFromExternalCommit(MLSGroupId(mlsGroupIdBytes)) }
// TODO support the possibility of merging returning some decrypted messages ?
}

Expand Down
10 changes: 9 additions & 1 deletion src/test/java/com/wire/xenon/MlsClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ public void testMlsClientCreateConversationAndEncrypt() throws IOException {

// Create a new client and join the conversation
CryptoMlsClient mlsClient = new CryptoMlsClient(client1, "pwd");
assert !mlsClient.conversationExists(groupIdBase64);
final byte[] commitBundle = mlsClient.createJoinConversationRequest(groupInfo);
assert commitBundle.length > groupInfo.length;
mlsClient.markConversationAsJoined(groupIdBase64);
assert mlsClient.conversationExists(groupIdBase64);

// Encrypt a message for the joined conversation
String plainMessage = UUID.randomUUID().toString();
final byte[] encryptedMessage = mlsClient.encrypt(groupIdBase64, plainMessage.getBytes());
Expand Down Expand Up @@ -91,17 +94,22 @@ public void testMlsClientsEncryptAndDecrypt() throws IOException {

// Create a new client and join the conversation
CryptoMlsClient mlsClient = new CryptoMlsClient(client1, "pwd");
assert !mlsClient.conversationExists(groupIdBase64);
final byte[] commitBundle = mlsClient.createJoinConversationRequest(groupInfo);
assert commitBundle.length > groupInfo.length;
mlsClient.markConversationAsJoined(groupIdBase64);
assert mlsClient.conversationExists(groupIdBase64);

// Create a second client and make the first client invite the second one
String client2 = "bob1_" + UUID.randomUUID();
CryptoMlsClient mlsClient2 = new CryptoMlsClient(client2, "pwd");
assert !mlsClient2.conversationExists(groupIdBase64);
final List<byte[]> keyPackages = mlsClient2.generateKeyPackages(1);
final byte[] welcome = mlsClient.addMemberToConversation(groupIdBase64, keyPackages);
mlsClient.acceptLatestCommit(groupIdBase64);
mlsClient2.welcomeMessage(welcome);
String welcomeBase64 = new String(Base64.getEncoder().encode(welcome));
mlsClient2.processWelcomeMessage(welcomeBase64);
assert mlsClient2.conversationExists(groupIdBase64);

// Encrypt a message for the joined conversation
String plainMessage = UUID.randomUUID().toString();
Expand Down
56 changes: 56 additions & 0 deletions src/test/java/com/wire/xenon/WireClientBaseTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.wire.xenon;

import com.wire.xenon.backend.models.NewBot;
import com.wire.xenon.crypto.mls.CryptoMlsClient;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.mockito.Mockito.*;

public class WireClientBaseTest {
private WireClientBase wireClientBase;
private WireAPI mockApi;
private CryptoMlsClient mockCryptoMlsClient;
private NewBot mockState;

@BeforeEach
public void setUp() {
mockApi = mock(WireAPI.class);
mockCryptoMlsClient = mock(CryptoMlsClient.class);
mockState = mock(NewBot.class);
wireClientBase = new WireClientBase(mockApi, null, mockCryptoMlsClient, mockState);
}

@Test
public void checkAndReplenishKeyPackages_replenishesWhenBelowThreshold() {
when(mockCryptoMlsClient.validKeyPackageCount()).thenReturn(WireClientBase.KEY_PACKAGES_LOWER_THRESHOLD - 5L);

wireClientBase.checkAndReplenishKeyPackages();

verify(mockCryptoMlsClient, times(1)).generateKeyPackages(WireClientBase.KEY_PACKAGES_REPLENISH_AMOUNT);
}

@Test
public void checkAndReplenishKeyPackages_doesNotReplenishWhenAboveThreshold() {
when(mockCryptoMlsClient.validKeyPackageCount()).thenReturn(WireClientBase.KEY_PACKAGES_LOWER_THRESHOLD + 5L);

wireClientBase.checkAndReplenishKeyPackages();

verify(mockCryptoMlsClient, never()).generateKeyPackages(anyInt());
}

@Test
public void processWelcomeMessage_callsCheckAndReplenishKeyPackages() {
String welcomeMessage = "welcomeMessage";
byte[] expectedResponse = new byte[]{1, 2, 3};

when(mockCryptoMlsClient.processWelcomeMessage(welcomeMessage)).thenReturn(expectedResponse);
when(mockCryptoMlsClient.validKeyPackageCount()).thenReturn(WireClientBase.KEY_PACKAGES_LOWER_THRESHOLD + 5L);

byte[] response = wireClientBase.processWelcomeMessage(welcomeMessage);

verify(mockCryptoMlsClient, times(1)).processWelcomeMessage(welcomeMessage);
assertArrayEquals(expectedResponse, response);
}
}