From c662768b5ce5eb7f0c56dc69f61feb110818a09c Mon Sep 17 00:00:00 2001 From: Alex Dixon Date: Tue, 6 Aug 2024 19:17:14 -0700 Subject: [PATCH] i guess just use claude --- src/ell/studio/__main__.py | 73 ++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/src/ell/studio/__main__.py b/src/ell/studio/__main__.py index 8d04042c..6481cf72 100644 --- a/src/ell/studio/__main__.py +++ b/src/ell/studio/__main__.py @@ -5,7 +5,9 @@ from ell.studio.data_server import create_app from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse -from watchfiles import awatch +import watchfiles +import importlib +import sys import time def main(): @@ -31,13 +33,11 @@ async def serve_react_app(full_path: str): db_path = os.path.join(args.storage_dir, "ell.db") async def db_watcher(db_path, app): + print("Starting db watcher") last_stat = None - while True: - await asyncio.sleep(0.1) # Fixed interval of 0.1 seconds try: current_stat = os.stat(db_path) - if last_stat is None: print(f"Database file found: {db_path}") await app.notify_clients("database_updated") @@ -64,21 +64,58 @@ async def db_watcher(db_path, app): except Exception as e: print(f"Error checking database file: {e}") await asyncio.sleep(1) # Wait a bit longer on errors + finally: + await asyncio.sleep(1) # Use a consistent sleep interval + + def get_dependencies(module_name): + module = importlib.import_module(module_name) + return list(set(sys.modules[name].__file__ for name in sys.modules if name.startswith(module_name.split('.')[0]))) + + def reload_app(): + importlib.reload(sys.modules["ell.studio.data_server"]) + return create_app() + + async def run_server(server): + await server.serve() + + async def watch_files(dependencies, server, config, loop): + async for changes in watchfiles.awatch(*dependencies): + print(f"Detected changes in {changes}. Reloading...") + new_app = reload_app() + await server.shutdown() + config.app = new_app + server.force_exit = False + loop.create_task(run_server(server)) + + async def main_async(args): + db_path = os.path.join(args.storage_dir, "ell.db") + dependencies = get_dependencies("ell.studio.data_server") + app = create_app() + + config = uvicorn.Config( + app=app, + host=args.host, + port=args.port, + loop=asyncio.get_event_loop(), + ) + server = uvicorn.Server(config) + + tasks = [ + asyncio.create_task(run_server(server)), + asyncio.create_task(watch_files(dependencies, server, config, asyncio.get_event_loop())), + asyncio.create_task(db_watcher(db_path, app)) + ] + + try: + await asyncio.gather(*tasks) + except asyncio.CancelledError: + pass + finally: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) - # Start the database watcher - loop = asyncio.new_event_loop() - - # config = uvicorn.Config(app=app, port=args.port, loop=loop,reload=True,#if args.dev else False, - # reload_delay=1) - # server = uvicorn.Server(config) - uvicorn.run("ell.studio.data_server:create_app", - reload=True, - reload_delay=5, - host=args.host, - port=args.port) - # loop.create_task(server.serve()) - loop.create_task(db_watcher(db_path, app)) - loop.run_forever() + asyncio.run(main_async(args)) if __name__ == "__main__": main() \ No newline at end of file