diff --git a/src/main/java/io/supertokens/inmemorydb/queries/SessionQueries.java b/src/main/java/io/supertokens/inmemorydb/queries/SessionQueries.java index b325f53b2..dbcb6e215 100644 --- a/src/main/java/io/supertokens/inmemorydb/queries/SessionQueries.java +++ b/src/main/java/io/supertokens/inmemorydb/queries/SessionQueries.java @@ -106,28 +106,43 @@ public static SessionInfo getSessionInfo_Transaction(Start start, Connection con ((ConnectionWithLocks) con).lock( tenantIdentifier.getAppId() + "~" + tenantIdentifier.getTenantId() + "~" + sessionHandle + Config.getConfig(start).getSessionInfoTable()); - + // we do this as two separate queries and not one query with left join cause psql does not + // support left join with for update if the right table returns null. String QUERY = - "SELECT sess.session_handle, sess.user_id, sess.refresh_token_hash_2, sess.session_data, sess" + - ".expires_at, " - + - "sess.created_at_time, sess.jwt_user_payload, sess.use_static_key, users" + - ".primary_or_recipe_user_id FROM " + + "SELECT session_handle, user_id, refresh_token_hash_2, session_data, " + + "expires_at, created_at_time, jwt_user_payload, use_static_key FROM " + getConfig(start).getSessionInfoTable() - + " AS sess LEFT JOIN " + getConfig(start).getUsersTable() + - " as users ON sess.app_id = users.app_id AND sess.user_id = users.user_id WHERE sess.app_id =" + - " ? AND " + - "sess.tenant_id = ? AND sess.session_handle = ?"; - return execute(con, QUERY, pst -> { + + " WHERE app_id = ? AND tenant_id = ? AND session_handle = ?"; + SessionInfo sessionInfo = execute(con, QUERY, pst -> { pst.setString(1, tenantIdentifier.getAppId()); pst.setString(2, tenantIdentifier.getTenantId()); pst.setString(3, sessionHandle); }, result -> { if (result.next()) { - return SessionInfoRowMapper.getInstance().mapOrThrow(result); + return SessionInfoRowMapper.getInstance().mapOrThrow(result, false); } return null; }); + + if (sessionInfo == null) { + return null; + } + + QUERY = "SELECT primary_or_recipe_user_id FROM " + getConfig(start).getUsersTable() + + " WHERE app_id = ? AND user_id = ?"; + + return execute(con, QUERY, pst -> { + pst.setString(1, tenantIdentifier.getAppId()); + pst.setString(2, sessionInfo.recipeUserId); + }, result -> { + if (result.next()) { + String primaryUserId = result.getString("primary_or_recipe_user_id"); + if (primaryUserId != null) { + sessionInfo.userId = primaryUserId; + } + } + return sessionInfo; + }); } public static void updateSessionInfo_Transaction(Start start, Connection con, TenantIdentifier tenantIdentifier, @@ -331,7 +346,7 @@ public static SessionInfo getSession(Start start, TenantIdentifier tenantIdentif pst.setString(3, sessionHandle); }, result -> { if (result.next()) { - return SessionInfoRowMapper.getInstance().mapOrThrow(result); + return SessionInfoRowMapper.getInstance().mapOrThrow(result, true); } return null; }); @@ -386,7 +401,7 @@ public static void removeAccessTokenSigningKeysBefore(Start start, AppIdentifier }); } - static class SessionInfoRowMapper implements RowMapper { + static class SessionInfoRowMapper { public static final SessionInfoRowMapper INSTANCE = new SessionInfoRowMapper(); private SessionInfoRowMapper() { @@ -396,19 +411,23 @@ private static SessionInfoRowMapper getInstance() { return INSTANCE; } - @Override - public SessionInfo map(ResultSet result) throws Exception { + public SessionInfo mapOrThrow(ResultSet result, boolean hasPrimaryOrRecipeUserId) throws StorageQueryException { JsonParser jp = new JsonParser(); // if result.getString("primary_or_recipe_user_id") is null, it will be handled by SessionInfo // constructor - return new SessionInfo(result.getString("session_handle"), - result.getString("primary_or_recipe_user_id"), - result.getString("user_id"), - result.getString("refresh_token_hash_2"), - jp.parse(result.getString("session_data")).getAsJsonObject(), - result.getLong("expires_at"), - jp.parse(result.getString("jwt_user_payload")).getAsJsonObject(), - result.getLong("created_at_time"), result.getBoolean("use_static_key")); + try { + return new SessionInfo(result.getString("session_handle"), + hasPrimaryOrRecipeUserId ? result.getString("primary_or_recipe_user_id") : + result.getString("user_id"), + result.getString("user_id"), + result.getString("refresh_token_hash_2"), + jp.parse(result.getString("session_data")).getAsJsonObject(), + result.getLong("expires_at"), + jp.parse(result.getString("jwt_user_payload")).getAsJsonObject(), + result.getLong("created_at_time"), result.getBoolean("use_static_key")); + } catch (Exception e) { + throw new StorageQueryException(e); + } } }