Skip to content

Commit

Permalink
Merge pull request #114 from MadcowD/wguss/webosckets
Browse files Browse the repository at this point in the history
added websockets. Closes #32
  • Loading branch information
MadcowD authored Aug 5, 2024
2 parents 75c5c73 + 576afa2 commit 55b15be
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 65 deletions.
38 changes: 29 additions & 9 deletions ell-studio/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,41 @@ import Traces from './pages/Traces';
import { ThemeProvider } from './contexts/ThemeContext';
import './styles/globals.css';
import './styles/sourceCode.css';
import { useWebSocketConnection } from './hooks/useBackend';
import { Toaster, toast } from 'react-hot-toast';

const WebSocketConnectionProvider = ({children}) => {
const { isConnected } = useWebSocketConnection();

React.useEffect(() => {
if (isConnected) {
toast.success('Store connected', {
duration: 1000,
});
} else {
toast('Connecting to store...', {
icon: '🔄',
duration: 500,
});
}
}, [isConnected]);

return (
<>
{children}
<Toaster position="top-right" />
</>
);
};

// Create a client
const queryClient = new QueryClient({
defaultOptions: {
queries: {
refetchOnWindowFocus: false, // default: true
retry: false, // default: 3
staleTime: 5 * 60 * 1000, // 5 minutes
},
},
});
const queryClient = new QueryClient();

