Skip to content

Commit

Permalink
Merge branch 'main' into research-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo authored Apr 28, 2024
2 parents 1dc5f48 + a78c80f commit aa023a3
Show file tree
Hide file tree
Showing 9 changed files with 480 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/single_llama_index.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging

from dotenv import load_dotenv
from langchain_community.tools import DuckDuckGoSearchRun


from motleycrew import MotleyCrew, Task
from motleycrew.agent.llama_index import ReActLlamaIndexMotleyAgent
from motleycrew.caсhing import enable_cache, disable_cache

logging.basicConfig(level=logging.INFO)

load_dotenv()
enable_cache()

search_tool = DuckDuckGoSearchRun()

Expand Down Expand Up @@ -40,3 +46,4 @@
print(list(result._done)[0].outputs)

print("######################")
disable_cache()
2 changes: 2 additions & 0 deletions motleycrew/caсhing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .caching import enable_cache, disable_cache, set_cache_location, set_strong_cache
from .http_cache import CACHE_WHITELIST, CACHE_BLACKLIST
42 changes: 42 additions & 0 deletions motleycrew/caсhing/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

from motleycrew.caсhing.http_cache import (
BaseHttpCache,
RequestsHttpCaching,
HttpxHttpCaching,
CurlCffiHttpCaching,
)

is_caching = False
caching_http_library_list = [
RequestsHttpCaching(),
HttpxHttpCaching(),
CurlCffiHttpCaching(),
]


def set_strong_cache(val: bool):
"""Enabling disabling the strictly caching option"""
BaseHttpCache.strong_cache = bool(val)


def set_cache_location(location: str) -> str:
"""Sets the caching root directory, returns the absolute path of the derrictory"""
BaseHttpCache.root_cache_dir = location
return os.path.abspath(BaseHttpCache.root_cache_dir)


def enable_cache():
"""The function of enable the caching process"""
global is_caching
for http_cache in caching_http_library_list:
http_cache.enable()
is_caching = True


def disable_cache():
"""The function of disable the caching process"""
global is_caching
for http_cache in caching_http_library_list:
http_cache.disable()
is_caching = False
321 changes: 321 additions & 0 deletions motleycrew/caсhing/http_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
import os
from pathlib import Path
from abc import ABC, abstractmethod
from typing import List, Callable, Any, Union
from urllib.parse import urlparse
import logging
import inspect
import fnmatch

from dotenv import load_dotenv
import requests
from httpx import Client
from curl_cffi.requests import AsyncSession
import cloudpickle
import platformdirs

from .utils import recursive_hash, hash_code, FakeRLock

CACHE_WHITELIST = []
CACHE_BLACKLIST = [
"*//api.lunary.ai/*",
]


class CacheException(Exception):
"""Exception for caching process"""


class StrongCacheException(BaseException):
"""Exception use of cache only"""


load_dotenv()


def file_cache(http_cache: "BaseHttpCache"):
"""Decorator to cache function output based on its inputs, ignoring specified parameters."""

def decorator(func):
def wrapper(*args, **kwargs):
return http_cache.get_response(func, *args, **kwargs)

return wrapper

return decorator


def afile_cache(http_cache: "BaseHttpCache"):
"""Async decorator to cache function output based on its inputs, ignoring specified parameters."""

def decorator(func):
async def wrapper(*args, **kwargs):
return await http_cache.aget_response(func, *args, **kwargs)

return wrapper

return decorator


class BaseHttpCache(ABC):
"""Basic abstract class for replacing http library methods"""

ignore_params: List[str] = [] # ignore params names for create hash file name
library_name: str = ""
app_name = os.environ.get("APP_NAME") or "motleycrew"
root_cache_dir = platformdirs.user_cache_dir(app_name)
strong_cache = False

def __init__(self, *args, **kwargs):
self.is_caching = False

@abstractmethod
def get_url(self, *args, **kwargs) -> str:
"""Finds the url in the arguments and returns it"""

@abstractmethod
def _enable(self):
"""Replacing the original function with a caching function"""

@abstractmethod
def _disable(self):
"""Replacing the caching function with the original one"""

def enable(self):
"""Enable caching"""
self._enable()
self.is_caching = True
library_log = (
"for {} library.".format(self.library_name) if self.library_name else "."
)
logging.info("Enable caching {} class {}".format(self.__class__, library_log))

def disable(self):
"""Disable caching"""
self._disable()
self.is_caching = False
library_log = (
"for {} library.".format(self.library_name) if self.library_name else "."
)
logging.info("Disable caching {} class {}".format(self.__class__, library_log))

def prepare_response(self, response: Any) -> Any:
"""Preparing the response object before saving"""
return response

def should_cache(self, url: str) -> bool:
if CACHE_WHITELIST and CACHE_BLACKLIST:
raise CacheException(
"It is necessary to fill in only the CACHE_WHITELIST or the CACHE_BLACKLIST"
)
elif CACHE_WHITELIST:
return self.url_matching(url, CACHE_WHITELIST)
elif CACHE_BLACKLIST:
return not self.url_matching(url, CACHE_BLACKLIST)
return True

def get_cache_file(self, func: Callable, *args, **kwargs) -> Union[tuple, None]:
url = self.get_url(*args, **kwargs)
url_parsed = urlparse(url)

# Check valid url
if not self.should_cache(url):
logging.info("Ignore url to cache: {}".format(url))
return None

# check or create cache dirs
root_dir = Path(self.root_cache_dir)
cache_dir = (
root_dir
/ url_parsed.hostname
/ url_parsed.path.strip("/").replace("/", "_")
)
cache_dir.mkdir(parents=True, exist_ok=True)

