Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iris: Fix server tests #7133

Merged
merged 9 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import java.io.IOException;

import org.jetbrains.annotations.NotNull;
import javax.validation.constraints.NotNull;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Profile;
import org.springframework.http.HttpHeaders;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,11 @@
package de.tum.in.www1.artemis.repository.iris;

import java.util.List;

import javax.validation.constraints.NotNull;

import org.springframework.data.jpa.repository.JpaRepository;

import de.tum.in.www1.artemis.domain.iris.IrisMessageContent;
import de.tum.in.www1.artemis.web.rest.errors.EntityNotFoundException;

/**
* Spring Data repository for the IrisMessageContent entity.
*/
public interface IrisMessageContentRepository extends JpaRepository<IrisMessageContent, Long> {

List<IrisMessageContent> findAllByMessageId(Long messageId);

@NotNull
default IrisMessageContent findByIdElseThrow(long messageContentId) throws EntityNotFoundException {
return findById(messageContentId).orElseThrow(() -> new EntityNotFoundException("Iris Message Content", messageContentId));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.*;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -67,11 +66,16 @@ public CompletableFuture<IrisMessageResponseDTO> sendRequest(IrisTemplate templa
* @return A list of available Models as IrisModelDTO
*/
public List<IrisModelDTO> getOfferedModels() throws IrisConnectorException {
var response = restTemplate.getForEntity(irisUrl + "/api/v1/models", JsonNode.class);
if (!response.getStatusCode().is2xxSuccessful() || !response.hasBody()) {
try {
var response = restTemplate.getForEntity(irisUrl + "/api/v1/models", JsonNode.class);
if (!response.getStatusCode().is2xxSuccessful() || !response.hasBody()) {
throw new IrisConnectorException("Could not fetch offered models");
}
return Arrays.asList((IrisModelDTO[]) parseResponse(response.getBody(), IrisModelDTO.class.arrayType()));
}
catch (HttpStatusCodeException e) {
throw new IrisConnectorException("Could not fetch offered models");
}
return Arrays.asList((IrisModelDTO[]) parseResponse(response.getBody(), IrisModelDTO.class.arrayType()));
}

private CompletableFuture<IrisMessageResponseDTO> sendRequest(IrisRequestDTO request) {
Expand All @@ -83,7 +87,7 @@ private CompletableFuture<IrisMessageResponseDTO> sendRequest(IrisRequestDTO req
}
return CompletableFuture.completedFuture(parseResponse(response.getBody(), IrisMessageResponseDTO.class));
}
catch (HttpClientErrorException e) {
catch (HttpStatusCodeException e) {
switch (e.getStatusCode()) {
case BAD_REQUEST -> {
var badRequestDTO = parseResponse(objectMapper.readTree(e.getResponseBodyAsString()).get("detail"), IrisErrorResponseDTO.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.Objects;

import org.springframework.stereotype.Service;

Expand Down Expand Up @@ -110,5 +111,29 @@ public Map<String, Object> getTranslationParams() {
public enum IrisWebsocketMessageType {
MESSAGE, ERROR
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
IrisWebsocketDTO that = (IrisWebsocketDTO) other;
return type == that.type && Objects.equals(message, that.message) && Objects.equals(errorMessage, that.errorMessage)
&& Objects.equals(errorTranslationKey, that.errorTranslationKey) && Objects.equals(translationParams, that.translationParams);
}

@Override
public int hashCode() {
return Objects.hash(type, message, errorMessage, errorTranslationKey, translationParams);
}

@Override
public String toString() {
return "IrisWebsocketDTO{" + "type=" + type + ", message=" + message + ", errorMessage='" + errorMessage + '\'' + ", errorTranslationKey='" + errorTranslationKey + '\''
+ ", translationParams=" + translationParams + '}';
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package de.tum.in.www1.artemis.web.rest.admin.iris;

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import de.tum.in.www1.artemis.domain.iris.settings.IrisSettings;
import de.tum.in.www1.artemis.security.annotations.EnforceAdmin;
import de.tum.in.www1.artemis.service.iris.IrisSettingsService;

/**
* REST controller for managing {@link IrisSettings}.
*/
@RestController
@RequestMapping("api/admin/")
public class AdminIrisSettingsResource {

private final IrisSettingsService irisSettingsService;

public AdminIrisSettingsResource(IrisSettingsService irisSettingsService) {
this.irisSettingsService = irisSettingsService;
}

/**
* PUT iris/global-iris-settings: Update the global iris settings.
*
* @param settings the settings to update
* @return the {@link ResponseEntity} with status {@code 200 (Ok)} and with body the updated settings.
*/
@PutMapping("iris/global-iris-settings")
@EnforceAdmin
public ResponseEntity<IrisSettings> updateGlobalSettings(@RequestBody IrisSettings settings) {
var updatedSettings = irisSettingsService.saveGlobalIrisSettings(settings);
return ResponseEntity.ok(updatedSettings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,6 @@ public ResponseEntity<IrisSettings> getProgrammingExerciseSettings(@PathVariable
return ResponseEntity.ok(combinedIrisSettings);
}

/**
* PUT iris/global-iris-settings: Update the global iris settings.
*
* @param settings the settings to update
* @return the {@link ResponseEntity} with status {@code 200 (Ok)} and with body the updated settings.
*/
@PutMapping("iris/global-iris-settings")
@EnforceAdmin
public ResponseEntity<IrisSettings> updateGlobalSettings(@RequestBody IrisSettings settings) {
var updatedSettings = irisSettingsService.saveGlobalIrisSettings(settings);
return ResponseEntity.ok(updatedSettings);
}

/**
* PUT courses/{courseId}/raw-iris-settings: Update the raw iris settings for the course.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export class IrisSettingsService {
* @param settings the settings to set
*/
setGlobalSettings(settings: IrisSettings): Observable<EntityResponseType> {
return this.http.put<IrisSettings>(`${this.resourceUrl}/iris/global-iris-settings`, settings, { observe: 'response' });
return this.http.put<IrisSettings>(`${this.resourceUrl}/admin/iris/global-iris-settings`, settings, { observe: 'response' });
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
@ExtendWith(SpringExtension.class)
@AutoConfigureEmbeddedDatabase
// NOTE: we use a common set of active profiles to reduce the number of application launches during testing. This significantly saves time and memory!
@ActiveProfiles({ SPRING_PROFILE_TEST, "artemis", "bamboo", "bitbucket", "jira", "ldap", "scheduling", "athena", "apollon" })
@ActiveProfiles({ SPRING_PROFILE_TEST, "artemis", "bamboo", "bitbucket", "jira", "ldap", "scheduling", "athena", "apollon", "iris" })
public abstract class AbstractSpringIntegrationBambooBitbucketJiraTest extends AbstractArtemisIntegrationTest {

@SpyBean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.net.URL;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.Map;

import org.mockito.MockitoAnnotations;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -26,7 +27,9 @@
import de.tum.in.www1.artemis.domain.iris.IrisMessage;
import de.tum.in.www1.artemis.domain.iris.IrisMessageContent;
import de.tum.in.www1.artemis.domain.iris.IrisMessageSender;
import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisErrorResponseDTO;
import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO;
import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisModelDTO;

@Component
@Profile("iris")
Expand All @@ -37,7 +40,10 @@ public class IrisRequestMockProvider {
private MockRestServiceServer mockServer;

@Value("${artemis.iris.url}/api/v1/messages")
private URL apiURL;
private URL messagesApiURL;

@Value("${artemis.iris.url}/api/v1/models")
private URL modelsApiURL;

@Autowired
private ObjectMapper mapper;
Expand Down Expand Up @@ -66,9 +72,9 @@ public void reset() throws Exception {
/**
* Mocks response call for the pyris call
*/
public void mockResponse(String responseMessage) throws JsonProcessingException {
public void mockMessageResponse(String responseMessage) throws JsonProcessingException {
if (responseMessage == null) {
mockServer.expect(ExpectedCount.once(), requestTo(apiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess());
mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess());
return;
}
var irisMessage = new IrisMessage();
Expand All @@ -81,10 +87,28 @@ public void mockResponse(String responseMessage) throws JsonProcessingException
var response = new IrisMessageResponseDTO(null, irisMessage);
var json = mapper.writeValueAsString(response);

mockServer.expect(ExpectedCount.once(), requestTo(apiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess(json, MediaType.APPLICATION_JSON));
mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess(json, MediaType.APPLICATION_JSON));
}

public void mockMessageError() throws JsonProcessingException {
mockMessageError(500);
}

public void mockMessageError(int status) throws JsonProcessingException {
var errorResponseDTO = new IrisErrorResponseDTO("Test error");
var json = Map.of("detail", errorResponseDTO);
mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST))
.andRespond(withRawStatus(status).body(mapper.writeValueAsString(json)));
}

public void mockModelsResponse() throws JsonProcessingException {
var irisModelDTO = new IrisModelDTO("TEST_MODEL", "Test model", "Test description");
var irisModelDTOArray = new IrisModelDTO[] { irisModelDTO };
mockServer.expect(ExpectedCount.once(), requestTo(modelsApiURL.toString())).andExpect(method(HttpMethod.GET))
.andRespond(withSuccess(mapper.writeValueAsString(irisModelDTOArray), MediaType.APPLICATION_JSON));
}

public void mockError() {
mockServer.expect(ExpectedCount.once(), requestTo(apiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withBadRequest());
public void mockModelsError() throws JsonProcessingException {
mockServer.expect(ExpectedCount.once(), requestTo(modelsApiURL.toString())).andExpect(method(HttpMethod.GET)).andRespond(withRawStatus(418));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@

import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
import static tech.jhipster.config.JHipsterConstants.SPRING_PROFILE_TEST;

import java.util.Objects;
import java.util.stream.Collectors;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.mockito.ArgumentMatchers;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.test.context.ActiveProfiles;

import de.tum.in.www1.artemis.AbstractSpringIntegrationBambooBitbucketJiraTest;
import de.tum.in.www1.artemis.connector.IrisRequestMockProvider;
import de.tum.in.www1.artemis.domain.Course;
import de.tum.in.www1.artemis.domain.ProgrammingExercise;
import de.tum.in.www1.artemis.domain.iris.IrisMessage;
import de.tum.in.www1.artemis.domain.iris.IrisMessageContent;
import de.tum.in.www1.artemis.domain.iris.IrisTemplate;
import de.tum.in.www1.artemis.exercise.ExerciseUtilService;
import de.tum.in.www1.artemis.exercise.programmingexercise.ProgrammingExerciseUtilService;
Expand All @@ -28,8 +28,7 @@
import de.tum.in.www1.artemis.service.iris.IrisWebsocketService;
import de.tum.in.www1.artemis.user.UserUtilService;

@ActiveProfiles({ SPRING_PROFILE_TEST, "artemis", "bamboo", "bitbucket", "jira", "ldap", "scheduling", "athene", "apollon", "iris" })
public class AbstractIrisIntegrationTest extends AbstractSpringIntegrationBambooBitbucketJiraTest {
public abstract class AbstractIrisIntegrationTest extends AbstractSpringIntegrationBambooBitbucketJiraTest {

@Autowired
protected CourseRepository courseRepository;
Expand Down Expand Up @@ -100,18 +99,13 @@ protected IrisTemplate createDummyTemplate() {
return template;
}

protected void verifyNoMessageWasSentOverWebsocket() throws InterruptedException {
Thread.sleep(1000);
verifyNoInteractions(websocketMessagingService);
}

/**
* Wait for the iris message to be processed by Iris, the LLM mock and the websocket service.
*
* @throws InterruptedException if the thread is interrupted
*/
protected void waitForIrisMessageToBeProcessed() throws InterruptedException {
Thread.sleep(500);
Thread.sleep(100);
}

/**
Expand All @@ -124,8 +118,8 @@ protected void waitForIrisMessageToBeProcessed() throws InterruptedException {
protected void verifyMessageWasSentOverWebsocket(String user, Long sessionId, String message) {
verify(websocketMessagingService, times(1)).sendMessageToUser(eq(user), eq("/topic/iris/sessions/" + sessionId),
ArgumentMatchers.argThat(object -> object instanceof IrisWebsocketService.IrisWebsocketDTO websocketDTO
&& websocketDTO.getType() == IrisWebsocketService.IrisWebsocketDTO.IrisWebsocketMessageType.MESSAGE && websocketDTO.getMessage().getContent().size() == 1
&& Objects.equals(websocketDTO.getMessage().getContent().get(0).getTextContent(), message)));
&& websocketDTO.getType() == IrisWebsocketService.IrisWebsocketDTO.IrisWebsocketMessageType.MESSAGE
&& Objects.equals(websocketDTO.getMessage().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n")), message)));
}

/**
Expand All @@ -138,8 +132,9 @@ protected void verifyMessageWasSentOverWebsocket(String user, Long sessionId, St
protected void verifyMessageWasSentOverWebsocket(String user, Long sessionId, IrisMessage message) {
verify(websocketMessagingService, times(1)).sendMessageToUser(eq(user), eq("/topic/iris/sessions/" + sessionId),
ArgumentMatchers.argThat(object -> object instanceof IrisWebsocketService.IrisWebsocketDTO websocketDTO
&& websocketDTO.getType() == IrisWebsocketService.IrisWebsocketDTO.IrisWebsocketMessageType.MESSAGE && websocketDTO.getMessage().getContent().size() == 1
&& Objects.equals(websocketDTO.getMessage(), message)));
&& websocketDTO.getType() == IrisWebsocketService.IrisWebsocketDTO.IrisWebsocketMessageType.MESSAGE
&& Objects.equals(websocketDTO.getMessage().getContent().stream().map(IrisMessageContent::getTextContent).toList(),
message.getContent().stream().map(IrisMessageContent::getTextContent).toList())));
}

/**
Expand Down Expand Up @@ -175,7 +170,7 @@ protected void verifyNothingElseWasSentOverWebsocket(String user, Long sessionId
}

/**
* Verify that an error was sent through the websocket.
* Verify that no error was sent through the websocket.
*
* @param user the user
* @param sessionId the session id
Expand Down
Loading