function App() {
return (
<QueryClientProvider client={queryClient}>
<ThemeProvider>
<WebSocketConnectionProvider>
<Router>
<div className="flex min-h-screen max-h-screen bg-gray-900 text-gray-100">
<Sidebar />
Expand All @@ -38,6 +57,7 @@ function App() {
</div>
</div>
</Router>
</WebSocketConnectionProvider>
</ThemeProvider>
</QueryClientProvider>
);
Expand Down
13 changes: 12 additions & 1 deletion ell-studio/src/components/HierarchicalTable.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@ const TableRow = ({ item, schema, level = 0, onRowClick, columnWidths, updateWid
const hasChildren = item.children && item.children.length > 0;
const isExpanded = expandedRows[item.id];
const isSelected = isItemSelected(item);
const [isNew, setIsNew] = useState(true);

const customRowClassName = rowClassName ? rowClassName(item) : '';

useEffect(() => {
if (isNew) {
const timer = setTimeout(() => setIsNew(false), 200);
return () => clearTimeout(timer);
}
}, [isNew]);

return (
<React.Fragment>
<tr
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-colors duration-500 ease-in-out ${isSelected ? 'bg-blue-900/30' : ''} ${customRowClassName}`}
className={`border-b border-gray-800 hover:bg-gray-800/30 cursor-pointer transition-all duration-500 ease-in-out
${isSelected ? 'bg-blue-900/30' : ''}
${customRowClassName}
${isNew ? 'animate-fade-in bg-green-900/30' : ''}`}
onClick={() => {
if (onRowClick) onRowClick(item);
}}
Expand Down
1 change: 0 additions & 1 deletion ell-studio/src/components/invocations/InvocationsTable.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const InvocationsTable = ({ invocations, currentPage, setCurrentPage, pageSize,
const navigate = useNavigate();



const onClickLMP = useCallback(({lmp, id : invocationId}) => {
navigate(`/lmp/${lmp.name}/${lmp.lmp_id}?i=${invocationId}`);
}, [navigate]);
Expand Down
70 changes: 55 additions & 15 deletions ell-studio/src/hooks/useBackend.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,45 @@
import { useQuery, useQueries } from '@tanstack/react-query';
import { useQuery, useQueryClient, useQueries } from '@tanstack/react-query';
import axios from 'axios';

import { useEffect, useState } from 'react';

const API_BASE_URL = "http://localhost:8080";
const WS_URL = "ws://localhost:8080/ws";

export const useWebSocketConnection = () => {
const queryClient = useQueryClient();
const [isConnected, setIsConnected] = useState(false);
useEffect(() => {
const socket = new WebSocket(WS_URL);

socket.onopen = () => {
console.log('WebSocket connected');
setIsConnected(true);
};

socket.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.entity === 'database_updated') {
// Invalidate relevant queries
queryClient.invalidateQueries({queryKey: ['traces']});
queryClient.invalidateQueries({queryKey: ['latestLMPs']});
queryClient.invalidateQueries({queryKey: ['invocations']}) ;
queryClient.invalidateQueries({queryKey: ['lmpDetails']});
console.log('Database updated, invalidating queries');
}
};

socket.onclose = () => {
console.log('WebSocket disconnected');
setIsConnected(false);
};

return () => {
console.log('WebSocket connection closed');
socket.close();
};
}, [queryClient]);
return { isConnected };
};

export const useLMPs = (name, id) => {
return useQuery({
Expand All @@ -21,7 +58,7 @@ export const useLMPs = (name, id) => {
});
};

export const useInvocations = (name, id, page = 0, pageSize = 50) => {
export const useInvocationsFromLMP = (name, id, page = 0, pageSize = 50) => {
return useQuery({
queryKey: ['invocations', name, id, page, pageSize],
queryFn: async () => {
Expand All @@ -39,6 +76,18 @@ export const useInvocations = (name, id, page = 0, pageSize = 50) => {
});
};

export const useInvocation = (id) => {
return useQuery({
queryKey: ['invocation', id],
queryFn: async () => {
const response = await axios.get(`${API_BASE_URL}/api/invocation/${id}`);
return response.data;
},
enabled: !!id,
});
}


export const useMultipleLMPs = (usesIds) => {
const multipleLMPs = useQueries({
queries: (usesIds || []).map(use => ({
Expand All @@ -55,26 +104,17 @@ export const useMultipleLMPs = (usesIds) => {
return { isLoading, data };
};




export const useLatestLMPs = (page = 0, pageSize = 100) => {
return useQuery({
queryKey: ['allLMPs', page, pageSize],
queryKey: ['latestLMPs', page, pageSize],
queryFn: async () => {
const skip = page * pageSize;
const response = await axios.get(`${API_BASE_URL}/api/latest/lmps?skip=${skip}&limit=${pageSize}`);
const lmps = response.data;

return lmps;
return response.data;
}
});
};





export const useTraces = (lmps) => {
return useQuery({
queryKey: ['traces', lmps],
Expand Down Expand Up @@ -103,4 +143,4 @@ export const useTraces = (lmps) => {
},
enabled: !!lmps && lmps.length > 0,
});
};
};
17 changes: 12 additions & 5 deletions ell-studio/src/pages/LMP.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React, { useState, useEffect, useMemo } from "react";
import { useParams, useSearchParams, useNavigate, Link } from "react-router-dom";
import { useLMPs, useInvocations, useMultipleLMPs } from "../hooks/useBackend";
import { useLMPs, useInvocationsFromLMP, useMultipleLMPs, useInvocation } from "../hooks/useBackend";
import InvocationsTable from "../components/invocations/InvocationsTable";
import DependencyGraphPane from "../components/DependencyGraphPane";
import LMPSourceView from "../components/source/LMPSourceView";
Expand Down Expand Up @@ -35,6 +35,7 @@ function LMP() {
const requestedInvocationId = searchParams.get("i");

const [currentPage, setCurrentPage] = useState(0);
const pageSize = 50;

// TODO: Could be expensive if you have a funct on of versions.
const { data: versionHistory, isLoading: isLoadingLMP } = useLMPs(name);
Expand All @@ -47,7 +48,7 @@ function LMP() {
}
}, [versionHistory, id]);

const { data: invocations } = useInvocations(name, id);
const { data: invocations } = useInvocationsFromLMP(name, id, currentPage, pageSize);
const { data: uses } = useMultipleLMPs(lmp?.uses);


Expand All @@ -65,9 +66,14 @@ function LMP() {
: null;
}, [versionHistory, lmp]);

const requestedInvocation = useMemo(() => invocations?.find(
(invocation) => invocation.id === requestedInvocationId
), [invocations, requestedInvocationId]);
const {data: requestedInvocationQueryData} = useInvocation(requestedInvocationId);
const requestedInvocation = useMemo(() => {
if (!requestedInvocationQueryData)
return invocations?.find(i => i.id === requestedInvocationId);
else
return requestedInvocationQueryData;

}, [requestedInvocationQueryData, invocations, requestedInvocationId]);

useEffect(() => {
setSelectedTrace(requestedInvocation);
Expand Down Expand Up @@ -233,6 +239,7 @@ function LMP() {
<InvocationsTable
invocations={invocations}
currentPage={currentPage}
pageSize={pageSize}
setCurrentPage={setCurrentPage}
producingLmp={lmp}
onSelectTrace={(trace) => {
Expand Down
18 changes: 5 additions & 13 deletions ell-studio/src/pages/Traces.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,21 @@ import { FiCopy, FiZap, FiEdit2, FiFilter, FiClock, FiColumns, FiPause, FiPlay }
import InvocationsTable from '../components/invocations/InvocationsTable';
import InvocationsLayout from '../components/invocations/InvocationsLayout';
import { useNavigate, useLocation } from 'react-router-dom';
import { useInvocations } from '../hooks/useBackend';
import { useInvocationsFromLMP } from '../hooks/useBackend';

const Traces = () => {
const [selectedTrace, setSelectedTrace] = useState(null);
const [isPolling, setIsPolling] = useState(true);
const navigate = useNavigate();
const location = useLocation();


// TODO Unify invocation search behaviour with the LMP page.
const [currentPage, setCurrentPage] = useState(0);
const pageSize = 10;
const pageSize = 50;

const { data: invocations, refetch , isLoading } = useInvocations(null, null, currentPage, pageSize);
const { data: invocations , isLoading } = useInvocationsFromLMP(null, null, currentPage, pageSize);

useEffect(() => {
let intervalId;
if (isPolling) {
intervalId = setInterval(refetch, 200); // Poll every 200ms
}

return () => {
if (intervalId) clearInterval(intervalId);
};
}, [isPolling, refetch]);

useEffect(() => {
const searchParams = new URLSearchParams(location.search);
Expand Down
Binary file added examples/sqlite_example/ell.db-shm
Binary file not shown.
Empty file.
1 change: 1 addition & 0 deletions src/ell/stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def get_lmps(self, skip: int = 0, limit: int = 10, subquery=None, **filters: Opt
))

if filters:
print(f"Filters: {filters}")
for key, value in filters.items():
query = query.where(getattr(SerializedLMP, key) == value)

Expand Down
23 changes: 20 additions & 3 deletions src/ell/studio/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import os
import uvicorn
from argparse import ArgumentParser
from ell.studio.data_server import create_app
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from watchfiles import run_process
from watchfiles import awatch


def main():
parser = ArgumentParser(description="ELL Studio Data Server")
Expand All @@ -26,8 +28,23 @@ def main():
async def serve_react_app(full_path: str):
return FileResponse(os.path.join(static_dir, "index.html"))

# In production mode, run without auto-reloading
uvicorn.run(app, host=args.host, port=args.port)
db_path = os.path.join(args.storage_dir, "ell.db")

async def db_watcher():
async for changes in awatch(db_path):
print(f"Database changed: {changes}")
await app.notify_clients("database_updated")

# Start the database watcher


loop = asyncio.new_event_loop()

config = uvicorn.Config(app=app, port=args.port, loop=loop)
server = uvicorn.Server(config)
loop.create_task(server.serve())
loop.create_task(db_watcher())
loop.run_forever()

if __name__ == "__main__":
main()
Loading

0 comments on commit 55b15be

Please sign in to comment.