# Convert args to a dictionary based on the function's signature
args_names = func.__code__.co_varnames[: func.__code__.co_argcount]
args_dict = dict(zip(args_names, args))

# Remove ignored params
kwargs_clone = kwargs.copy()
for param in self.ignore_params:
args_dict.pop(param, None)
kwargs_clone.pop(param, None)

# Create hash based on argument names, argument values, and function source code
func_source_code_hash = hash_code(inspect.getsource(func))
arg_hash = (
recursive_hash(args_dict, ignore_params=self.ignore_params)
+ recursive_hash(kwargs_clone, ignore_params=self.ignore_params)
+ func_source_code_hash
)

cache_file = cache_dir / "{}.pkl".format(arg_hash)
return cache_file, url

def get_response(self, func: Callable, *args, **kwargs) -> Any:
"""Returns a response from the cache if it is found, or executes the request first"""
cache_data = self.get_cache_file(func, *args, **kwargs)
if cache_data is None:
return func(*args, **kwargs)
cache_file, url = cache_data

# If cache exists, load and return it
result = self.load_cache_response(cache_file, url)
if result is not None:
return result

# Otherwise, call the function and save its result to the cache
result = func(*args, **kwargs)

self.write_to_cache(result, cache_file, url)
return result

async def aget_response(self, func: Callable, *args, **kwargs) -> Any:
"""Async returns a response from the cache if it is found, or executes the request first"""
cache_data = self.get_cache_file(func, *args, **kwargs)
if cache_data is None:
return await func(*args, **kwargs)
cache_file, url = cache_data

# If cache exists, load and return it
result = self.load_cache_response(cache_file, url)
if result is not None:
return result

# Otherwise, call the function and save its result to the cache
result = await func(*args, **kwargs)

self.write_to_cache(result, cache_file, url)
return result

def load_cache_response(self, cache_file: Path, url: str) -> Union[Any, None]:
"""Loads and returns the cached response"""
if cache_file.exists():
return self.read_from_cache(cache_file, url)
elif self.strong_cache:
msg = "Cache file not found: {}\nthe strictly caching option is enabled.".format(
str(cache_file)
)
raise StrongCacheException(msg)

def read_from_cache(self, cache_file: Path, url: str = "") -> Union[Any, None]:
"""Reads and returns a serialized object from a file"""
try:
with cache_file.open("rb") as f:
logging.info("Used cache for {} url from {}".format(url, cache_file))
result = cloudpickle.load(f)
return result
except Exception as e:
logging.warning("Unpickling failed for {}".format(cache_file))
if self.strong_cache:
msg = "Error reading cached file: {}\n{}".format(
str(e), str(cache_file)
)
raise StrongCacheException(msg)
return None

def write_to_cache(self, response: Any, cache_file: Path, url: str = "") -> None:
"""Writes the response object to a file"""
response = self.prepare_response(response)
try:
with cache_file.open("wb") as f:
cloudpickle.dump(response, f)
logging.info("Write cache for {} url to {}".format(url, cache_file))
except Exception as e:
logging.warning("Pickling failed for {} url: {}".format(cache_file, e))

@staticmethod
def url_matching(url: str, patterns: List[str]) -> bool:
"""Checking the url for a match in the list of templates"""
return any([fnmatch.fnmatch(url, pat) for pat in patterns])


class RequestsHttpCaching(BaseHttpCache):
"""Requests library caching"""

ignore_params = ["timestamp", "runId", "parentRunId"]
library_name = "requests"

def __init__(self, *args, **kwargs):
super(RequestsHttpCaching, self).__init__(*args, **kwargs)
self.library_method = requests.api.request

def get_url(self, *args, **kwargs) -> str:
"""Finds the url in the arguments and returns it"""
return args[1]

def _enable(self):
"""Replacing the original function with a caching function"""

@file_cache(self)
def request_func(*args, **kwargs):
return self.library_method(*args, **kwargs)

requests.api.request = request_func

def _disable(self):
"""Replacing the caching function with the original one"""
requests.api.request = self.library_method


class HttpxHttpCaching(BaseHttpCache):
"""Httpx library caching"""

ignore_params = ["s", "headers", "stream", "extensions"]
library_name = "Httpx"

def __init__(self, *args, **kwargs):
super(HttpxHttpCaching, self).__init__(*args, **kwargs)
self.library_method = Client.send

def get_url(self, *args, **kwargs) -> str:
"""Finds the url in the arguments and returns it"""
return str(args[1].url)

def _enable(self):
"""Replacing the original function with a caching function"""

@file_cache(self)
def request_func(s, request, *args, **kwargs):
return self.library_method(s, request, **kwargs)

Client.send = request_func

def _disable(self):
"""Replacing the caching function with the original one"""
Client.send = self.library_method


class CurlCffiHttpCaching(BaseHttpCache):
"""Curl Cffi library caching"""

ignore_params = ["s"]
library_name = "Curl cffi"

def __init__(self, *args, **kwargs):
super(CurlCffiHttpCaching, self).__init__(*args, **kwargs)
self.library_method = AsyncSession.request

def get_url(self, *args, **kwargs) -> str:
"""Finds the url in the arguments and returns it"""
return args[2]

def prepare_response(self, response: Any) -> Any:
"""Preparing the response object before saving"""
response.curl = None
response.cookies.jar._cookies_lock = FakeRLock()
return response

def _enable(self):
"""Replacing the original function with a caching function"""

@afile_cache(self)
async def request_func(s, method, url, *args, **kwargs):
return await self.library_method(s, method, url, *args, **kwargs)

AsyncSession.request = request_func

def _disable(self):
"""Replacing the caching function with the original one"""
AsyncSession.request = self.library_method
Loading

0 comments on commit aa023a3

Please sign in to comment.