Skip to content

Commit

Permalink
feat(dia-1033): infiniteDiscovery queries and temporal direct integra…
Browse files Browse the repository at this point in the history
…tion with OpenSearch (#6306)

* feat(dia-1033): infiniteDiscovery queries and temporal direct integration with OpenSearch

* delete tasteProfileVector from the schema

* remove tasteProfileVector parameter

* change type name

* fix description

* move opensearch related code into old type

* remove deprecation warning

* random curators picks for inital batch of artworks
  • Loading branch information
nickskalkin authored Dec 13, 2024
1 parent 3194328 commit 7569b27
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 56 deletions.
20 changes: 16 additions & 4 deletions _schemaV2.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -16851,13 +16851,19 @@ type Query {
after: String
before: String
certainty: Float

# (Only for when useOpenSearch is true) Exclude these artworks from the response
excludeArtworkIds: [String]
first: Int
last: Int

# (Only for when useOpenSearch is true) These artworks are used to calculate the taste profile vector. Such artworks are excluded from the response
likedArtworkIds: [String]
limit: Int
offset: Int
sort: DiscoverArtworksSort
useRelatedArtworks: Boolean = false
userId: String!
useOpenSearch: Boolean = false
userId: String
): ArtworkConnection

# A namespace external partners (provided by Galaxy)
Expand Down Expand Up @@ -21500,13 +21506,19 @@ type Viewer {
after: String
before: String
certainty: Float

# (Only for when useOpenSearch is true) Exclude these artworks from the response
excludeArtworkIds: [String]
first: Int
last: Int

# (Only for when useOpenSearch is true) These artworks are used to calculate the taste profile vector. Such artworks are excluded from the response
likedArtworkIds: [String]
limit: Int
offset: Int
sort: DiscoverArtworksSort
useRelatedArtworks: Boolean = false
userId: String!
useOpenSearch: Boolean = false
userId: String
): ArtworkConnection

# A namespace external partners (provided by Galaxy)
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"lodash": "4.17.21",
"longjohn": "0.2.12",
"marked": "2.0.1",
"mathjs": "^14.0.1",
"memcached": "2.2.2",
"moment": "2.29.4",
"moment-timezone": "0.5.37",
Expand Down
4 changes: 4 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ const {
METAPHYSICS_PRODUCTION_ENDPOINT,
METAPHYSICS_STAGING_ENDPOINT,
NODE_ENV,
OPENSEARCH_API_BASE,
OPENSEARCH_ARTWORKS_INFINITE_DISCOVERY_INDEX,
PORT,
POSITRON_API_BASE,
PREDICTION_ENDPOINT,
Expand Down Expand Up @@ -196,6 +198,8 @@ export default {
METAPHYSICS_STAGING_ENDPOINT,
METAPHYSICS_PRODUCTION_ENDPOINT,
NODE_ENV: NODE_ENV || "development",
OPENSEARCH_API_BASE,
OPENSEARCH_ARTWORKS_INFINITE_DISCOVERY_INDEX,
PORT: Number(PORT) || 3000,
POSITRON_API_BASE,
PREDICTION_ENDPOINT,
Expand Down
26 changes: 26 additions & 0 deletions src/lib/apis/opensearch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import urljoin from "url-join"
import { assign } from "lodash"
import config from "config"
import fetch from "node-fetch"

const { OPENSEARCH_API_BASE } = config

export const opensearch = async (
path,
_accessToken,
fetchOptions: any = {}
) => {
const headers = {
Accept: "application/json",
"Content-Type": "application/json",
}

const response = await (
await fetch(
urljoin(OPENSEARCH_API_BASE, path),
assign({}, fetchOptions, { headers })
)
).json()

return response
}
30 changes: 30 additions & 0 deletions src/lib/infiniteDiscovery/calculateMeanArtworksVector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import config from "config"
import { opensearch } from "lib/apis/opensearch"
import { mean } from "mathjs"

export const calculateMeanArtworksVector = async (artworkIds) => {
const getVectorsQuery = {
size: artworkIds.length,
_source: ["_id", "vector_embedding"],
query: {
ids: {
values: artworkIds,
},
},
}

const artworksResponse = await opensearch(
`/${config.OPENSEARCH_ARTWORKS_INFINITE_DISCOVERY_INDEX}/_search`,
undefined,
{
method: "POST",
body: JSON.stringify(getVectorsQuery),
}
)

const vectorEmbeddings = artworksResponse.hits?.hits?.map(
(hit) => hit._source.vector_embedding
)

return mean(vectorEmbeddings, 0)
}
55 changes: 55 additions & 0 deletions src/lib/infiniteDiscovery/findSimilarArtworks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import config from "config"
import { opensearch } from "lib/apis/opensearch"

/**
* Perform kNN operation to find artworks similiar to vectorEmbedding
* and then return the artworks loaded by artworksLoader
*
* @param vectorEmbedding - vector embedding of the artwork
* @param size - number of similar artworks to return
* @param excludeArtworkIds - list of artwork ids to exclude from the response
* @param artworksLoader - artworks loader
*/
export const findSimilarArtworks = async (
vectorEmbedding: number[],
size = 10,
excludeArtworkIds: string[] = [],
artworksLoader
) => {
const knnQuery = {
size: size,
_source: ["_id"],
query: {
bool: {
must_not: {
terms: {
_id: excludeArtworkIds,
},
},
should: [
{
knn: {
vector_embedding: {
vector: vectorEmbedding,
k: size,
},
},
},
],
},
},
}

const knnResponse = await opensearch(
`/${config.OPENSEARCH_ARTWORKS_INFINITE_DISCOVERY_INDEX}/_search`,
undefined,
{
method: "POST",
body: JSON.stringify(knnQuery),
}
)

const artworkIds = knnResponse.hits?.hits?.map((hit) => hit._id) || []

return await artworksLoader({ ids: artworkIds })
}
27 changes: 27 additions & 0 deletions src/lib/infiniteDiscovery/getInitialArtworksSample.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { opensearch } from "lib/apis/opensearch"

export const getInitialArtworksSample = async (limit, artworksLoader) => {
// initial artworks sample comes from indexed curators picks, but
// in future we plan to come up with a more sophisticated approach
const curatorsPicks = await opensearch(`/curators_picks/_search`, undefined, {
method: "POST",
body: JSON.stringify({
size: limit,
query: {
function_score: {
functions: [
{
random_score: {
seed: Math.floor(Math.random() * 1000),
},
},
],
},
},
}),
})

const artworkIds = curatorsPicks.hits?.hits?.map((hit) => hit._id) || []

return await artworksLoader({ ids: artworkIds })
}
91 changes: 39 additions & 52 deletions src/schema/v2/infiniteDiscovery/discoverArtworks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import {
GraphQLInt,
GraphQLEnumType,
GraphQLFloat,
GraphQLNonNull,
GraphQLBoolean,
GraphQLList,
} from "graphql"
import { ResolverContext } from "types/graphql"
import { artworkConnection } from "../artwork"
import { connectionFromArray } from "graphql-relay"
import { pageable } from "relay-cursor-paging"
import { sampleSize, shuffle, uniqBy } from "lodash"
import { sampleSize, uniqBy } from "lodash"
import {
insertSampleCuratedWorks,
getUserFilterList,
Expand All @@ -23,11 +23,14 @@ import {
getArtworkIds,
getFilteredIdList,
} from "lib/infiniteDiscovery/weaviate"
import { getInitialArtworksSample } from "lib/infiniteDiscovery/getInitialArtworksSample"
import { calculateMeanArtworksVector } from "lib/infiniteDiscovery/calculateMeanArtworksVector"
import { findSimilarArtworks } from "lib/infiniteDiscovery/findSimilarArtworks"

export const DiscoverArtworks: GraphQLFieldConfig<void, ResolverContext> = {
type: artworkConnection.connectionType,
args: pageable({
userId: { type: GraphQLNonNull(GraphQLString) },
userId: { type: GraphQLString },
limit: { type: GraphQLInt },
offset: { type: GraphQLInt },
certainty: { type: GraphQLFloat },
Expand All @@ -48,19 +51,22 @@ export const DiscoverArtworks: GraphQLFieldConfig<void, ResolverContext> = {
},
}),
},
useRelatedArtworks: { type: GraphQLBoolean, defaultValue: false },
useOpenSearch: { type: GraphQLBoolean, defaultValue: false },
excludeArtworkIds: {
type: new GraphQLList(GraphQLString),
description:
"(Only for when useOpenSearch is true) Exclude these artworks from the response",
},
likedArtworkIds: {
type: new GraphQLList(GraphQLString),
description:
"(Only for when useOpenSearch is true) These artworks are used to calculate the taste profile vector. Such artworks are excluded from the response",
},
}),
resolve: async (
_root,
args,
{
weaviateCreateObjectLoader,
weaviateGraphqlLoader,
artworksLoader,
relatedArtworksLoader,
marketingCollectionLoader,
savedArtworksLoader,
}
{ weaviateCreateObjectLoader, weaviateGraphqlLoader, artworksLoader }
) => {
if (
!artworksLoader ||
Expand All @@ -76,51 +82,32 @@ export const DiscoverArtworks: GraphQLFieldConfig<void, ResolverContext> = {
offset = 0,
certainty = 0.5,
sort,
useRelatedArtworks,
useOpenSearch,
} = args

if (useRelatedArtworks) {
if (!savedArtworksLoader) {
return new Error("You need to be signed in to perform this action")
}

const { body: savedArtworks } = await savedArtworksLoader({
size: 28,
sort: "-position",
user_id: userId,
private: true,
})

const savedArtworkIds = savedArtworks.map((artwork) => artwork.id)

const curatedArtworksCollection = await marketingCollectionLoader(
"curators-picks"
)

const curatedArtworkIds = curatedArtworksCollection.artwork_ids
if (useOpenSearch) {
const { excludeArtworkIds, likedArtworkIds } = args

// Select two random artworks from curated artworks
const randomCuratedArtworksIds = sampleSize(curatedArtworkIds, 2)
let result = []

const curatedArtworks = await artworksLoader({
ids: randomCuratedArtworksIds,
})

// use curated artworks if there are no saved artworks
const finalArtworkIds =
savedArtworkIds.length > 0 ? [...savedArtworkIds] : curatedArtworkIds

// Limit the number of artwork IDs to a maximum of 10
const queryArtworkIds = finalArtworkIds.slice(0, 10)

const relatedArtworks = await relatedArtworksLoader({
artwork_id: queryArtworkIds,
size: 8,
})
if (!likedArtworkIds) {
result = await getInitialArtworksSample(limit, artworksLoader)
} else {
const tasteProfileVector = await calculateMeanArtworksVector(
likedArtworkIds
)
// we don't want to recommend the same artworks that the user already liked
excludeArtworkIds.push(...likedArtworkIds)

result = await findSimilarArtworks(
tasteProfileVector,
limit,
excludeArtworkIds,
artworksLoader
)
}

// inject curated artworks and shuffle the list
const shuffledArtworks = shuffle([...relatedArtworks, ...curatedArtworks])
return connectionFromArray(shuffledArtworks, args)
return connectionFromArray(result, args)
}

const userQueryResponse = await weaviateGraphqlLoader({
Expand Down
Loading

0 comments on commit 7569b27

Please sign in to comment.