From 956689bd08a8ed4036855a3065ff56777d702029 Mon Sep 17 00:00:00 2001 From: Harish Mohan Raj Date: Fri, 1 Dec 2023 12:51:42 +0530 Subject: [PATCH] Support for multiple teams in single conversation (#51) --- main.wasp | 39 ++-- .../migration.sql | 18 ++ src/client/AccountPage.tsx | 90 +++++--- src/client/PricingPage.tsx | 81 ++++--- src/client/chatConversationHelper.tsx | 72 ++++++ src/client/components/ConversationList.tsx | 197 ++++++++++------ src/client/components/ConversationWrapper.tsx | 167 +++++--------- src/client/helpers.tsx | 72 ++++++ src/client/tests/helpers.test.tsx | 175 ++++++++++++++ src/server/actions.ts | 218 ++++++------------ src/server/config.js | 4 + src/server/queries.ts | 66 ++---- src/server/webSocket.js | 35 +-- 13 files changed, 776 insertions(+), 458 deletions(-) create mode 100644 migrations/20231201044043_add_team_details_to_conversation_model/migration.sql create mode 100644 src/client/chatConversationHelper.tsx create mode 100644 src/client/helpers.tsx create mode 100644 src/client/tests/helpers.test.tsx diff --git a/main.wasp b/main.wasp index e076da7..83f1ed8 100644 --- a/main.wasp +++ b/main.wasp @@ -69,7 +69,7 @@ app chatApp { ("markdown-to-jsx", "7.3.2"), ], webSocket: { - fn: import { webSocketFn } from "@server/webSocket.js" + fn: import { checkTeamStatusAndUpdateInDB } from "@server/webSocket.js" }, } @@ -128,15 +128,19 @@ entity Chat {=psl psl=} entity Conversation {=psl - id Int @id @default(autoincrement()) - conversation Json - createdAt DateTime @default(now()) - updatedAt DateTime @updatedAt - status String? - chat Chat? @relation(fields: [chatId], references: [id]) - chatId Int? - user User? @relation(fields: [userId], references: [id]) - userId Int? + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + message String + role String + team_id Int? + team_name String? + team_status String? + is_question_from_agent Boolean @default(false) + chat Chat? @relation(fields: [chatId], references: [id]) + chatId Int? + user User? @relation(fields: [userId], references: [id]) + userId Int? psl=} @@ -208,12 +212,6 @@ page ChatPage { // 📝 Actions aka Mutations -action generateGptResponse { - fn: import { generateGptResponse } from "@server/actions.js", - // entities: [User, RelatedObject] - entities: [User] -} - action stripePayment { fn: import { stripePayment } from "@server/actions.js", entities: [User] @@ -224,8 +222,13 @@ action createChat { entities: [Chat, Conversation] } -action updateConversation { - fn: import { updateConversation } from "@server/actions.js", +action addNewConversationToChat { + fn: import { addNewConversationToChat } from "@server/actions.js", + entities: [Chat, Conversation] +} + +action updateExistingConversation { + fn: import { updateExistingConversation } from "@server/actions.js", entities: [Chat, Conversation] } diff --git a/migrations/20231201044043_add_team_details_to_conversation_model/migration.sql b/migrations/20231201044043_add_team_details_to_conversation_model/migration.sql new file mode 100644 index 0000000..6950631 --- /dev/null +++ b/migrations/20231201044043_add_team_details_to_conversation_model/migration.sql @@ -0,0 +1,18 @@ +/* + Warnings: + + - You are about to drop the column `conversation` on the `Conversation` table. All the data in the column will be lost. + - You are about to drop the column `status` on the `Conversation` table. All the data in the column will be lost. + - Added the required column `message` to the `Conversation` table without a default value. This is not possible if the table is not empty. + - Added the required column `role` to the `Conversation` table without a default value. This is not possible if the table is not empty. + +*/ +-- AlterTable +ALTER TABLE "Conversation" DROP COLUMN "conversation", +DROP COLUMN "status", +ADD COLUMN "is_question_from_agent" BOOLEAN NOT NULL DEFAULT false, +ADD COLUMN "message" TEXT NOT NULL, +ADD COLUMN "role" TEXT NOT NULL, +ADD COLUMN "team_id" INTEGER, +ADD COLUMN "team_name" TEXT, +ADD COLUMN "team_status" TEXT; diff --git a/src/client/AccountPage.tsx b/src/client/AccountPage.tsx index e8b7c5d..4277c33 100644 --- a/src/client/AccountPage.tsx +++ b/src/client/AccountPage.tsx @@ -1,12 +1,13 @@ -import { User } from '@wasp/entities'; -import { useQuery } from '@wasp/queries' +import { User } from "@wasp/entities"; +import { useQuery } from "@wasp/queries"; // import getRelatedObjects from '@wasp/queries/getRelatedObjects' -import logout from '@wasp/auth/logout'; -import stripePayment from '@wasp/actions/stripePayment'; -import { useState, Dispatch, SetStateAction } from 'react'; +import logout from "@wasp/auth/logout"; +import stripePayment from "@wasp/actions/stripePayment"; +import { useState, Dispatch, SetStateAction } from "react"; // get your own link from your stripe dashboard: https://dashboard.stripe.com/settings/billing/portal -const CUSTOMER_PORTAL_LINK = 'https://billing.stripe.com/p/login/test_8wM8x17JN7DT4zC000'; +const CUSTOMER_PORTAL_LINK = + "https://billing.stripe.com/p/login/test_8wM8x17JN7DT4zC000"; export default function Example({ user }: { user: User }) { const [isLoading, setIsLoading] = useState(false); @@ -14,16 +15,22 @@ export default function Example({ user }: { user: User }) { // const { data: relatedObjects, isLoading: isLoadingRelatedObjects } = useQuery(getRelatedObjects) return ( -
-
-
-

Account Information

+
+
+
+

+ Account Information +

-
-
-
-
Email address
-
{user.email}
+
+
+
+
+ Email address +
+
+ {user.email} +
{/*
Your Plan
@@ -46,8 +53,8 @@ export default function Example({ user }: { user: User }) {
I'm a cool customer.
*/} {/*
*/} - {/*
Most Recent User RelatedObject
*/} - {/*
+ {/*
Most Recent User RelatedObject
*/} + {/*
{!!relatedObjects && relatedObjects.length > 0 ? relatedObjects[relatedObjects.length - 1].content : "You don't have any at this time."} @@ -56,10 +63,10 @@ export default function Example({ user }: { user: User }) {
-
+
@@ -68,42 +75,63 @@ export default function Example({ user }: { user: User }) { ); } -function BuyMoreButton({ isLoading, setIsLoading }: { isLoading: boolean, setIsLoading: Dispatch> }) { +function BuyMoreButton({ + isLoading, + setIsLoading, +}: { + isLoading: boolean; + setIsLoading: Dispatch>; +}) { const handleClick = async () => { try { setIsLoading(true); const stripeResults = await stripePayment(); if (stripeResults?.sessionUrl) { - window.open(stripeResults.sessionUrl, '_self'); + window.open(stripeResults.sessionUrl, "_self"); } - } catch (error: any) { - alert(error?.message ?? 'Something went wrong.') + alert(error?.message ?? "Something went wrong."); } finally { setIsLoading(false); } }; return ( -
-
); } -function CustomerPortalButton({ isLoading, setIsLoading }: { isLoading: boolean, setIsLoading: Dispatch> }) { +function CustomerPortalButton({ + isLoading, + setIsLoading, +}: { + isLoading: boolean; + setIsLoading: Dispatch>; +}) { const handleClick = () => { setIsLoading(true); - window.open(CUSTOMER_PORTAL_LINK, '_blank'); + window.open(CUSTOMER_PORTAL_LINK, "_blank"); setIsLoading(false); }; return ( -
-
); diff --git a/src/client/PricingPage.tsx b/src/client/PricingPage.tsx index 7566a39..5a93c19 100644 --- a/src/client/PricingPage.tsx +++ b/src/client/PricingPage.tsx @@ -1,24 +1,28 @@ -import { AiOutlineCheck } from 'react-icons/ai'; -import stripePayment from '@wasp/actions/stripePayment'; -import { useState } from 'react'; +import { AiOutlineCheck } from "react-icons/ai"; +import stripePayment from "@wasp/actions/stripePayment"; +import { useState } from "react"; const prices = [ { - name: 'Credits', - id: 'credits', - href: '', - price: '$2.95', - description: 'Buy credits to use for your projects.', - features: ['10 credits', 'Use them any time', 'No expiration date'], + name: "Credits", + id: "credits", + href: "", + price: "$2.95", + description: "Buy credits to use for your projects.", + features: ["10 credits", "Use them any time", "No expiration date"], disabled: true, }, { - name: 'Monthly Subscription', - id: 'monthly', - href: '#', - priceMonthly: '$9.99', - description: 'Get unlimited usage for your projects.', - features: ['Unlimited usage of all features', 'Priority support', 'Cancel any time'], + name: "Monthly Subscription", + id: "monthly", + href: "#", + priceMonthly: "$9.99", + description: "Get unlimited usage for your projects.", + features: [ + "Unlimited usage of all features", + "Priority support", + "Cancel any time", + ], }, ]; @@ -30,43 +34,55 @@ export default function PricingPage() { try { const response = await stripePayment(); if (response?.sessionUrl) { - window.open(response.sessionUrl, '_self'); + window.open(response.sessionUrl, "_self"); } } catch (e) { - alert('Something went wrong. Please try again.'); + alert("Something went wrong. Please try again."); console.error(e); } finally { setIsLoading(false); } }; - return ( -
-
-
+
+
+
{prices.map((price) => (
-

+

{price.name}

-
- +
+ {price.priceMonthly || price.price} {price.priceMonthly && ( - /month + + /month + )}
-

{price.description}

-
    +

    + {price.description} +

    +
      {price.features.map((feature) => ( -
    • -
    • +
    • ))} @@ -77,10 +93,11 @@ export default function PricingPage() { aria-describedby={price.id} disabled={price.disabled} className={`${ - price.disabled && 'disabled:opacity-25 disabled:cursor-not-allowed' + price.disabled && + "disabled:opacity-25 disabled:cursor-not-allowed" } mt-8 block rounded-md bg-yellow-400 px-3.5 py-2 text-center text-sm font-semibold leading-6 text-black shadow-sm hover:bg-yellow-500 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-yellow-600`} > - {isLoading ? 'Loading...' : 'Buy Now'} + {isLoading ? "Loading..." : "Buy Now"}
))} diff --git a/src/client/chatConversationHelper.tsx b/src/client/chatConversationHelper.tsx new file mode 100644 index 0000000..c7bb26a --- /dev/null +++ b/src/client/chatConversationHelper.tsx @@ -0,0 +1,72 @@ +import getAgentResponse from "@wasp/actions/getAgentResponse"; +import addNewConversationToChat from "@wasp/actions/addNewConversationToChat"; +import updateExistingConversation from "@wasp/actions/updateExistingConversation"; +import { prepareOpenAIRequest } from "./helpers"; + +export async function addUserMessageToConversation( + chat_id: number, + userQuery: string, + conv_id?: number, + team_name?: string, + team_id?: number +) { + let userMessage = userQuery; + let isAnswerToAgentQuestion = false; + let user_answer_to_team_id = null; + if (team_id) { + const payload = { + chat_id: chat_id, + conv_id: conv_id, + is_question_from_agent: false, + team_status: null, + }; + await updateExistingConversation(payload); + userMessage = `

Replying to ${team_name}:



` + userQuery; + isAnswerToAgentQuestion = true; + user_answer_to_team_id = team_id; + } + + const payload = { + chat_id: chat_id, + message: userMessage, + role: "user", + }; + + const updatedConversation: any = await addNewConversationToChat(payload); + + const [messages, latestConversationID]: [ + messages: any, + latestConversationID: number + ] = prepareOpenAIRequest(updatedConversation); + return [ + messages, + latestConversationID, + isAnswerToAgentQuestion, + user_answer_to_team_id, + ]; +} + +export async function addAgentMessageToConversation( + chat_id: number, + message: any, + conv_id: number, + isAnswerToAgentQuestion: boolean, + userResponseToTeamId: number | null | undefined +) { + const response: any = await getAgentResponse({ + message: message, + conv_id: conv_id, + isAnswerToAgentQuestion: isAnswerToAgentQuestion, + userResponseToTeamId: userResponseToTeamId, + }); + + const openAIResponse = { + chat_id: Number(chat_id), + message: response.content, + role: "assistant", + ...(response.team_name && { team_name: response.team_name }), + ...(response.team_id && { team_id: response.team_id }), + ...(response.team_status && { team_status: response.team_status }), + }; + await addNewConversationToChat(openAIResponse); +} diff --git a/src/client/components/ConversationList.tsx b/src/client/components/ConversationList.tsx index 508b930..3347573 100644 --- a/src/client/components/ConversationList.tsx +++ b/src/client/components/ConversationList.tsx @@ -1,85 +1,146 @@ +import React from "react"; +import { useState } from "react"; + import Markdown from "markdown-to-jsx"; import type { Conversation } from "@wasp/entities"; import logo from "../static/captn-logo.png"; -export default function ConversationsList(conversations: Conversation[]) { +type ConversationsListProps = { + conversations: Conversation[]; + onInlineFormSubmit: ( + userQuery: string, + conv_id: number, + team_name: string, + team_id: number + ) => void; +}; + +export default function ConversationsList({ + conversations, + onInlineFormSubmit, +}: ConversationsListProps) { return (
- { - // Todo: remove the below ignore comment - // @ts-ignore - conversations.conversations.map((conversation, idx) => { - const conversationBgColor = - conversation.role === "user" - ? "captn-light-blue" - : "captn-dark-blue"; - const conversationTextColor = - conversation.role === "user" - ? "captn-dark-blue" - : "captn-light-cream"; - const conversationLogo = - conversation.role === "user" ? ( -
-
You
-
- ) : ( - captn logo - ); - return ( -
+ {conversations.map((conversation, idx) => { + const conversationBgColor = + conversation.role === "user" ? "captn-light-blue" : "captn-dark-blue"; + const conversationTextColor = + conversation.role === "user" + ? "captn-dark-blue" + : "captn-light-cream"; + const conversationLogo = + conversation.role === "user" ? ( +
+
You
+
+ ) : ( + captn logo + ); + + const handleFormSubmit = ( + event: React.FormEvent, + conv_id: number, + team_name: string, + team_id: number + ) => { + event.preventDefault(); + const target = event.target as HTMLFormElement; + const userQuery = target.userQuery.value; + target.reset(); + onInlineFormSubmit(userQuery, conv_id, team_name, team_id); + }; + return ( +
+
-
- - {conversationLogo} - -
- {conversation.content} -
+ {conversationLogo} + +
+ {conversation.message}
+ {conversation.is_question_from_agent && ( +
+ handleFormSubmit( + event, + conversation.id, + conversation.team_name, + conversation.team_id + ) + } + className="relative block w-full mt-[15px]" + > + +
+ + +
+
+ )}
- ); - }) - } +
+ ); + })}
); } diff --git a/src/client/components/ConversationWrapper.tsx b/src/client/components/ConversationWrapper.tsx index d9bc096..b2237fc 100644 --- a/src/client/components/ConversationWrapper.tsx +++ b/src/client/components/ConversationWrapper.tsx @@ -1,47 +1,31 @@ import React from "react"; -import { useState, useRef, useEffect, useCallback } from "react"; +import { useState, useCallback, useRef } from "react"; import { useParams } from "react-router"; -import { Redirect, useLocation } from "react-router-dom"; +import { Redirect } from "react-router-dom"; import { useQuery } from "@wasp/queries"; -import updateConversation from "@wasp/actions/updateConversation"; -import getAgentResponse from "@wasp/actions/getAgentResponse"; import getConversations from "@wasp/queries/getConversations"; import ConversationsList from "./ConversationList"; import Loader from "./Loader"; -// A custom hook that builds on useLocation to parse -// the query string for you. -function getQueryParam(paramName: string) { - const { search } = useLocation(); - return React.useMemo(() => new URLSearchParams(search), [search]).get( - paramName - ); -} - -export function setRedirectMsg(formInputRef: any, loginMsgQuery: string) { - if (loginMsgQuery) { - formInputRef.value = decodeURIComponent(loginMsgQuery); - } -} +import { + addUserMessageToConversation, + addAgentMessageToConversation, +} from "../chatConversationHelper"; -export function triggerSubmit( - node: any, - loginMsgQuery: string, - formInputRef: any -) { - if (loginMsgQuery && formInputRef && formInputRef.value !== "") { - node.click(); - } -} +import { setRedirectMsg, getQueryParam, triggerSubmit } from "../helpers"; export default function ConversationWrapper() { - // Todo: remove the below ignore comment - // @ts-ignore - const { id } = useParams(); + const { id }: { id: string } = useParams(); const [isLoading, setIsLoading] = useState(false); - const chatContainerRef = useRef(null); + const { data: conversations } = useQuery( + getConversations, + { + chatId: Number(id), + }, + { enabled: !!id, refetchInterval: 1000 } + ); const loginMsgQuery: any = getQueryParam("msg"); const formInputRef = useCallback( @@ -62,80 +46,64 @@ export default function ConversationWrapper() { [loginMsgQuery, formInputRef] ); - const { - data: conversations, - isLoading: isConversationLoading, - error: isConversationError, - } = useQuery( - getConversations, - { - chatId: Number(id), - }, - { enabled: !!id, refetchInterval: 1000 } - ); - - useEffect(() => { - // if (chatContainerRef.current) { - // // Todo: remove the below ignore comment - // // @ts-ignore - // chatContainerRef.current.scrollTop = - // // Todo: remove the below ignore comment - // // @ts-ignore - // chatContainerRef.current.scrollHeight; - // } - }, [conversations]); - - async function callAgent(userQuery: string) { + async function addMessagesToConversation( + userQuery: string, + conv_id?: number, + team_name?: string, + team_id?: number + ) { try { - // 1. add new conversation to table - const payload = { - // @ts-ignore - conversation_id: conversations.id, - conversations: [...[{ role: "user", content: userQuery }]], - }; - - const updatedConversation = await updateConversation(payload); - // 2. call backend python server to get agent response + const [ + messages, + conversation_id, + isAnswerToAgentQuestion, + user_answer_to_team_id, + ] = await addUserMessageToConversation( + Number(id), + userQuery, + conv_id, + team_name, + team_id + ); setIsLoading(true); - const response = await getAgentResponse({ - message: updatedConversation["conversation"], - conv_id: updatedConversation.id, - // @ts-ignore - is_answer_to_agent_question: updatedConversation.status === "pause", - }); - // 3. add agent response as new conversation in the table - const openAIResponse = { - // @ts-ignore - conversation_id: conversations.id, - conversations: [ - // @ts-ignore - ...[{ role: "assistant", content: response.content }], - ], - // @ts-ignore - ...(response.team_status && { status: response.team_status }), - }; - await updateConversation(openAIResponse); + await addAgentMessageToConversation( + Number(id), + messages, + conversation_id, + isAnswerToAgentQuestion, + user_answer_to_team_id + ); setIsLoading(false); } catch (err: any) { setIsLoading(false); - window.alert("Error: " + err.message); + console.log("Error: " + err.message); + window.alert("Error: Something went wrong. Please try again later."); } } const handleFormSubmit = async (event: React.FormEvent) => { event.preventDefault(); - const target = event.target; - // Todo: remove the below ignore comment - // @ts-ignore + const target = event.target as HTMLFormElement; const userQuery = target.userQuery.value; - // Todo: remove the below ignore comment - // @ts-ignore target.reset(); - await callAgent(userQuery); + await addMessagesToConversation(userQuery); + }; + + const handleInlineFormSubmit = async ( + userQuery: string, + conv_id: number, + team_name: string, + team_id: number + ) => { + await addMessagesToConversation(userQuery, conv_id, team_name, team_id); }; - if (isConversationLoading && !!id) return ; - if (isConversationError) { + const chatContainerClass = `flex h-full flex-col items-center justify-between pb-24 overflow-y-auto bg-captn-light-blue ${ + isLoading ? "opacity-40" : "opacity-100" + }`; + + // check if user has access to chat + if (conversations && conversations.length === 0) { return ( <> @@ -143,24 +111,17 @@ export default function ConversationWrapper() { ); } - const chatContainerClass = `flex h-full flex-col items-center justify-between pb-24 overflow-y-auto bg-captn-light-blue ${ - isLoading ? "opacity-40" : "opacity-100" - }`; - return (
-
+
{conversations && ( - // Todo: remove the below ignore comment - // @ts-ignore - + )}
{isLoading && } diff --git a/src/client/helpers.tsx b/src/client/helpers.tsx new file mode 100644 index 0000000..5f75abb --- /dev/null +++ b/src/client/helpers.tsx @@ -0,0 +1,72 @@ +import React from "react"; +import { useLocation } from "react-router-dom"; + +export function areThereAnyTasks(): boolean { + return true; +} + +type InputMessage = { + chatId: number; + createdAt: string; + id: number; + message: string; + previousConversationId: number | null; + replyToConversationId: number | null; + role: string; + team_id: number | null; + team_name: string | null; + team_status: string | null; + is_question_from_agent: boolean; + updatedAt: string; + userId: number; +}; + +type OutputMessage = { + role: string; + content: string; +}; + +function getLatestConversationID(input: InputMessage[]): number { + const allMessageIDS: number[] = input.map((message) => message.id); + const sortedAllMessageIDS = allMessageIDS.sort((a, b) => b - a); + const latestConversationID = sortedAllMessageIDS[0]; + return latestConversationID; +} + +export function prepareOpenAIRequest( + input: InputMessage[] +): [OutputMessage[], number] { + const messages: OutputMessage[] = input.map((message) => { + return { + role: message.role, + content: message.message, + }; + }); + const latestConversationID = getLatestConversationID(input); + return [messages, latestConversationID]; +} + +// A custom hook that builds on useLocation to parse +// the query string for you. +export function getQueryParam(paramName: string) { + const { search } = useLocation(); + return React.useMemo(() => new URLSearchParams(search), [search]).get( + paramName + ); +} + +export function setRedirectMsg(formInputRef: any, loginMsgQuery: string) { + if (loginMsgQuery) { + formInputRef.value = decodeURIComponent(loginMsgQuery); + } +} + +export function triggerSubmit( + node: any, + loginMsgQuery: string, + formInputRef: any +) { + if (loginMsgQuery && formInputRef && formInputRef.value !== "") { + node.click(); + } +} diff --git a/src/client/tests/helpers.test.tsx b/src/client/tests/helpers.test.tsx new file mode 100644 index 0000000..85eb3c3 --- /dev/null +++ b/src/client/tests/helpers.test.tsx @@ -0,0 +1,175 @@ +import { test, expect } from "vitest"; + +import { areThereAnyTasks, prepareOpenAIRequest } from "../helpers"; + +test("areThereAnyTasks", () => { + expect(areThereAnyTasks()).toBe(true); +}); + +test("prepareOpenAIRequest_1", () => { + const input = [ + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 8, + message: "First Message", + previousConversationId: null, + replyToConversationId: null, + role: "user", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 9, + message: "Second Message", + previousConversationId: null, + replyToConversationId: null, + role: "user", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + ]; + const expected_message = [ + { + role: "user", + content: "First Message", + }, + { + role: "user", + content: "Second Message", + }, + ]; + const expected_conv_id = 9; + + const [actual_message, actual_conv_id] = prepareOpenAIRequest(input); + expect(actual_message).toStrictEqual(expected_message); + expect(actual_conv_id).toStrictEqual(expected_conv_id); +}); + +test("prepareOpenAIRequest_2", () => { + const input = [ + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 8, + message: "First Message", + previousConversationId: null, + replyToConversationId: null, + role: "user", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + ]; + const expected_message = [ + { + role: "user", + content: "First Message", + }, + ]; + const expected_conv_id = 8; + + const [actual_message, actual_conv_id] = prepareOpenAIRequest(input); + expect(actual_message).toStrictEqual(expected_message); + expect(actual_conv_id).toStrictEqual(expected_conv_id); +}); + +test("prepareOpenAIRequest_3", () => { + const input = [ + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 1, + message: "First Message", + previousConversationId: null, + replyToConversationId: null, + role: "assistant", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 2, + message: "Second Message", + previousConversationId: null, + replyToConversationId: null, + role: "user", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 10, + message: "Third Message", + previousConversationId: null, + replyToConversationId: null, + role: "assistant", + team_id: 123, + team_name: "google_ads_agent", + team_status: "pause", + is_question_from_agent: true, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + { + chatId: 4, + createdAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + id: 22, + message: "Forth Message", + previousConversationId: null, + replyToConversationId: null, + role: "user", + team_id: null, + team_name: null, + team_status: null, + is_question_from_agent: false, + updatedAt: "Wed Nov 29 2023 06:37:27 GMT+0530 (India Standard Time)", + userId: 1, + }, + ]; + const expected_message = [ + { + role: "assistant", + content: "First Message", + }, + { + role: "user", + content: "Second Message", + }, + { + role: "assistant", + content: "Third Message", + }, + { + role: "user", + content: "Forth Message", + }, + ]; + const expected_conv_id = 22; + const [actual_message, actual_last_conv_id] = prepareOpenAIRequest(input); + expect(actual_message).toStrictEqual(expected_message); + expect(actual_last_conv_id).toStrictEqual(expected_conv_id); +}); diff --git a/src/server/actions.ts b/src/server/actions.ts index 730f6b1..17c7d94 100644 --- a/src/server/actions.ts +++ b/src/server/actions.ts @@ -4,24 +4,21 @@ import HttpError from "@wasp/core/HttpError.js"; import type { Chat } from "@wasp/entities"; import type { Conversation } from "@wasp/entities"; import type { - GenerateGptResponse, StripePayment, CreateChat, - UpdateConversation, + AddNewConversationToChat, + UpdateExistingConversation, GetAgentResponse, } from "@wasp/actions/types"; import type { StripePaymentResult, OpenAIResponse } from "./types"; import Stripe from "stripe"; -import { ADS_SERVER_URL } from "./config.js"; +import { ADS_SERVER_URL, DOMAIN } from "./config.js"; const stripe = new Stripe(process.env.STRIPE_KEY!, { apiVersion: "2022-11-15", }); -// WASP_WEB_CLIENT_URL will be set up by Wasp when deploying to production: https://wasp-lang.dev/docs/deploying -const DOMAIN = process.env.WASP_WEB_CLIENT_URL || "http://localhost:3000"; - export const stripePayment: StripePayment = async ( _args, context @@ -82,95 +79,6 @@ export const stripePayment: StripePayment = async ( } }; -type GptPayload = { - instructions: string; - command: string; - temperature: number; -}; - -// export const generateGptResponse: GenerateGptResponse = async ( -export const generateGptResponse: GenerateGptResponse = async ( - { instructions, command, temperature }, - context -) => { - if (!context.user) { - throw new HttpError(401); - } - - const payload = { - // model: 'gpt-3.5-turbo', - // engine:"airt-canada-gpt35-turbo-16k", - messages: [ - { - role: "system", - content: instructions, - }, - { - role: "user", - content: command, - }, - ], - temperature: Number(temperature), - }; - - try { - // if (!context.user.hasPaid && !context.user.credits) { - // throw new HttpError(402, 'User has not paid or is out of credits'); - // } else if (context.user.credits && !context.user.hasPaid) { - // console.log('decrementing credits'); - // await context.entities.User.update({ - // where: { id: context.user.id }, - // data: { - // credits: { - // decrement: 1, - // }, - // }, - // }); - // } - - console.log("fetching", payload); - // https://api.openai.com/v1/chat/completions - const response = await fetch( - "https://airt-openai-canada.openai.azure.com/openai/deployments/airt-canada-gpt35-turbo-16k/chat/completions?api-version=2023-07-01-preview", - { - headers: { - "Content-Type": "application/json", - // Authorization: `Bearer ${process.env.AZURE_OPENAI_API_KEY!}`, - "api-key": `${process.env.AZURE_OPENAI_API_KEY!}`, - }, - method: "POST", - body: JSON.stringify(payload), - } - ); - - const json = (await response.json()) as OpenAIResponse; - console.log("response json", json); - // return context.entities.RelatedObject.create({ - // data: { - // content: json?.choices[0].message.content, - // user: { connect: { id: context.user.id } }, - // }, - // }); - return { - content: json?.choices[0].message.content, - }; - } catch (error: any) { - if (!context.user.hasPaid && error?.statusCode != 402) { - await context.entities.User.update({ - where: { id: context.user.id }, - data: { - credits: { - increment: 1, - }, - }, - }); - } - console.error(error); - } - - throw new HttpError(500, "Something went wrong"); -}; - export const createChat: CreateChat = async ( _args, context @@ -186,68 +94,71 @@ export const createChat: CreateChat = async ( return await context.entities.Conversation.create({ data: { - conversation: [ - { - role: "assistant", - content: `Hi! I am Captn and I am here to help you with digital marketing. I can create and optimise marketing campaigns for you. But before I propose any activities, please let me know a little bit about your business and what your marketing goals are.`, - }, - ], + message: + "Hi! I am Captn and I am here to help you with digital marketing. I can create and optimise marketing campaigns for you. But before I propose any activities, please let me know a little bit about your business and what your marketing goals are.", + role: "assistant", chat: { connect: { id: chat.id } }, user: { connect: { id: context.user.id } }, }, }); }; -type UpdateConversationPayload = { - conversation_id: number; - conversations: any; - status?: string; -}; - -type ConversationItem = { +type AddNewConversationToChatPayload = { + message: string; role: string; - content: string; + chat_id: number; + team_name?: string; + team_id?: number; + team_status?: string; }; -function convertConversationList( - currentConversation: Conversation -): Array { - const conversationList: Array = Object.entries( - // @ts-ignore - currentConversation.conversation - ); - return conversationList.map((item) => item[1]); -} - -export const updateConversation: UpdateConversation< - UpdateConversationPayload, - Conversation +export const addNewConversationToChat: AddNewConversationToChat< + AddNewConversationToChatPayload, + Conversation[] > = async (args, context) => { if (!context.user) { throw new HttpError(401); } - const currentConversation = - await context.entities.Conversation.findFirstOrThrow({ - where: { id: args.conversation_id }, - }); - let currentConversationList = convertConversationList(currentConversation); - const existingRole = - currentConversationList[currentConversationList.length - 1]["role"]; - const openAIResponseRole = args.conversations[0]["role"]; + await context.entities.Conversation.create({ + data: { + message: args.message, + role: args.role, + chat: { connect: { id: args.chat_id } }, + user: { connect: { id: context.user.id } }, + ...(args.team_name && { team_name: args.team_name }), + ...(args.team_id && { team_id: args.team_id }), + ...(args.team_status && { team_status: args.team_status }), + }, + }); - if (!(existingRole === "assistant" && existingRole === openAIResponseRole)) { - currentConversationList = [ - ...currentConversationList, - ...args.conversations, - ]; - } + return context.entities.Conversation.findMany({ + where: { chatId: args.chat_id, userId: context.user.id }, + orderBy: { id: "asc" }, + }); +}; + +type UpdateExistingConversationPayload = { + chat_id: number; + conv_id: number; + is_question_from_agent: boolean; + team_status: null; +}; - return context.entities.Conversation.update({ - where: { id: args.conversation_id }, +export const updateExistingConversation: UpdateExistingConversation< + UpdateExistingConversationPayload, + void +> = async (args, context) => { + if (!context.user) { + throw new HttpError(401); + } + await context.entities.Conversation.update({ + where: { + id: args.conv_id, + }, data: { - conversation: currentConversationList, - ...(args.status && { status: args.status }), + team_status: args.team_status, + is_question_from_agent: args.is_question_from_agent, }, }); }; @@ -255,12 +166,23 @@ export const updateConversation: UpdateConversation< type AgentPayload = { message: any; conv_id: number; - is_answer_to_agent_question?: boolean; + isAnswerToAgentQuestion: boolean; + userResponseToTeamId: number | null | undefined; }; export const getAgentResponse: GetAgentResponse = async ( - { message, conv_id, is_answer_to_agent_question }, - context + { + message, + conv_id, + isAnswerToAgentQuestion, + userResponseToTeamId, + }: { + message: any; + conv_id: number; + isAnswerToAgentQuestion: boolean; + userResponseToTeamId: number | null | undefined; + }, + context: any ) => { if (!context.user) { throw new HttpError(401); @@ -270,7 +192,8 @@ export const getAgentResponse: GetAgentResponse = async ( message: message, conv_id: conv_id, user_id: context.user.id, - is_answer_to_agent_question: is_answer_to_agent_question, + is_answer_to_agent_question: isAnswerToAgentQuestion, + user_answer_to_team_id: userResponseToTeamId, }; console.log("==========="); console.log("Payload to Python server"); @@ -292,7 +215,12 @@ export const getAgentResponse: GetAgentResponse = async ( throw new Error(errorMsg); } - return { content: json["content"], team_status: json["team_status"] }; + return { + content: json["content"], + team_status: json["team_status"], + team_name: json["team_name"], + team_id: json["team_id"], + }; } catch (error: any) { throw new HttpError(500, "Something went wrong. Please try again later"); } diff --git a/src/server/config.js b/src/server/config.js index 80ad453..93f999c 100644 --- a/src/server/config.js +++ b/src/server/config.js @@ -1,2 +1,6 @@ export const ADS_SERVER_URL = process.env.ADS_SERVER_URL || "http://127.0.0.1:9000"; + +// WASP_WEB_CLIENT_URL will be set up by Wasp when deploying to production: https://wasp-lang.dev/docs/deploying +export const DOMAIN = + process.env.WASP_WEB_CLIENT_URL || "http://localhost:3000"; diff --git a/src/server/queries.ts b/src/server/queries.ts index 4aa674c..2db9f0d 100644 --- a/src/server/queries.ts +++ b/src/server/queries.ts @@ -1,26 +1,7 @@ -import HttpError from '@wasp/core/HttpError.js'; +import HttpError from "@wasp/core/HttpError.js"; -// import type { RelatedObject } from '@wasp/entities'; -// import type { GetRelatedObjects } from '@wasp/queries/types'; - -import type { Chat, Conversation } from '@wasp/entities'; -import type { GetChats, GetConversations } from '@wasp/queries/types'; - -// import type { Conversation } from '@wasp/entities'; -// import type { GetConversations } from '@wasp/queries/types'; - -// export const getRelatedObjects: GetRelatedObjects = async (args, context) => { -// if (!context.user) { -// throw new HttpError(401); -// } -// return context.entities.RelatedObject.findMany({ -// where: { -// user: { -// id: context.user.id -// } -// }, -// }) -// } +import type { Chat, Conversation } from "@wasp/entities"; +import type { GetChats, GetConversations } from "@wasp/queries/types"; export const getChats: GetChats = async (args, context) => { if (!context.user) { @@ -29,37 +10,26 @@ export const getChats: GetChats = async (args, context) => { return context.entities.Chat.findMany({ where: { user: { - id: context.user.id - } + id: context.user.id, + }, }, - orderBy: { id: 'desc' }, - }) -} + orderBy: { id: "desc" }, + }); +}; type GetConversationPayload = { - chatId: number -} + chatId: number; +}; -export const getConversations: GetConversations = async (args, context) => { +export const getConversations: GetConversations< + GetConversationPayload, + Conversation[] +> = async (args, context) => { if (!context.user) { throw new HttpError(401); } - return context.entities.Conversation.findFirstOrThrow({ + return context.entities.Conversation.findMany({ where: { chatId: args.chatId, userId: context.user.id }, - }) -} - - -// export const getConversations: GetConversations = async (args, context) => { -// if (!context.user) { -// throw new HttpError(401); -// } -// return context.entities.Task.findMany({ -// where: { -// chat: { -// id: args.chat_id -// } -// }, -// }) -// } - + orderBy: { id: "asc" }, + }); +}; diff --git a/src/server/webSocket.js b/src/server/webSocket.js index 48f8ac6..6127ba8 100644 --- a/src/server/webSocket.js +++ b/src/server/webSocket.js @@ -2,7 +2,7 @@ import HttpError from "@wasp/core/HttpError.js"; import { ADS_SERVER_URL } from "./config.js"; -export const webSocketFn = (io, context) => { +export const checkTeamStatusAndUpdateInDB = (io, context) => { io.on("connection", async (socket) => { if (socket.data.user) { const userEmail = socket.data.user.email; @@ -12,14 +12,14 @@ export const webSocketFn = (io, context) => { // Check for updates every 3 seconds const updateInterval = setInterval(async () => { const conversations = await context.entities.Conversation.findMany({ - where: { userId: socket.data.user.id, status: "inprogress" }, + where: { userId: socket.data.user.id, team_status: "inprogress" }, }); conversations.length > 0 && conversations.forEach(async function (conversation) { try { const payload = { - conversation_id: conversation.id, + team_id: conversation.team_id, }; const response = await fetch( `${ADS_SERVER_URL}/openai/get-team-status`, @@ -41,22 +41,31 @@ export const webSocketFn = (io, context) => { throw new Error(errorMsg); } - const conversation_status = json["status"]; - if ( - conversation_status === "completed" || - conversation_status === "pause" - ) { - const updated_conversation = conversation.conversation.concat([ - { role: "assistant", content: json["msg"] }, - ]); + const team_status = json["team_status"]; + if (team_status === "completed" || team_status === "pause") { + // const updated_conversation = conversation.conversation.concat([ + // { role: "assistant", content: json["msg"] }, + // ]); await context.entities.Conversation.update({ where: { // userId: socket.data.user.id, id: conversation.id, }, data: { - conversation: updated_conversation, - status: conversation_status, + team_status: null, + }, + }); + + await context.entities.Conversation.create({ + data: { + message: json["msg"], + role: "assistant", + team_name: json["team_name"], + team_id: Number(json["team_id"]), + team_status: team_status, + is_question_from_agent: team_status === "pause", + chat: { connect: { id: conversation.chatId } }, + user: { connect: { id: socket.data.user.id } }, }, }); }