Skip to content

Commit

Permalink
Replace spring session jdbc with own in-memory map implementation
Browse files Browse the repository at this point in the history
related to #691
  • Loading branch information
honnel authored and derTobsch committed May 29, 2024
1 parent 34ddf37 commit 2f79895
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
<!-- SESSION -->
<dependency>
<groupId>org.springframework.session</groupId>
<artifactId>spring-session-jdbc</artifactId>
<artifactId>spring-session-core</artifactId>
</dependency>

<!-- Web -->
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package de.focusshift.zeiterfassung.security;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.session.FindByIndexNameSessionRepository;
import org.springframework.session.MapSession;
import org.springframework.session.Session;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static java.util.stream.Collectors.toMap;

class IndexNameMapSessionRepository implements FindByIndexNameSessionRepository<Session> {

private final ConcurrentHashMap<String, Session> sessions = new ConcurrentHashMap<>();
// <AttributeName, <AttributeValue, SessionId>>
private final ConcurrentHashMap<String, ConcurrentHashMap<String, String>> indexMap = new ConcurrentHashMap<>();

@Override
public Session findById(String id) {
return sessions.get(id);
}

@Override
public Session createSession() {

return new MapSession();
}

@Override
public void save(Session session) {

final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication instanceof OAuth2AuthenticationToken token) {
final OAuth2User oAuth2User = token.getPrincipal();
if (oAuth2User instanceof OidcUser oidcUser) {
session.setAttribute(PRINCIPAL_NAME_INDEX_NAME, oidcUser.getSubject());
}
}

sessions.put(session.getId(), session);
for (String attributeName : session.getAttributeNames()) {
indexMap.computeIfAbsent(attributeName, k -> new ConcurrentHashMap<>()).put(session.getId(), session.getAttribute(attributeName).toString());
}
}

@Override
public void deleteById(String id) {
Session session = sessions.remove(id);
if (session != null) {
for (String attributeName : session.getAttributeNames()) {
final ConcurrentHashMap<String, String> index = indexMap.get(attributeName);
if (index != null) {
index.remove(id);
}
}
}
}

@Override
public Map<String, Session> findByIndexNameAndIndexValue(String indexName, String indexValue) {
ConcurrentHashMap<String, String> index = indexMap.get(indexName);
if (index != null) {
return index.entrySet().stream()
.filter(entry -> entry.getValue().equals(indexValue))
.collect(toMap(Map.Entry::getKey, entry -> sessions.get(entry.getKey())));
}
return Map.of();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package de.focusshift.zeiterfassung.security;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.session.FindByIndexNameSessionRepository;
import org.springframework.session.Session;
import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession;

@Configuration
@EnableSpringHttpSession
class SessionConfiguration {

@Bean
FindByIndexNameSessionRepository<Session> sessionRepository() {
return new IndexNameMapSessionRepository();
}
}
3 changes: 0 additions & 3 deletions src/main/resources/application.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ spring:
hibernate:
ddl-auto: none
open-in-view: false
session:
jdbc:
initialize-schema: always
liquibase:
change-log: classpath:/db/changelog/db.changelog-main.xml
mail:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package de.focusshift.zeiterfassung.security;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser;
import org.springframework.session.MapSession;
import org.springframework.session.Session;

import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.entry;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.springframework.session.FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME;

class IndexNameMapSessionRepositoryTest {

private IndexNameMapSessionRepository sut;

@BeforeEach
void setUp() {
sut = new IndexNameMapSessionRepository();
}

@Test
void createSession() {
assertThat(sut.createSession()).isInstanceOf(MapSession.class);
}

@Test
void saveAndFindById() {

final MapSession session = new MapSession("id");

sut.save(session);

assertThat(sut.findById("id")).isEqualTo(session);
}

@Test
void saveAndFindByIndexNameAndIndexValue() {

final SecurityContext context = SecurityContextHolder.getContext();
context.setAuthentication(prepareOAuth2Authentication("user1"));
final MapSession session1 = new MapSession("id1");
final MapSession session2 = new MapSession("id2");
sut.save(session1);
sut.save(session2);

context.setAuthentication(prepareOAuth2Authentication("user2"));
final MapSession session3 = new MapSession("id3");
sut.save(session3);

final Map<String, Session> byIndexNameAndIndexValue = sut.findByIndexNameAndIndexValue(PRINCIPAL_NAME_INDEX_NAME, "user1");
assertThat(byIndexNameAndIndexValue)
.containsExactlyInAnyOrderEntriesOf(Map.of("id1", session1, "id2", session2))
.doesNotContain(entry("id3", session3));

assertThat(sut.findByIndexNameAndIndexValue("unknown-index-name", "unknown-index-value")).isEmpty();
}

@Test
void deleteById() {

final SecurityContext context = SecurityContextHolder.getContext();
context.setAuthentication(prepareOAuth2Authentication("user"));
final MapSession session = new MapSession("id");

sut.save(session);

assertThat(sut.findById("id")).isEqualTo(session);
assertThat(sut.findByIndexNameAndIndexValue(PRINCIPAL_NAME_INDEX_NAME, "user")).containsExactlyInAnyOrderEntriesOf(Map.of("id", session));

sut.deleteById(session.getId());

assertThat(sut.findById("id")).isNull();
assertThat(sut.findByIndexNameAndIndexValue(PRINCIPAL_NAME_INDEX_NAME, "user")).isEmpty();
}

private OAuth2AuthenticationToken prepareOAuth2Authentication(String subject) {

final DefaultOidcUser oidcUser = new DefaultOidcUser(
List.of(),
OidcIdToken.withTokenValue("token-value").subject(subject).build()
);

final OAuth2AuthenticationToken authentication = mock(OAuth2AuthenticationToken.class);
when(authentication.getPrincipal()).thenReturn(oidcUser);

return authentication;
}
}

0 comments on commit 2f79895

Please sign in to comment.