diff --git a/api-bank/README.md b/api-bank/README.md index e5367da6..b254ae8e 100644 --- a/api-bank/README.md +++ b/api-bank/README.md @@ -1,25 +1,26 @@ -# API-Bank: A Benchmark for Tool-Augmented LLMs -Minghao Li, Feifan Song, Bowen Yu, Haiyang Yu, Zhoujun Li, Fei Huang, Yongbin Li +# API-Bank: A Comprehensive Benchmark for Tool-Augmented LLMs + +Minghao Li, Yingxiu Zhao, Bowen Yu, Feifan Song, Hangyu Li, Haiyang Yu, Zhoujun Li, Fei Huang, Yongbin Li arXiv: [[Abstract]](https://arxiv.org/abs/2304.08244)/[[PDF]](https://arxiv.org/pdf/2304.08244.pdf) - ## News +- **The Lynx model is released on [HuggingFace Hub](https://huggingface.co/liminghao1630/Lynx-7b).** +- **API-Bank is accepted by EMNLP 2023.** - **The code and data of API-Bank have been released.** ## Abstract -Recent research has shown that Large Language Models (LLMs) can utilize external tools to improve their contextual processing abilities, moving away from the pure language modeling paradigm and paving the way for Artificial General Intelligence. Despite this, there has been a lack of systematic evaluation to demonstrate the efficacy of LLMs using tools to respond to human instructions. This paper presents API-Bank, the first benchmark tailored for Tool-Augmented LLMs. API-Bank includes 53 commonly used API tools, a complete Tool-Augmented LLM workflow, and 264 annotated dialogues that encompass a total of 568 API calls. These resources have been designed to thoroughly evaluate LLMs' ability to plan step-by-step API calls, retrieve relevant APIs, and correctly execute API calls to meet human needs. The experimental results show that GPT-3.5 emerges the ability to use the tools relative to GPT3, while GPT-4 has stronger planning performance. Nevertheless, there remains considerable scope for further improvement when compared to human performance. Additionally, detailed error analysis and case studies demonstrate the feasibility of Tool-Augmented LLMs for daily use, as well as the primary challenges that future research needs to address. +Recent research has demonstrated that Large Language Models (LLMs) can enhance their capabilities by utilizing external tools. However, three pivotal questions remain unanswered: (1) How effective are current LLMs in utilizing tools? (2) How can we enhance LLMs' ability to utilize tools? (3) What obstacles need to be overcome to leverage tools? To address these questions, we introduce API-Bank, a groundbreaking benchmark, specifically designed for tool-augmented LLMs. For the first question, we develop a runnable evaluation system consisting of 73 API tools. We annotate 314 tool-use dialogues with 753 API calls to assess the existing LLMs' capabilities in planning, retrieving, and calling APIs. For the second question, we construct a comprehensive training set containing 1,888 tool-use dialogues from 2,138 APIs spanning 1,000 distinct domains. Using this dataset, we train Lynx, a tool-augmented LLM initialized from Alpaca. Experimental results demonstrate that GPT-3.5 exhibits improved tool utilization compared to GPT-3, while GPT-4 excels in planning. However, there is still significant potential for further improvement. Moreover, Lynx surpasses Alpaca's tool utilization performance by more than 26 pts and approaches the effectiveness of GPT-3.5. Through error analysis, we highlight the key challenges for future research in this field to answer the third question. -## Tool-Augmented LLMs Paradigm +## Multi-Agent Dataset Synthesis -![Paradigm](https://cdn.jsdelivr.net/gh/liminghao1630/auxiliary_use/figures/flowchart.png) +![multiagent](./figures/multi-agent.png) -## System Design +## Evaluation Tasks -![System](https://cdn.jsdelivr.net/gh/liminghao1630/auxiliary_use/figures/system.png) +![ability](./figures/three_ability.png) ## Demo As far as we know, there is a conflict between the dependencies of the `googletrans` package and the dependencies of the `gradio` package, which may cause the demo not to run properly. There is no good solution, you can uninstall `googletrans` first when using the demo. @@ -50,8 +51,9 @@ JsDelivr: https://cdn.jsdelivr.net/gh/liminghao1630/auxiliary_use/gpt-3.5-demo.g ## Evaluation -The conversation data of level-1 and level-2 are stored in the `lv1-lv2-samples` directory, please follow the code in `evaluator.py` to design the evaluation script. -The evaluation of level-3 needs to be done manually, you can use `simulator.py` or `demo.py` for testing. +The datasets are released on [HuggingFace Hub](https://huggingface.co/datasets/liminghao1630/API-Bank). +The conversation data of level-1 and level-2 are stored in the `lv1-lv2-samples` directory or `test-data`, please follow the code in `evaluator.py`/`evaluator_by_json.py` to design the evaluation script. +The evaluation of level-3 requires `lv3_evaluator.py`. diff --git a/api-bank/api_call_extraction.py b/api-bank/api_call_extraction.py index ff853b50..ea23bfce 100644 --- a/api-bank/api_call_extraction.py +++ b/api-bank/api_call_extraction.py @@ -1,5 +1,8 @@ import re +def fn(**kwargs): + return kwargs + def get_api_call(model_output): api_call_pattern = r"\[(\w+)\((.*)\)\]" api_call_pattern = re.compile(api_call_pattern) @@ -16,6 +19,13 @@ def parse_api_call(text): api_name = match.group(1) params = match.group(2) + # params = params.replace('\'[', '[') + # params = params.replace(']\'', ']') + # params = params.replace('\'{', '{') + # params = params.replace('}\'', '}') + + # param_dict = eval('fn(' + params + ')') + param_pattern = r"(\w+)\s*=\s*['\"](.+?)['\"]|(\w+)\s*=\s*(\[.*\])|(\w+)\s*=\s*(\w+)" param_dict = {} for m in re.finditer(param_pattern, params): diff --git a/api-bank/apis/add_agenda.py b/api-bank/apis/add_agenda.py index 8cc9a3a6..d0d115c2 100644 --- a/api-bank/apis/add_agenda.py +++ b/api-bank/apis/add_agenda.py @@ -5,7 +5,7 @@ class AddAgenda(API): - description = "The API for adding a schedule item includes parameters for token, content, time, and location." + description = "The API for adding a agenda item includes content, time and location." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the agenda.'}, diff --git a/api-bank/apis/add_alarm.py b/api-bank/apis/add_alarm.py index 23fea1aa..8eab7422 100644 --- a/api-bank/apis/add_alarm.py +++ b/api-bank/apis/add_alarm.py @@ -5,7 +5,7 @@ class AddAlarm(API): - description = "The API for setting an alarm includes a parameter for the time." + description = "The API for setting an alarm includes a parameter for the alarm time." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'time': {'type': 'str', 'description': 'The time for alarm. Format: %Y-%m-%d %H:%M:%S'} diff --git a/api-bank/apis/add_meeting.py b/api-bank/apis/add_meeting.py index 3cf03f3a..81594630 100644 --- a/api-bank/apis/add_meeting.py +++ b/api-bank/apis/add_meeting.py @@ -5,12 +5,8 @@ class AddMeeting(API): - description = "This API allows users to make a reservation for a meeting and store the meeting information in the database." \ - "Function:" \ - "Allow users to make a reservation for a meeting." \ - "Exception Handling:" \ - "1. If the reservation is successful, return a success message." \ - "2. If the reservation fails, return a corresponding error message." + + description = "This API allows users to make a reservation for a meeting and store the meeting information (e.g., topic, time, location, attendees) in the database." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'meeting_topic': {'type': 'str', 'description': 'The title of the meeting, no more than 50 characters.'}, diff --git a/api-bank/apis/add_reminder.py b/api-bank/apis/add_reminder.py index 503b2159..685fbdba 100644 --- a/api-bank/apis/add_reminder.py +++ b/api-bank/apis/add_reminder.py @@ -5,11 +5,8 @@ class AddReminder(API): - description = "Add a reminder API that takes three parameters - 'token','content' and 'time'. " \ - "The 'token' parameter refers to the user's token " \ - "and the 'content' parameter refers to the description of the reminder " \ - "and the 'time' parameter specifies the time at which the reminder " \ - "should be triggered." + + description = "The API for adding a reminder item includes content and time." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the conference.'}, diff --git a/api-bank/apis/delete_meeting.py b/api-bank/apis/delete_meeting.py index 6a1c4bf7..7c0b258d 100644 --- a/api-bank/apis/delete_meeting.py +++ b/api-bank/apis/delete_meeting.py @@ -5,12 +5,8 @@ class DeleteMeeting(API): - description = "This API allows users to delete a reservation for a meeting and remove the meeting information in the database." \ - "Function:" \ - "Delete user's reservation for a meeting." \ - "Exception Handling:" \ - "1. If the deletion is successful, return a success message." \ - "2. If the deletion fails, return a corresponding error message." + + description = "This API allows users to delete a reservation for a meeting and remove the meeting information in the database." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'meeting_topic': {'type': 'str', 'description': 'The title of the meeting, no more than 50 characters.'}, diff --git a/api-bank/apis/delete_reminder.py b/api-bank/apis/delete_reminder.py index 9431d629..9b5a46c8 100644 --- a/api-bank/apis/delete_reminder.py +++ b/api-bank/apis/delete_reminder.py @@ -5,10 +5,8 @@ class DeleteReminder(API): - description = "Delete a reminder API that takes three parameters - 'token','content' and 'time'. " \ - "The 'token' parameter refers to the user's token " \ - "and the 'content' parameter refers to the description of the reminder " \ - "and the 'time' parameter specifies the time at which the reminder should be triggered." + + description = "The API for deleting a reminder item includes content and time." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the conference.'}, diff --git a/api-bank/apis/delete_scene.py b/api-bank/apis/delete_scene.py index 91d2dade..16669ca4 100644 --- a/api-bank/apis/delete_scene.py +++ b/api-bank/apis/delete_scene.py @@ -1,7 +1,8 @@ from apis.api import API class DeleteScene(API): - description = 'This API deletes a scene.' + + description = 'This API deletes a scene by its name.' input_parameters = { "name": {'type': 'str', 'description': 'The name of the scene.'}, } diff --git a/api-bank/apis/dictionary.py b/api-bank/apis/dictionary.py index b2fa4fea..c72047a6 100644 --- a/api-bank/apis/dictionary.py +++ b/api-bank/apis/dictionary.py @@ -3,7 +3,7 @@ import requests class Dictionary(API): - description = 'This API searches for a given keyword.' + description = 'This API searches the dictionary for a given keyword.' input_parameters = { "keyword": {'type': 'str', 'description': 'The keyword to search.'}, } @@ -85,4 +85,3 @@ def check_api_call_correctness(self, response, groundtruth) -> bool: if response['exception'] != groundtruth['exception']: return False return True - diff --git a/api-bank/apis/emergency_knowledge.py b/api-bank/apis/emergency_knowledge.py index 045a44b4..9dcd180d 100644 --- a/api-bank/apis/emergency_knowledge.py +++ b/api-bank/apis/emergency_knowledge.py @@ -1,7 +1,8 @@ from apis.api import API class EmergencyKnowledge(API): - description = 'This API searches for a given symptom.' + + description = 'This API searches for a given symptom for emergency knowledge.' input_parameters = { "symptom": {'type': 'str', 'description': 'The symptom to search.'}, } diff --git a/api-bank/apis/get_user_token.py b/api-bank/apis/get_user_token.py index 1da511a9..a2fe6b02 100644 --- a/api-bank/apis/get_user_token.py +++ b/api-bank/apis/get_user_token.py @@ -3,7 +3,7 @@ import os class GetUserToken(API): - description = 'Get the user token.' + description = 'Get the user token by username and password.' input_parameters = { 'username': {'type': 'str', 'description': 'The username of the user.'}, 'password': {'type': 'str', 'description': 'The password of the user.'}, diff --git a/api-bank/apis/modify_agenda.py b/api-bank/apis/modify_agenda.py index c801ee00..39e7e640 100644 --- a/api-bank/apis/modify_agenda.py +++ b/api-bank/apis/modify_agenda.py @@ -5,7 +5,7 @@ class ModifyAgenda(API): - description = "The API for modifying a schedule item includes parameters for token, content, time, and location." + description = "The API for modifying a schedule item includes parameters for content, time, and location." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the agenda.'}, diff --git a/api-bank/apis/modify_meeting.py b/api-bank/apis/modify_meeting.py index 0cc08d17..03801930 100644 --- a/api-bank/apis/modify_meeting.py +++ b/api-bank/apis/modify_meeting.py @@ -5,12 +5,8 @@ class ModifyMeeting(API): - description = "This API allows users to modify a reservation for a meeting" \ - "Function:" \ - "Delete user's reservation for a meeting." \ - "Exception Handling:" \ - "1. If the modification is successful, return a success message." \ - "2. If the modification fails, return a corresponding error message." + + description = "This API allows users to modify a reservation for a meeting" input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'meeting_topic': {'type': 'str', 'description': 'The title of the meeting, no more than 50 characters.'}, diff --git a/api-bank/apis/modify_password.py b/api-bank/apis/modify_password.py index b0fe703f..36f7174b 100644 --- a/api-bank/apis/modify_password.py +++ b/api-bank/apis/modify_password.py @@ -3,7 +3,7 @@ class ModifyPassword(API): - description = 'Modify the password of an account.' + description = 'The API for modifying the password of the account.' input_parameters = { 'token': {'type': 'str', 'description': 'The token of the user.'}, 'old_password': {'type': 'str', 'description': 'The old password of the user.'}, diff --git a/api-bank/apis/modify_registration.py b/api-bank/apis/modify_registration.py index a56fc22f..c79a640e 100644 --- a/api-bank/apis/modify_registration.py +++ b/api-bank/apis/modify_registration.py @@ -2,6 +2,7 @@ import datetime class ModifyRegistration(API): + description = 'This API modifies the registration of a patient given appointment ID.' input_parameters = { "appointment_id": {'type': 'str', 'description': 'The ID of appointment.'}, @@ -129,7 +130,7 @@ def check_api_call_correctness(self, response, groundtruth) -> bool: Returns: - correctness (bool): the correctness of the API call. """ - response_appointment_id = response['input']['appointment_id'] + response_appointment_id = str(response['input']['appointment_id']) groundtruth_appointment_id = groundtruth['input']['appointment_id'] response_new_appointment_date = response['input']['new_appointment_date'] groundtruth_new_appointment_date = groundtruth['input']['new_appointment_date'] diff --git a/api-bank/apis/modify_reminder.py b/api-bank/apis/modify_reminder.py index 8fa1528b..802b3b56 100644 --- a/api-bank/apis/modify_reminder.py +++ b/api-bank/apis/modify_reminder.py @@ -5,11 +5,8 @@ class ModifyReminder(API): - description = "Modify a reminder API that takes three parameters - 'token','content' and 'time'. " \ - "The 'token' parameter refers to the user's token " \ - "and the 'content' parameter refers to the description of the reminder " \ - "and the 'time' parameter specifies the time at which the reminder " \ - "should be triggered." + + description = "The API for deleting a reminder item includes content and time." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the conference.'}, diff --git a/api-bank/apis/open_bank_account.py b/api-bank/apis/open_bank_account.py index 968577a3..6823ea04 100644 --- a/api-bank/apis/open_bank_account.py +++ b/api-bank/apis/open_bank_account.py @@ -5,11 +5,7 @@ class OpenBankAccount(API): - description = "This is an API for opening a bank account with three required parameters:" \ - " account (string), password (string), and name (string). The API creates a new " \ - "account with the specified account identifier, password, and account holder's name. " \ - "If an account with the same identifier already exists, the API will return an error message. " \ - "If the account is successfully created, the API will return a success message." + description = "This is an API for opening a bank account for a user, given the account, password and name." input_parameters = { 'account': {'type': 'str', 'description': 'The account for the user.'}, 'password': {'type': 'str', 'description': 'The password.'}, diff --git a/api-bank/apis/query_history_today.py b/api-bank/apis/query_history_today.py index 685fd45a..cffa443c 100644 --- a/api-bank/apis/query_history_today.py +++ b/api-bank/apis/query_history_today.py @@ -3,7 +3,7 @@ import datetime class QueryHistoryToday(API): - description = 'This API queries the history of a given user today.' + description = 'This API queries the history of the given date.' input_parameters = { 'date': {'type': 'str', 'description': 'The date of the history. Format: %m-%d'}, } diff --git a/api-bank/apis/query_meeting.py b/api-bank/apis/query_meeting.py index 3626833e..b6447401 100644 --- a/api-bank/apis/query_meeting.py +++ b/api-bank/apis/query_meeting.py @@ -5,12 +5,8 @@ class QueryMeeting(API): - description = "This API allows users to query a reservation for a meeting." \ - "Function:" \ - "Query infomation for a meeting." \ - "Exception Handling:" \ - "1. If the Query is successful, return a meeting infomation with json." \ - "2. If the Query fails, return a error message." + + description = "This API allows users to query the information of a meeting." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'meeting_topic': {'type': 'str', 'description': 'The title of the meeting, no more than 50 characters.'}, diff --git a/api-bank/apis/query_registration.py b/api-bank/apis/query_registration.py index e4891cb6..f9e119af 100644 --- a/api-bank/apis/query_registration.py +++ b/api-bank/apis/query_registration.py @@ -2,7 +2,8 @@ import datetime class QueryRegistration(API): - description = 'This API queries the registration of a patient given patient ID.' + + description = 'This API queries the registration of a patient, given patient ID.' input_parameters = { "patient_name": {'type': 'str', 'description': 'The name of patient.'}, "date": {'type': 'str', 'description': 'The date of appointment. Format be like %Y-%m-%d'}, diff --git a/api-bank/apis/query_reminder.py b/api-bank/apis/query_reminder.py index b2208ccf..24d996c0 100644 --- a/api-bank/apis/query_reminder.py +++ b/api-bank/apis/query_reminder.py @@ -5,8 +5,8 @@ class QueryReminder(API): - description = "Query a reminder API that takes three parameters - 'token','content' and 'time'. " \ - "The API used to help user query a reminder. If the reminder exists, the API will return the reminder information. " + + description = "The API for querying a reminder item includes content and time." input_parameters = { 'token': {'type': 'str', 'description': "User's token."}, 'content': {'type': 'str', 'description': 'The content of the reminder.'}, diff --git a/api-bank/apis/query_stock.py b/api-bank/apis/query_stock.py index 4ecf1d6c..50411d12 100644 --- a/api-bank/apis/query_stock.py +++ b/api-bank/apis/query_stock.py @@ -2,7 +2,8 @@ import datetime class QueryStock(API): - description = 'This API queries the stock price of a given stock.' + + description = 'This API queries the stock price of a given stock code and date.' input_parameters = { "stock_code": {'type': 'str', 'description': 'The stock code of the given stock.'}, "date": {'type': 'str', 'description': 'The date of the stock price. Format: %Y-%m-%d'} diff --git a/api-bank/apis/record_health_data.py b/api-bank/apis/record_health_data.py index 245e76de..556e50e4 100644 --- a/api-bank/apis/record_health_data.py +++ b/api-bank/apis/record_health_data.py @@ -2,7 +2,8 @@ import datetime class RecordHealthData(API): - description = 'This API records the health history of a patient given user ID, time and health data.' + + description = 'This API records the health data of a user.' input_parameters = { "user_id": {'type': 'str', 'description': 'The ID of user.'}, "time": {'type': 'str', 'description': 'The time of health data. Format: %Y-%m-%d %H:%M:%S'}, @@ -126,7 +127,7 @@ def check_api_call_correctness(self, response, groundtruth) -> bool: Returns: - correctness (bool): the correctness of the API call. """ - response_user_id = response['input']['user_id'] + response_user_id = str(response['input']['user_id']) groundtruth_user_id = groundtruth['input']['user_id'] response_time = response['input']['time'] groundtruth_time = groundtruth['input']['time'] @@ -137,18 +138,18 @@ def check_api_call_correctness(self, response, groundtruth) -> bool: groundtruth_user_id = groundtruth_user_id.upper().strip() response_time = self.format_check(response_time) groundtruth_time = self.format_check(groundtruth_time) - response_health_data = [{"name":str(i["name"]),"value":str(i["value"])} for i in response_health_data] - groundtruth_health_data = [{"name":str(i["name"]),"value":str(i["value"])} for i in groundtruth_health_data] - response_health_data.sort(key=lambda x: str(x)) - groundtruth_health_data.sort(key=lambda x: str(x)) + # response_health_data = [{"name":str(i["name"]),"value":str(i["value"])} for i in response_health_data] + # groundtruth_health_data = [{"name":str(i["name"]),"value":str(i["value"])} for i in groundtruth_health_data] + # response_health_data.sort(key=lambda x: str(x)) + # groundtruth_health_data.sort(key=lambda x: str(x)) if response_user_id != groundtruth_user_id: return False if response_time != groundtruth_time: return False - if response_health_data != groundtruth_health_data: - return False + # if response_health_data != groundtruth_health_data: + # return False if response['output'] != groundtruth['output']: return False if response['exception'] != groundtruth['exception']: diff --git a/api-bank/apis/register_user.py b/api-bank/apis/register_user.py index 2c55ea60..ba4cc8bc 100644 --- a/api-bank/apis/register_user.py +++ b/api-bank/apis/register_user.py @@ -5,7 +5,7 @@ class RegisterUser(API): - description = 'Register a user.' + description = 'The API for registering a account, given the username, password and email.' input_parameters = { 'username': {'type': 'str', 'description': 'The username of the user.'}, 'password': {'type': 'str', 'description': 'The password of the user.'}, diff --git a/api-bank/apis/search_engine.py b/api-bank/apis/search_engine.py index 729b2518..b09a7d5a 100644 --- a/api-bank/apis/search_engine.py +++ b/api-bank/apis/search_engine.py @@ -9,7 +9,7 @@ from nltk.tokenize import word_tokenize class SearchEngine(API): - description = 'This API searches for a given keyword.' + description = 'This API searches for a given keyword for search engine.' input_parameters = { "keyword": {'type': 'str', 'description': 'The keyword to search.'}, } @@ -96,7 +96,6 @@ def search(self, keyword: str) -> list: if len(rankings) > 2: rankings = rankings[:2] results = [self.database["raw_documents"][i] for i in rankings] - return results def check_api_call_correctness(self, response, groundtruth) -> bool: diff --git a/api-bank/apis/send_email.py b/api-bank/apis/send_email.py index 03620b48..c49c1600 100644 --- a/api-bank/apis/send_email.py +++ b/api-bank/apis/send_email.py @@ -3,7 +3,7 @@ import re class SendEmail(API): - description = 'This API sends an email.' + description = 'This API for sending email, given the receiver, subject and content.' input_parameters = { "receiver": {'type': 'str', 'description': 'The receiver address of the email.'}, "subject": {'type': 'str', 'description': 'The subject address of the email.'}, diff --git a/api-bank/apis/timed_switch.py b/api-bank/apis/timed_switch.py index 2ebaa50b..ba05994e 100644 --- a/api-bank/apis/timed_switch.py +++ b/api-bank/apis/timed_switch.py @@ -2,7 +2,8 @@ import datetime class TimedSwitch(API): - description = 'This API switches a smart device on or off at a specified time' + + description = 'This API for setting a timed switch for a smart device.' input_parameters = { "name": {'type': 'str', 'description': 'The name of the smart device.'}, "time": {'type': 'str', 'description': 'The time to switch the device on or off. Format: %Y-%m-%d %H:%M:%S'}, diff --git a/api-bank/apis/tool_search.py b/api-bank/apis/tool_search.py index 66c1a1a5..1df69610 100644 --- a/api-bank/apis/tool_search.py +++ b/api-bank/apis/tool_search.py @@ -7,7 +7,8 @@ import os class ToolSearcher(API): - description = 'Searches for relevant tools in library based on the keyword.' + + description = 'Searches for relevant tools in library based on the keywords.' input_parameters = { 'keywords': {'type': 'str', 'description': 'The keyword to search for.'} } diff --git a/api-bank/apis/wiki.py b/api-bank/apis/wiki.py index d2fcd4d8..56992970 100644 --- a/api-bank/apis/wiki.py +++ b/api-bank/apis/wiki.py @@ -1,7 +1,7 @@ from apis.api import API class Wiki(API): - description = 'This API searches for a given keyword.' + description = 'This API for searching a keyword in Wikipedia.' input_parameters = { "keyword": {'type': 'str', 'description': 'The keyword to search.'}, } diff --git a/api-bank/evaluator.py b/api-bank/evaluator.py index 3d2ac0aa..9e7dddb8 100644 --- a/api-bank/evaluator.py +++ b/api-bank/evaluator.py @@ -3,14 +3,13 @@ import re from rouge import Rouge import os -from utils import ChatGPTWrapper, DavinciWrapper +from utils import ChatGPTWrapper, DavinciWrapper, GPT4Wrapper import logging from tqdm import tqdm from api_call_extraction import parse_api_call from datetime import datetime import numpy as np -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filename=f'evaluator-{datetime.now().strftime("%Y-%m-%d-%H-%M-%S")}.log', filemode='w') def calculate_rouge_l_score(reference, hypothesis): rouge = Rouge() @@ -111,13 +110,16 @@ def get_api_call(model_output): if __name__ == '__main__': data_dir = 'lv1-lv2-samples/level-1-given-desc' api_test_enabled = False - dialog_test_enabled = True + dialog_test_enabled = not api_test_enabled + + api_str = 'api' if api_test_enabled else 'response' + if os.path.basename(data_dir).endswith('given-desc'): tool_search_enabled = False else: tool_search_enabled = True - chatgpt = DavinciWrapper(api_key='YOUR_API_KEY') + chatgpt = GPT4Wrapper(api_key='YOUR_API_KEY') api_call_prompt = ''' Based on the given API description and the existing conversation history 1..t, please generate the API request that the AI should call in step t+1 and output it in the format of [ApiName(key1='value1', key2='value2', ...)], replace the ApiName with the actual API name, and replace the key and value with the actual parameters. Your output should start with a square bracket "[" and end with a square bracket "]". Do not output any other explanation or prompt or the result of the API call in your output. diff --git a/api-bank/evaluator_by_json.py b/api-bank/evaluator_by_json.py new file mode 100644 index 00000000..271759ee --- /dev/null +++ b/api-bank/evaluator_by_json.py @@ -0,0 +1,313 @@ +import json +from tool_manager import ToolManager +import re +from rouge import Rouge +import os +from utils import ChatGPTWrapper, DavinciWrapper +import logging +from tqdm import tqdm +from api_call_extraction import parse_api_call +from datetime import datetime +import numpy as np + +def calculate_rouge_l_score(reference, hypothesis): + rouge = Rouge() + scores = rouge.get_scores(hypothesis, reference) + rouge_l_score = scores[0]['rouge-l']['f'] + return rouge_l_score + +class Sample: + def __init__(self, chat_history, apis, ground_truth): + self.chat_history = chat_history + self.apis = apis + self.ground_truth = ground_truth + + def __repr__(self): + return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(self.chat_history, self.apis, self.ground_truth) + # return 'Chat history: {}, apis: {}, ground truth: {}'.format(self.chat_history, self.apis, self.ground_truth) + + def __str__(self) -> str: + return self.__repr__() + + @classmethod + def from_chat_history(cls, chat_history): + apis = set() + api_positions = [] + for i, item in enumerate(chat_history): + if item['role'] == 'API': + apis.add(item['api_name']) + api_positions.append(i) + + samples = [] + for i in api_positions: + sample = cls(chat_history[:i], apis, chat_history[i]) + samples.append(sample) + sample = cls(chat_history[:i + 1], apis, chat_history[i + 1]) + samples.append(sample) + + return samples + + +class Evaluator: + + def __init__(self, samples): + self.dataset = samples + self.sample_ids = list(range(len(self.dataset))) + + def get_all_sample_ids(self): + return self.sample_ids + + def get_api_description(self, api_name): + tool_manager = ToolManager() + return tool_manager.get_api_description(api_name) + + + def get_model_input(self, sample_id): + sample = self.dataset[sample_id] + apis = sample.apis + chat_history = sample.chat_history + tool_manager = ToolManager() + api_descriptions = [] + for api_name in apis: + api_descriptions.append(tool_manager.get_api_description(api_name)) + api_descriptions = '\n'.join(api_descriptions) + return api_descriptions, chat_history + + + def evaluate(self, sample_id, model_output): + # model_output [ApiName(param1=value1, param2=value2), ...)] + tool_manager = ToolManager() + + sample = self.dataset[sample_id] + ground_truth = sample.ground_truth + if ground_truth['role'] == 'API': + try: + api_name, param_dict = parse_api_call(model_output) + except Exception as e: + raise Exception('Parse API Call Error: {}'.format(model_output)) + if api_name != ground_truth['api_name']: + return False, 'API Name Mismatch: {} vs {}'.format(api_name, ground_truth['api_name']) + # try: + result = tool_manager.api_call(api_name, **param_dict) + # except Exception as e: + # return False, str(e) + api = tool_manager.init_tool(api_name) + try: + correct = api.check_api_call_correctness(result, ground_truth['result']) + except KeyError: + correct = False + result = 'KeyError' + str(result) + return correct, result + elif ground_truth['role'] == 'AI': + score = calculate_rouge_l_score(ground_truth['text'], model_output) + return round(score, 4) + + +def get_api_call(model_output): + api_call_pattern = r"\[(\w+)\((.*)\)\]" + api_call_pattern = re.compile(api_call_pattern) + match = api_call_pattern.search(model_output) + if match: + return match.group(0) + else: + return None + +if __name__ == '__main__': + data_dir = 'lv1-lv2-samples/level-1-given-desc' + evaluation_path = 'path-to-json' + api_test_enabled = True + dialog_test_enabled = not api_test_enabled + + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', filename=evaluation_path.replace('.json', '.log'), filemode='w') + + if os.path.basename(data_dir).endswith('given-desc'): + tool_search_enabled = False + else: + tool_search_enabled = True + + with open(evaluation_path, 'r') as f: + predictions = [json.loads(line) for line in f] + pred_map = {} + for pred in predictions: + if pred['file'] not in pred_map: + pred_map[pred['file']] = {} + pred_map[pred['file']][pred['id']] = pred + + total_api_calls = 0 + correct_api_calls = 0 + total_after_api_response = 0 + total_score_after_api_response = 0 + + rougel_scores = [] + + error_statistic = { + 'NO_API_CALL': { + 'count': 0, + 'samples': [] + }, + 'API_NAME_MISMATCH': { + 'count': 0, + 'samples': [] + }, + 'HAS_EXCEPTION': { + 'count': 0, + 'samples': [] + }, + 'INPUT_MISMATCH': { + 'count': 0, + 'samples': [] + }, + 'OUTPUT_MISMATCH': { + 'count': 0, + 'samples': [] + }, + 'INVALID_INPUT_PARAMETER': { + 'count': 0, + 'samples': [] + }, + 'KEY_ERROR': { + 'count': 0, + 'samples': [] + }, + 'FAILED_PARSE_API_CALL': { + 'count': 0, + 'samples': [] + }, + 'MISS_INPUT_ARGUMENT': { + 'count': 0, + 'samples': [] + }, + } + + jsonl_files = [f for f in os.listdir(data_dir) if f.endswith('.jsonl')] + + for file in tqdm(jsonl_files, desc='Processing files', ncols=100): + history = [] + with open(os.path.join(data_dir, file), 'r') as f: + for line in f: + history.append(json.loads(line)) + samples = Sample.from_chat_history(history) + evaluator = Evaluator(samples) + + for sample_id in evaluator.get_all_sample_ids(): + sample = evaluator.dataset[sample_id] + if sample.ground_truth['role'] == 'API' and api_test_enabled: + total_api_calls += 1 + + # assert file in pred_map + # assert sample_id in pred_map[file] + if sample_id not in pred_map[file]: + continue + model_output = pred_map[file][sample_id]['pred'] + + api_call = get_api_call(model_output) + if api_call: + try: + correct, model_output_result = evaluator.evaluate(sample_id, api_call) + except AssertionError as e: + if 'The API name is not correct.' in str(e): + error_statistic['API_NAME_MISMATCH']['count'] += 1 + error_statistic['API_NAME_MISMATCH']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + elif 'invalid parameter name' in str(e): + error_statistic['INVALID_INPUT_PARAMETER']['count'] += 1 + error_statistic['INVALID_INPUT_PARAMETER']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + raise e + except Exception as e: + if 'Parse API Call Error' in str(e): + error_statistic['FAILED_PARSE_API_CALL']['count'] += 1 + error_statistic['FAILED_PARSE_API_CALL']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + if 'missing' in str(e) and 'required positional argument' in str(e): + error_statistic['MISS_INPUT_ARGUMENT']['count'] += 1 + error_statistic['MISS_INPUT_ARGUMENT']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + raise e + + else: + model_output_result = 'No API call found' + error_statistic['NO_API_CALL']['count'] += 1 + error_statistic['NO_API_CALL']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + if isinstance(model_output_result, str) and model_output_result.startswith('API Name Mismatch'): + error_statistic['API_NAME_MISMATCH']['count'] += 1 + error_statistic['API_NAME_MISMATCH']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + if isinstance(model_output_result, str) and model_output_result.startswith('KeyError'): + error_statistic['KEY_ERROR']['count'] += 1 + error_statistic['KEY_ERROR']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + + if correct: + correct_api_calls += 1 + logging.info('Correct API call: {} Ground truth: {}'.format(api_call, sample.ground_truth)) + else: + logging.info('Incorrect model output: {} Result: {} Ground truth: {} File: {} Sample ID: {} '.format(model_output.replace('\n', ' '), model_output_result, sample.ground_truth, file, sample_id)) + assert isinstance(model_output_result, dict) + if model_output_result['exception']: + error_statistic['HAS_EXCEPTION']['count'] += 1 + error_statistic['HAS_EXCEPTION']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'model_output_result': model_output_result, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + if model_output_result['output'] != sample.ground_truth['result']['output']: + error_statistic['OUTPUT_MISMATCH']['count'] += 1 + error_statistic['OUTPUT_MISMATCH']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'model_output_result': model_output_result, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + if model_output_result['input'] != sample.ground_truth['result']['input']: + error_statistic['INPUT_MISMATCH']['count'] += 1 + error_statistic['INPUT_MISMATCH']['samples'].append( + {'file': file, 'sample_id': sample_id, 'model_output': model_output, 'model_output_result': model_output_result, 'sample': sample, 'pred': pred_map[file][sample_id]} + ) + continue + elif sample.ground_truth['role'] == 'AI' and dialog_test_enabled: + assert file in pred_map + assert sample_id in pred_map[file] + model_output = pred_map[file][sample_id]['pred'] + + if model_output: + score = evaluator.evaluate(sample_id, model_output) + else: + score = 0 + rougel_scores.append(score) + if score < 0.2: + logging.info('Low score: {} Score: {} Ground truth: {} File: {} Sample ID: {}'.format(model_output.replace('\n', ' '), score, sample.ground_truth, file, sample_id)) + + def print_error_samples(sample): + print('Instruction: \n{}\n'.format(sample['pred']['instruction'])) + print('Input: \n{}\n'.format(sample['pred']['input'])) + print('Output: \n{}\n'.format(sample['model_output'])) + print('Ground truth: \n{}\n'.format(sample['pred']['expected_output'])) + + + print('Error statistic: {}'.format(error_statistic)) + for key in error_statistic: + print(key, error_statistic[key]['count']) + + if dialog_test_enabled: + print('Dialog score: {:.4f}'.format(np.mean(rougel_scores))) + + if api_test_enabled: + print('Total API calls: {}'.format(total_api_calls)) + print('Correct API calls: {}'.format(correct_api_calls)) + print('Accuracy: {:.4f}'.format(correct_api_calls / total_api_calls)) + logging.info('Total API calls: {}'.format(total_api_calls)) + logging.info('Correct API calls: {}'.format(correct_api_calls)) + logging.info('Accuracy: {:.4f}'.format(correct_api_calls / total_api_calls)) diff --git a/api-bank/figures/flowchart.png b/api-bank/figures/flowchart.png deleted file mode 100644 index 386fd118..00000000 Binary files a/api-bank/figures/flowchart.png and /dev/null differ diff --git a/api-bank/figures/multi-agent.png b/api-bank/figures/multi-agent.png new file mode 100644 index 00000000..5d63a7b2 Binary files /dev/null and b/api-bank/figures/multi-agent.png differ diff --git a/api-bank/figures/system.png b/api-bank/figures/system.png deleted file mode 100644 index 672721fe..00000000 Binary files a/api-bank/figures/system.png and /dev/null differ diff --git a/api-bank/figures/three_ability.png b/api-bank/figures/three_ability.png new file mode 100644 index 00000000..e11aaaad Binary files /dev/null and b/api-bank/figures/three_ability.png differ diff --git a/api-bank/lv3_apis/account_info.py b/api-bank/lv3_apis/account_info.py new file mode 100644 index 00000000..38726146 --- /dev/null +++ b/api-bank/lv3_apis/account_info.py @@ -0,0 +1,54 @@ +from apis.api import API +class AccountInfo(API): + description = "API for retrieving and updating user account information." + input_parameters = { + 'username': {'type': 'str', 'description': 'Name of the user.'}, + 'password': {'type': 'str', 'description': 'Password of the user.'}, + } + output_parameters = { + 'status': {'type': 'str', 'description': 'success or failed'}, + 'account_info': {'type': 'dict', 'description': 'User account information'} + } + + def __init__(self): + self.database = [ + (('John', '123456'), {'email': 'john@example.com', 'phone': '1234567890'}), + (('Mary', 'abcdef'), {'email': 'mary@example.com', 'phone': '0987654321'}), + (('Peter', 'qwerty'), {'email': 'peter@example.com', 'phone': '1231231234'}), + (('Tom', 'asdfgh'), {'email': 'tom@example.com', 'phone': '4564564567'}), + (('Jerry', 'zxcvbn'), {'email': 'jerry@example.com', 'phone': '7897897890'}), + ] + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, username: str, password: str) -> dict: + input_parameters = { + 'username': username, + 'password': password + } + try: + account_info = self.retrieve_account_info(username, password) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': account_info, 'exception': None} + + def retrieve_account_info(self, username, password): + for (user, pwd), account_info in self.database: + if user == username and pwd == password: + return account_info + raise Exception('User not found.') \ No newline at end of file diff --git a/api-bank/lv3_apis/add_meeting.py b/api-bank/lv3_apis/add_meeting.py new file mode 100644 index 00000000..871bc7a1 --- /dev/null +++ b/api-bank/lv3_apis/add_meeting.py @@ -0,0 +1,94 @@ +from apis.api import API +import json +import os +import datetime + + +class AddMeeting(API): + description = "This API allows users to make a reservation for a meeting and store the meeting information in the database." \ + "Function:" \ + "Allow users to make a reservation for a meeting." \ + "Exception Handling:" \ + "1. If the reservation is successful, return a success message." \ + "2. If the reservation fails, return a corresponding error message." + input_parameters = { + 'meeting_topic': {'type': 'str', 'description': 'The title of the meeting, no more than 50 characters.'}, + 'start_time': {'type': 'str', + 'description': 'The start time of the meeting, in the pattern of %Y-%m-%d %H:%M:%S'}, + 'end_time': {'type': 'str', + 'description': 'The end time of the meeting, in the pattern of %Y-%m-%d %H:%M:%S'}, + 'location': {'type': 'str', + 'description': 'The location where the meeting to be held, no more than 100 characters.'}, + 'attendees': {'type': 'list(str)', + 'description': 'The attendees of the meeting, including names, positions and other information.'} + } + output_parameters = { + 'status': {'type': 'str', 'description': 'success or failed'} + } + + def __init__(self): + self.database = [] + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + response_content, groundtruth_content = response['input']['meeting_topic'].split(" "), groundtruth['input'][ + 'meeting_topic'].split(" ") + content_satisfied = False + if len(set(response_content).intersection(set(groundtruth_content))) / len(set(response_content).union( + set(groundtruth_content))) > 0.5: + content_satisfied = True + + response['input'].pop('meeting_topic') + groundtruth['input'].pop('meeting_topic') + + if content_satisfied and response['input'] == groundtruth['input'] and response['output'] == \ + groundtruth['output'] and response['exception'] == groundtruth['exception']: + return True + else: + return False + + def call(self, meeting_topic: str, start_time: str, end_time: str, location: str, + attendees: list) -> dict: + input_parameters = { + 'meeting_topic': meeting_topic, + 'start_time': start_time, + 'end_time': end_time, + 'location': location, + 'attendees': attendees + } + try: + status = self.add_meeting(meeting_topic, start_time, end_time, location, attendees) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': status, + 'exception': None} + + def add_meeting(self, meeting_topic: str, start_time: str, end_time: str, location: str, + attendees: list) -> str: + + # Check the format of the input parameters. + datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.datetime.strptime(end_time, '%Y-%m-%d %H:%M:%S') + + if meeting_topic.strip() == "": + raise Exception('Meeting Topic should not be null') + self.database.append({ + 'meeting_topic': meeting_topic, + 'start_time': start_time, + 'end_time': end_time, + 'location': location, + 'attendees': attendees + }) + return "success" diff --git a/api-bank/lv3_apis/calculator.py b/api-bank/lv3_apis/calculator.py new file mode 100644 index 00000000..21aae873 --- /dev/null +++ b/api-bank/lv3_apis/calculator.py @@ -0,0 +1,207 @@ +from apis.api import API + +class Calculator(API): + description = 'This API provides basic arithmetic operations: addition, subtraction, multiplication, and division.' + input_parameters = { + 'formula': {'type': 'str', 'description': 'The formula that needs to be calculated. Only integers are supported. Valid operators are +, -, *, /, and (, ). For example, \'(1 + 2) * 3\'.'}, + } + output_parameters = { + 'result': {'type': 'float', 'description': 'The result of the formula.'}, + } + def __init__(self) -> None: + pass + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + re_formula = response['input']['formula'].replace(' ', '') + gt_formula = groundtruth['input']['formula'].replace(' ', '') + + if re_formula == gt_formula and response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception']: + return True + else: + return False + + def call(self, formula: str) -> float: + """ + Calculates the result of the formula. + + Parameters: + - formula (str): the formula that needs to be calculated. Valid operators are +, -, *, /, and (, ). For example, '(1 + 2) * 3'. + + Returns: + - result (float): the result of the formula. + - formula (str): the formula that was calculated. + """ + input_parameters = { + 'formula': formula, + } + try: + result = self.calculate(formula) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': result, 'exception': None} + + def calculate(self, formula: str) -> float: + """ + Calculates the result of the formula. + + Parameters: + - formula (str): the formula that needs to be calculated. Valid operators are +, -, *, /, and (, ). For example, '(1 + 2) * 3'. + + Returns: + - result (float): the result of the formula. + """ + # Remove all spaces from the formula + formula = formula.replace(' ', '') + + # Check if the formula is valid + if not self.is_valid_formula(formula): + raise Exception('invalid formula') + + # Convert the formula to a list + formula = self.convert_formula_to_list(formula) + + # Calculate the result + result = self.calculate_formula(formula) + + return result + + def is_valid_formula(self, formula: str) -> bool: + """ + Checks if the formula is valid. + + Parameters: + - formula (str): the formula that needs to be checked. + + Returns: + - is_valid (bool): True if the formula is valid, False otherwise. + """ + # Check if the formula is empty + if len(formula) == 0: + return False + + # Check if the formula contains invalid characters + for c in formula: + if c not in '0123456789+-*/()': + return False + + # Check if the formula contains an invalid number of parentheses + if formula.count('(') != formula.count(')'): + return False + + # Check if the formula contains an invalid number of operators + if formula.count('+') + formula.count('-') + formula.count('*') + formula.count('/') == 0: + return False + + # Check if the formula contains an invalid number of operands + if formula.count('+') + formula.count('-') + formula.count('*') + formula.count('/') + 1 == len(formula): + return False + + return True + + def convert_formula_to_list(self, formula: str) -> list: + """ + Converts the formula to a list. + + Parameters: + - formula (str): the formula that needs to be converted. + + Returns: + - formula_list (list): the formula converted to a list. + """ + formula_list = [] + number = '' + for c in formula: + if c in '0123456789': + number += c + else: + if number != '': + formula_list.append(float(number)) + number = '' + formula_list.append(c) + if number != '': + formula_list.append(float(number)) + + return formula_list + + def calculate_formula(self, formula: list) -> float: + """ + Calculates the result of the formula. + + Parameters: + - formula (list): the formula that needs to be calculated. + + Returns: + - result (float): the result of the formula. + """ + # Calculate the result of the parentheses + while '(' in formula: + left_parenthesis_index = formula.index('(') + right_parenthesis_index = formula.index(')') + formula[left_parenthesis_index:right_parenthesis_index + 1] = [self.calculate_formula(formula[left_parenthesis_index + 1:right_parenthesis_index])] + + # Calculate the result of the multiplication and division + while '*' in formula or '/' in formula: + if '*' in formula and '/' in formula: + if formula.index('*') < formula.index('/'): + index = formula.index('*') + else: + index = formula.index('/') + elif '*' in formula: + index = formula.index('*') + else: + index = formula.index('/') + formula[index - 1:index + 2] = [self.calculate_operation(formula[index - 1], formula[index], formula[index + 1])] + + # Calculate the result of the addition and subtraction + while '+' in formula or '-' in formula: + if '+' in formula and '-' in formula: + if formula.index('+') < formula.index('-'): + index = formula.index('+') + else: + index = formula.index('-') + elif '+' in formula: + index = formula.index('+') + else: + index = formula.index('-') + formula[index - 1:index + 2] = [self.calculate_operation(formula[index - 1], formula[index], formula[index + 1])] + + return formula[0] + + def calculate_operation(self, operand1: float, operator: str, operand2: float) -> float: + """ + Calculates the result of the operation. + + Parameters: + - operand1 (float): the first operand. + - operator (str): the operator. + - operand2 (float): the second operand. + + Returns: + - result (float): the result of the operation. + """ + if operator == '+': + return operand1 + operand2 + elif operator == '-': + return operand1 - operand2 + elif operator == '*': + return operand1 * operand2 + elif operator == '/': + return operand1 / operand2 + +if __name__ == '__main__': + # Create the API + api = Calculator() + response = api.call('(1 + 2) * 3') + print(response) \ No newline at end of file diff --git a/api-bank/lv3_apis/clothing_recommandation.py b/api-bank/lv3_apis/clothing_recommandation.py new file mode 100644 index 00000000..eb5251f8 --- /dev/null +++ b/api-bank/lv3_apis/clothing_recommandation.py @@ -0,0 +1,81 @@ +from apis.api import API +class ClothingRecommendation(API): + description = "API for providing clothing recommendations based on weather conditions." + input_parameters = { + 'temperature': {'type': 'float', 'description': 'Temperature in Celsius.'}, + 'humidity': {'type': 'float', 'description': 'Relative humidity in percentage.'}, + 'weather_conditions': {'type': 'str', 'description': 'Description of weather conditions.'}, + } + output_parameters = { + 'clothing_options': {'type': 'list', 'description': 'List of recommended clothing options.'}, + } + + def call(self, temperature: float, humidity: float, weather_conditions: str) -> dict: + input_parameters = { + 'temperature': temperature, + 'humidity': humidity, + 'weather_conditions': weather_conditions, + } + try: + clothing_options = self.get_clothing_recommendation(temperature, humidity, weather_conditions) + return { + 'api_name': self.__class__.__name__, + 'input': input_parameters, + 'output': {'clothing_options': clothing_options}, + 'exception': None + } + except Exception as e: + exception = str(e) + return { + 'api_name': self.__class__.__name__, + 'input': input_parameters, + 'output': None, + 'exception': exception + } + + def get_clothing_recommendation(self, temperature: float, humidity: float, + weather_conditions: str) -> list: + # Clothing recommendation logic based on weather conditions + clothing_options = [] + + if temperature < 10: + clothing_options.append('Warm coat') + clothing_options.append('Hat') + clothing_options.append('Gloves') + + if temperature >= 10 and temperature < 20: + clothing_options.append('Light jacket') + clothing_options.append('Long-sleeved shirt') + + if temperature >= 20 and temperature < 30: + clothing_options.append('T-shirt') + clothing_options.append('Shorts') + + if temperature >= 30: + clothing_options.append('Sun hat') + clothing_options.append('Sunglasses') + clothing_options.append('Loose-fitting clothes') + + if humidity > 70: + clothing_options.append('Umbrella') + clothing_options.append('Waterproof shoes') + + + if 'rain' in weather_conditions.lower(): + clothing_options.append('Raincoat') + + return clothing_options + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] \ No newline at end of file diff --git a/api-bank/lv3_apis/email_reminder.py b/api-bank/lv3_apis/email_reminder.py new file mode 100644 index 00000000..24c8eeb4 --- /dev/null +++ b/api-bank/lv3_apis/email_reminder.py @@ -0,0 +1,63 @@ +import datetime +from apis.api import API + +class EmailReminder(API): + description = "This API sends an email reminder to the user with the meeting details." + input_parameters = { + 'content': {'type': 'str', 'description': 'The content of the email.'}, + 'time': {'type': 'str', 'description': 'The time for the meeting. Format: %Y-%m-%d %H:%M:%S'}, + 'location': {'type': 'str', 'description': 'The location of the meeting.'}, + 'recipient': {'type': 'str', 'description': 'The email address of the recipient.'}, + } + output_parameters = { + 'status': {'type': 'str', 'description': 'success or failed'} + } + + def __init__(self): + pass + + def call(self, content: str, time: str, location: str, recipient: str) -> dict: + input_parameters = { + 'content': content, + 'time': time, + 'location': location, + 'recipient': recipient + } + try: + self.send_email(content, time, location, recipient) + status = 'success' + except Exception as e: + status = 'failed' + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': status, 'exception': None} + + def send_email(self, content: str, time: str, location: str, recipient: str): + # Validate the input parameters + datetime.datetime.strptime(time, '%Y-%m-%d %H:%M:%S') + + if content.strip() == "": + raise Exception('Content should not be empty') + + # Send the email to the recipient + # email_subject = f"Meeting Reminder: {content}" + # email_body = f"Meeting Details:\n\nContent: {content}\nTime: {time}\nLocation: {location}" + # self.email_sender.send_email(token, recipient, email_subject, email_body) + return 'success' + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + response['input'].pop('content') + groundtruth['input'].pop('content') + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] \ No newline at end of file diff --git a/api-bank/lv3_apis/flight_search.py b/api-bank/lv3_apis/flight_search.py new file mode 100644 index 00000000..3ea28c4a --- /dev/null +++ b/api-bank/lv3_apis/flight_search.py @@ -0,0 +1,87 @@ +import datetime +from apis.api import API + +class FlightSearch(API): + description = "API to retrieve flight options based on the destination and travel dates." + input_parameters = { + 'source': {'type': 'str', 'description': "Source for the flight."}, + 'destination': {'type': 'str', 'description': "Destination for the flight."}, + 'travel_dates': {'type': 'str', 'description': 'Travel dates. Format: %Y-%m-%d'} + } + output_parameters = { + 'flights': {'type': 'list', 'description': 'List of available flight options.'} + } + + def __init__(self): + self.flight_data = [{ + 'source': 'New York', + 'destination': 'San Francisco', + 'departure_date': datetime.datetime(2022, 1, 1, 12, 0, 0), + 'arrival_date': datetime.datetime(2022, 1, 1, 15, 0, 0), + }, + { + 'source': 'Los Angeles', + 'destination': 'San Francisco', + 'departure_date': datetime.datetime(2022, 1, 2, 12, 0, 0), + 'arrival_date': datetime.datetime(2022, 1, 2, 15, 0, 0), + }, + { + 'source': 'London', + 'destination': 'San Francisco', + 'departure_date': datetime.datetime(2022, 1, 3, 12, 0, 0), + 'arrival_date': datetime.datetime(2022, 1, 3, 15, 0, 0), + }, + { + 'source': 'New York', + 'destination': 'London', + 'departure_date': datetime.datetime(2022, 1, 4, 12, 0, 0), + 'arrival_date': datetime.datetime(2022, 1, 4, 15, 0, 0), + }, + { + 'source': 'New York', + 'destination': 'Los Angeles', + 'departure_date': datetime.datetime(2022, 1, 5, 12, 0, 0), + 'arrival_date': datetime.datetime(2022, 1, 5, 15, 0, 0), + }] + + def get_flights(self, source, destination, travel_dates): + travel_dates = datetime.datetime.strptime(travel_dates, '%Y-%m-%d') + flights = [] + for flight in self.flight_data: + if flight['source'] == source and flight['destination'] == destination and flight['departure_date'].date() == travel_dates.date(): + flight['departure_date'] = flight['departure_date'].strftime('%Y-%m-%d %H:%M:%S') + flight['arrival_date'] = flight['arrival_date'].strftime('%Y-%m-%d %H:%M:%S') + flights.append(flight) + return flights + + def call(self, source, destination, travel_dates): + input_parameters = { + 'source': source, + 'destination': destination, + 'travel_dates': travel_dates + } + try: + flights = self.get_flights(source, destination, travel_dates) + output_parameters = {'flights': flights} + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, + 'exception': None} + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + diff --git a/api-bank/lv3_apis/geocoding.py b/api-bank/lv3_apis/geocoding.py new file mode 100644 index 00000000..617cdf5b --- /dev/null +++ b/api-bank/lv3_apis/geocoding.py @@ -0,0 +1,53 @@ +from apis.api import API + +class Geocoding(API): + description = "The API for converting an address or place name to geographical coordinates." + input_parameters = { + 'address': {'type': 'str', 'description': 'The address or place name to be converted.'}, + } + output_parameters = { + 'latitude': {'type': 'float', 'description': 'The latitude of the location.'}, + 'longitude': {'type': 'float', 'description': 'The longitude of the location.'}, + } + + def __init__(self): + self.address_to_coordinates = { + 'New York City': (40.7128, 74.0060), + 'San Francisco': (37.7749, 122.4194), + 'London': (51.5074, 0.1278), + 'Paris': (48.8566, 2.3522), + 'Tokyo': (35.6762, 139.6503), + } + + def call(self, address: str) -> dict: + input_parameters = { + 'address': address, + } + try: + latitude, longitude = self.convert_address_to_coordinates(address) + output_parameters = { + 'latitude': latitude, + 'longitude': longitude, + } + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, 'exception': None} + + def convert_address_to_coordinates(self, address: str) -> tuple: + return self.address_to_coordinates[address] + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] \ No newline at end of file diff --git a/api-bank/lv3_apis/get_occupation_salary.py b/api-bank/lv3_apis/get_occupation_salary.py new file mode 100644 index 00000000..b0d4937a --- /dev/null +++ b/api-bank/lv3_apis/get_occupation_salary.py @@ -0,0 +1,65 @@ +from apis.api import API + +class GetOccupationSalary(API): + description = "API for querying the salary of a given occupation." + input_parameters = { + 'occupation': {'type': 'str', 'description': 'The occupation to query.'}, + } + output_parameters = { + 'salary': {'type': 'float', 'description': 'The salary of the given occupation.'} + } + + def __init__(self): + self.salary_info = { + 'Financial Analyst': 100000, + 'Software Engineer': 120000, + 'Data Scientist': 150000, + 'Product Manager': 130000, + 'Doctor': 200000, + } + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, occupation: str) -> dict: + input_parameters = { + 'occupation': occupation, + } + try: + salary = self.query_salary(occupation) + output_parameters = { + 'salary': salary, + } + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, + 'exception': None} + + def query_salary(self, occupation: str) -> float: + """ + Queries the salary of the given occupation. + + Parameters: + - occupation (str): the occupation to query. + + Returns: + - salary (float): the salary of the given occupation. + """ + if occupation not in self.salary_info: + raise Exception("The occupation is not in the database.") + return self.salary_info[occupation] + diff --git a/api-bank/lv3_apis/get_weather.py b/api-bank/lv3_apis/get_weather.py new file mode 100644 index 00000000..57e85d4d --- /dev/null +++ b/api-bank/lv3_apis/get_weather.py @@ -0,0 +1,79 @@ +import requests +from apis.api import API + +class GetWeatherForCoordinates(API): + description = "Retrieves current weather information based on the provided coordinates." + input_parameters = { + 'latitude': {'type': 'float', 'description': 'Latitude of the location.'}, + 'longitude': {'type': 'float', 'description': 'Longitude of the location.'} + } + output_parameters = { + 'temperature': {'type': 'float', 'description': 'Current temperature in Celsius.'}, + 'humidity': {'type': 'float', 'description': 'Current humidity level.'}, + 'description': {'type': 'str', 'description': 'Weather description.'} + } + + def __init__(self): + self.coordinates_to_weather = [ + ((40.7128, 74.0060), { + 'temperature': 10, + 'humidity': 0.5, + 'description': 'Clear' + }), + ((37.7749, 122.4194), { + 'temperature': 20, + 'humidity': 0.8, + 'description': 'Cloudy' + }), + ((51.5074, 0.1278), { + 'temperature': 5, + 'humidity': 0.9, + 'description': 'Rainy' + }), + ((48.8566, 2.3522), { + 'temperature': 15, + 'humidity': 0.7, + 'description': 'Sunny' + }), + ((35.6762, 139.6503), { + 'temperature': 25, + 'humidity': 0.6, + 'description': 'Rainy' + }), + ] + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, latitude: float, longitude: float) -> dict: + input_parameters = { + 'latitude': latitude, + 'longitude': longitude + } + try: + response = self.fetch_weather(latitude, longitude) + output_parameters = { + 'temperature': response['temperature'], + 'humidity': response['humidity'], + 'description': response['description'] + } + except requests.exceptions.RequestException as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, 'exception': None} + + def fetch_weather(self, latitude, longitude): + for coordinates, weather in self.coordinates_to_weather: + if coordinates == (latitude, longitude): + return weather \ No newline at end of file diff --git a/api-bank/lv3_apis/hotel_availability.py b/api-bank/lv3_apis/hotel_availability.py new file mode 100644 index 00000000..41c6f735 --- /dev/null +++ b/api-bank/lv3_apis/hotel_availability.py @@ -0,0 +1,103 @@ +import datetime +from apis.api import API + +class HotelAvailability(API): + description = "API for checking hotel availability based on the destination and travel dates." + input_parameters = { + 'destination': {'type': 'str', 'description': "Destination for hotel search."}, + 'check_in_date': {'type': 'str', 'description': 'Check-in date. Format: %Y-%m-%d'}, + 'check_out_date': {'type': 'str', 'description': 'Check-out date. Format: %Y-%m-%d'}, + } + output_parameters = { + 'hotels': {'type': 'list', 'description': 'List of available hotels.'}, + } + + def __init__(self): + self.hotel_database = { + 'hotel_1': { + 'destination': 'San Francisco', + 'check_in_date': datetime.datetime(2022, 1, 1), + 'check_out_date': datetime.datetime(2022, 1, 2), + }, + 'hotel_2': { + 'destination': 'San Francisco', + 'check_in_date': datetime.datetime(2022, 1, 2), + 'check_out_date': datetime.datetime(2022, 1, 3), + }, + 'hotel_3': { + 'destination': 'San Francisco', + 'check_in_date': datetime.datetime(2022, 1, 3), + 'check_out_date': datetime.datetime(2022, 1, 4), + }, + 'hotel_4': { + 'destination': 'San Francisco', + 'check_in_date': datetime.datetime(2022, 1, 4), + 'check_out_date': datetime.datetime(2022, 1, 5), + }, + 'hotel_5': { + 'destination': 'San Francisco', + 'check_in_date': datetime.datetime(2022, 1, 5), + 'check_out_date': datetime.datetime(2022, 1, 6), + }, + } + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + if response['input']['destination'] == groundtruth['input']['destination'] \ + and response['input']['check_in_date'] == groundtruth['input']['check_in_date'] \ + and response['input']['check_out_date'] == groundtruth['input']['check_out_date'] \ + and response['output']['hotels'] == groundtruth['output']['hotels']: + return True + else: + return False + + def call(self, destination: str, check_in_date: str, check_out_date: str) -> dict: + input_parameters = { + 'destination': destination, + 'check_in_date': check_in_date, + 'check_out_date': check_out_date + } + try: + hotels = self.get_available_hotels(destination, check_in_date, check_out_date) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': {'hotels': hotels}, + 'exception': None} + + def get_available_hotels(self, destination: str, check_in_date: str, check_out_date: str) -> list: + # Check the format of the input parameters. + datetime.datetime.strptime(check_in_date, '%Y-%m-%d') + datetime.datetime.strptime(check_out_date, '%Y-%m-%d') + + available_hotels = [] + for hotel_id, hotel in self.hotel_database.items(): + if hotel['destination'] == destination and self.is_hotel_available(hotel, check_in_date, check_out_date): + hotel['check_in_date'] = hotel['check_in_date'].strftime('%Y-%m-%d') + hotel['check_out_date'] = hotel['check_out_date'].strftime('%Y-%m-%d') + available_hotels.append(hotel) + + return available_hotels + + def is_hotel_available(self, hotel, check_in_date, check_out_date) -> bool: + hotel_check_in_date = hotel['check_in_date'] if isinstance(hotel['check_in_date'], datetime.datetime) \ + else datetime.datetime.strptime(hotel['check_in_date'], '%Y-%m-%d') + hotel_check_out_date = hotel['check_out_date'] if isinstance(hotel['check_out_date'], datetime.datetime) \ + else datetime.datetime.strptime(hotel['check_out_date'], '%Y-%m-%d') + user_check_in_date = datetime.datetime.strptime(check_in_date, '%Y-%m-%d') + user_check_out_date = datetime.datetime.strptime(check_out_date, '%Y-%m-%d') + + if user_check_in_date >= hotel_check_in_date and user_check_out_date <= hotel_check_out_date: + return True diff --git a/api-bank/lv3_apis/like_count.py b/api-bank/lv3_apis/like_count.py new file mode 100644 index 00000000..8d8cb1ef --- /dev/null +++ b/api-bank/lv3_apis/like_count.py @@ -0,0 +1,64 @@ +from apis.api import API + +class LikeCount(API): + description = "API to retrieve the number of likes for a given post ID." + input_parameters = { + 'post_id': {'type': 'int', 'description': "Post ID."}, + } + output_parameters = { + 'like_count': {'type': 'int', 'description': 'Number of likes for the post.'}, + } + + def __init__(self): + self.database = { + 1: 10, + 2: 20, + 3: 30, + 4: 40, + 5: 50, + 6: 60, + 7: 70, + 8: 80, + 9: 90, + 10: 100, + 11: 110, + 12: 120, + 13: 130, + 14: 140, + 15: 150, + } + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, post_id: int) -> dict: + input_parameters = { + 'post_id': post_id, + } + try: + like_count = self.retrieve_like_count(post_id) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, + 'output': {'like_count': like_count}, 'exception': None} + + def retrieve_like_count(self, post_id: int) -> int: + if post_id not in self.database: + raise Exception(f"Post with ID {post_id} does not exist.") + + return self.database[post_id] + diff --git a/api-bank/lv3_apis/movie_recommandation.py b/api-bank/lv3_apis/movie_recommandation.py new file mode 100644 index 00000000..6c0bd4c9 --- /dev/null +++ b/api-bank/lv3_apis/movie_recommandation.py @@ -0,0 +1,58 @@ +from apis.api import API +class MovieRecommendations(API): + description = "The API for retrieving recommended movies based on user preferences and filtering watched movies." + input_parameters = { + 'preferences': {'type': 'list', 'description': "User's movie preferences."}, + } + output_parameters = { + 'recommended_movies': {'type': 'list', 'description': 'List of recommended movies.'} + } + + def __init__(self): + self.database = { + 'Action': ['The Dark Knight', 'The Matrix', 'The Lord of the Rings'], + 'Comedy': ['The Hangover', 'Knives Out', 'Deadpool'], + 'Drama': ['The Shawshank Redemption', 'Forrest Gump', 'Joker'], + 'Romance': ['Titanic', 'La La Land', 'The Notebook'], + 'Thriller': ['Inception', 'Parasite', 'Get Out'], + } + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + + def call(self, preferences: list) -> dict: + input_parameters = { + 'preferences': preferences, + } + try: + recommended_movies = self.get_recommendations(preferences) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': {'recommended_movies': recommended_movies}, + 'exception': None} + + + def get_recommendations(self, preferences: list) -> list: + recommended_movies = [] + + for preference in preferences: + movies = self.database[preference] + recommended_movies.extend(movies) + + return recommended_movies diff --git a/api-bank/lv3_apis/nearby_restaurants.py b/api-bank/lv3_apis/nearby_restaurants.py new file mode 100644 index 00000000..3a6ad1bf --- /dev/null +++ b/api-bank/lv3_apis/nearby_restaurants.py @@ -0,0 +1,65 @@ +from apis.api import API +class NearbyRestaurants(API): + description = "Retrieves nearby restaurants based on the provided coordinates and search parameters." + input_parameters = { + 'latitude': {'type': 'float', 'description': 'Latitude of the location.'}, + 'longitude': {'type': 'float', 'description': 'Longitude of the location.'}, + 'distance': {'type': 'int', 'description': 'The distance in meters from the location to search for restaurants.'}, + } + output_parameters = { + 'restaurants': {'type': 'list', 'description': 'A list of nearby restaurants.'}, + } + + def __init__(self): + self.restaurants = [ + {'coordinates': (40.7128, 74.0060), 'name': 'Restaurant A'}, + {'coordinates': (37.7749, 122.4194), 'name': 'Restaurant B'}, + {'coordinates': (40.7128, 74.0060), 'name': 'Restaurant C'}, + {'coordinates': (37.7749, 122.4194), 'name': 'Restaurant D'}, + ] + + def get_nearby_restaurants(self, latitude: float, longitude: float, distance: int) -> list: + """ + Retrieves nearby restaurants based on the provided coordinates and search parameters. + + Parameters: + - latitude (float): latitude of the location. + - longitude (float): longitude of the location. + - distance (int): the distance in meters from the location to search for restaurants. + + Returns: + - restaurants (list): a list of nearby restaurants. + """ + return self.restaurants + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['input'] == groundtruth['input'] + + def call(self, latitude: float, longitude: float, distance: int) -> dict: + input_parameters = { + 'latitude': latitude, + 'longitude': longitude, + 'distance': distance + } + try: + restaurants = self.get_nearby_restaurants(latitude, longitude, distance) + output_parameters = { + 'restaurants': restaurants + } + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, 'exception': None} + diff --git a/api-bank/lv3_apis/organization_members.py b/api-bank/lv3_apis/organization_members.py new file mode 100644 index 00000000..13a0f3bc --- /dev/null +++ b/api-bank/lv3_apis/organization_members.py @@ -0,0 +1,77 @@ +from apis.api import API + +class OrganizationMembers(API): + description = "API to retrieve the list of members in the organization." + input_parameters = { + 'organization': {'type': 'str', 'description': "Name of the organization."}, + } + output_parameters = { + 'members': {'type': 'list', 'description': 'List of organization members.'} + } + + + def __init__(self): + self.database = { + 'Alibaba': [ + 'John', + 'Mary', + 'Peter', + ], + 'Tencent': [ + 'Tom', + 'Jerry', + ], + 'Baidu': [ + 'Jack', + 'Rose', + ], + 'ByteDance': [ + 'Bob', + 'Alice', + ], + 'JD': [ + 'Mike', + 'Jane', + ], + } + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + if response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input']: + return True + else: + return False + + + def call(self, organization: str) -> dict: + input_parameters = { + 'organization': organization + } + try: + members = self.get_organization_members(organization) + output_parameters = { + 'members': members + } + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': output_parameters, + 'exception': None} + + + def get_organization_members(self, organization): + return self.database[organization] + diff --git a/api-bank/lv3_apis/query_meeting.py b/api-bank/lv3_apis/query_meeting.py new file mode 100644 index 00000000..d2432985 --- /dev/null +++ b/api-bank/lv3_apis/query_meeting.py @@ -0,0 +1,111 @@ +import datetime +from apis.api import API + +class QueryMeeting(API): + description = "The API for retrieving the meeting details from the user's calendar." + input_parameters = { + 'user_name': {'type': 'str', 'description': 'Name of the user.'}, + } + output_parameters = { + 'meetings': {'type': 'list', 'description': 'List of meetings.'}, + } + + def __init__(self): + self.database = { + 'John': [ + { + 'meeting_id': 1, + 'meeting_name': 'Meeting with the client', + 'meeting_time': datetime.datetime(2021, 1, 1, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 1', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + { + 'meeting_id': 2, + 'meeting_name': 'Meeting about the new project', + 'meeting_time': datetime.datetime(2021, 1, 2, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 2', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + ], + 'Mary': [ + { + 'meeting_id': 1, + 'meeting_name': 'Meeting with the client', + 'meeting_time': datetime.datetime(2021, 1, 1, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 1', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + { + 'meeting_id': 2, + 'meeting_name': 'Meeting about the new project', + 'meeting_time': datetime.datetime(2021, 1, 2, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 2', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + ], + 'Peter': [ + { + 'meeting_id': 1, + 'meeting_name': 'Meeting with the client', + 'meeting_time': datetime.datetime(2021, 1, 1, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 1', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + { + 'meeting_id': 2, + 'meeting_name': 'Meeting about the new project', + 'meeting_time': datetime.datetime(2021, 1, 2, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 2', + 'meeting_attendees': ['John', 'Mary', 'Peter'], + }, + ], + 'Tom': [ + { + 'meeting_id': 1, + 'meeting_name': 'Meeting', + 'meeting_time': datetime.datetime(2021, 1, 1, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 1', + 'meeting_attendees': ['Tom', 'Jerry'], + }, + ], + 'Jerry': [ + { + 'meeting_id': 1, + 'meeting_name': 'Meeting', + 'meeting_time': datetime.datetime(2021, 1, 1, 10, 0, 0).strftime('%Y-%m-%d %H:%M:%S'), + 'meeting_location': 'Room 1', + 'meeting_attendees': ['Tom', 'Jerry'], + }, + ], + } + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, user_name): + input_parameters = { + 'user_name': user_name, + } + try: + meetings = self.retrieve_user_meetings(user_name) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': {'meetings': meetings}, 'exception': None} + + def retrieve_user_meetings(self, user_name): + return self.database[user_name] \ No newline at end of file diff --git a/api-bank/lv3_apis/tax_calculator.py b/api-bank/lv3_apis/tax_calculator.py new file mode 100644 index 00000000..f6dc42a5 --- /dev/null +++ b/api-bank/lv3_apis/tax_calculator.py @@ -0,0 +1,64 @@ +from apis.api import API +class TaxCalculator(API): + description = "API for calculating tax deductions based on the given salary." + input_parameters = { + 'salary': {'type': 'float', 'description': 'The salary to calculate tax deductions for.'}, + } + output_parameters = { + 'salary_after_tax': {'type': 'float', 'description': 'The salary after tax deductions.'} + } + + def __init__(self, tax_rates=None): + if tax_rates is not None: + self.tax_rates = tax_rates + else: + self.tax_rates = { + 0: 0.0, # 0% tax rate + 1000: 0.1, # 10% tax rate for income up to 1000 + 3000: 0.2, # 20% tax rate for income up to 3000 + 5000: 0.3, # 30% tax rate for income up to 5000 + float('inf'): 0.4 # 40% tax rate for income above 5000 + } + + def calculate_tax_deductions(self, salary): + tax_rate = next(rate for threshold, rate in sorted(self.tax_rates.items(), reverse=True) if salary >= threshold) + tax_deduction = salary * tax_rate + salary_after_tax = salary - tax_deduction + return salary_after_tax + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + if response['input']['salary'] == groundtruth['input']['salary'] and \ + response['output']['salary_after_tax'] == groundtruth['output']['salary_after_tax'] and \ + response['exception'] == groundtruth['exception']: + return True + else: + return False + + def call(self, salary): + input_parameters = {'salary': salary} + try: + salary_after_tax = self.calculate_tax_deductions(salary) + output_parameters = {'salary_after_tax': salary_after_tax} + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, + 'output': output_parameters, 'exception': None} + +if __name__ == '__main__': + tax_calculator = TaxCalculator() + response = tax_calculator.call(100000) + print(response) \ No newline at end of file diff --git a/api-bank/lv3_apis/tool_search.py b/api-bank/lv3_apis/tool_search.py new file mode 100644 index 00000000..42a8ccbe --- /dev/null +++ b/api-bank/lv3_apis/tool_search.py @@ -0,0 +1,149 @@ +from sentence_transformers import SentenceTransformer, util +import logging +logging.getLogger('sentence_transformers').setLevel(logging.WARNING) + +import json +from apis.api import API +import os + +class ToolSearcher(API): + description = 'Searches for relevant tools in library based on the keyword.' + input_parameters = { + 'keywords': {'type': 'str', 'description': 'The keyword to search for.'} + } + output_parameters = { + 'best_matchs': {'type': 'Union[List[dict], dict]', 'description': 'The best match tool(s).'}, + } + + def __init__(self, apis_dir = './lv3_apis'): + import importlib.util + + + all_apis = [] + # import all the file in the apis folder, and load all the classes + except_files = ['__init__.py', 'api.py', 'tool_search.py'] + for file in os.listdir(apis_dir): + if file.endswith('.py') and file not in except_files: + api_file = file.split('.')[0] + basename = os.path.basename(apis_dir) + module = importlib.import_module(basename + "." + api_file) + classes = [getattr(module, x) for x in dir(module) if isinstance(getattr(module, x), type)] + for cls in classes: + if issubclass(cls, API) and cls is not API: + all_apis.append(cls) + + classes = all_apis + + # # Import the module containing the API classes + # spec = importlib.util.spec_from_file_location('apis', './apis.py') + # module = importlib.util.module_from_spec(spec) + # spec.loader.exec_module(module) + + # # Get a list of all classes in the module + # classes = inspect.getmembers(module, inspect.isclass) + + def api_summery(cls): + cls_name = cls.__name__ + # split cls_name by capital letters + cls_name = ''.join([' ' + i.lower() if i.isupper() else i for i in cls_name]).strip() + return cls_name + cls.description + + # Get the description parameter for each class + apis = [] + for cls in classes: + if issubclass(cls, object) and cls is not object: + desc_for_search = api_summery(cls) + apis.append({ + 'name': cls.__name__, + 'description': cls.description, + 'input_parameters': cls.input_parameters, + 'output_parameters': cls.output_parameters, + 'desc_for_search': desc_for_search + }) + + self.apis = apis + + def check_api_call_correctness(self, response, groundtruth) -> bool: + + if response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception']: + return True + else: + return False + + def call(self, keywords): + """ + Searches for relevant tools in various libraries based on the keyword. + + Parameters: + - keywords (str): the keywords to search for. + + Returns: + - best_match (dict): the best match for the keywords. + """ + input_parameters = { + 'keywords': keywords + } + try: + best_match = self.best_match_api(keywords) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': best_match, 'exception': None} + + + def best_match_api(self, keywords): + + model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L3-v2') + kw_emb = model.encode(keywords) + best_match = None + best_match_score = 0 + for api in self.apis: + re_emb = model.encode(api['desc_for_search']) + cos_sim = util.cos_sim(kw_emb, re_emb).item() + if cos_sim > best_match_score: + best_match = api.copy() + best_match_score = cos_sim + best_match.pop('desc_for_search') + if 'token' in best_match['input_parameters']: + return [self.get_user_token_api, best_match] + else: + return best_match + # best_match = None + # for api in self.apis: + # if api['name'] == keywords: + # best_match = api + # break + # best_match = best_match.copy() + # best_match.pop('desc_for_search') + # return best_match + + # model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L3-v2') + # kw_emb = model.encode(keywords) + + # scores = [] + # for api in self.apis: + # re_emb = model.encode(api['desc_for_search']) + # cos_sim = util.cos_sim(kw_emb, re_emb).item() + # scores.append((api, cos_sim)) + + # scores.sort(key=lambda x: x[1], reverse=True) + # api_id = input('Please select the best match from {}:\n'.format([x[0]['name'] for x in scores[:3]])) + # try: + # api_id = int(api_id) - 1 + # except: + # best_match = scores[0][0] + # for api in self.apis: + # if api['name'] == api_id: + # best_match = api + # break + # else: + # best_match = scores[int(api_id)][0] + + # best_match = best_match.copy() + # best_match.pop('desc_for_search') + # return best_match + +if __name__ == '__main__': + tool_searcher = ToolSearcher(apis_dir='./lv3_apis') + print(tool_searcher.call('add alarm')) \ No newline at end of file diff --git a/api-bank/lv3_apis/travel_status.py b/api-bank/lv3_apis/travel_status.py new file mode 100644 index 00000000..b60787e4 --- /dev/null +++ b/api-bank/lv3_apis/travel_status.py @@ -0,0 +1,56 @@ +from apis.api import API +class TravelStatus(API): + description = "API for retrieving the current travel status of each member." + input_parameters = { + 'member_name': {'type': 'str', 'description': 'Name of the member.'}, + } + output_parameters = { + 'status': {'type': 'str', 'description': 'Travel status'}, + } + + def __init__(self): + self.database = { + 'John': 'Traveling', + 'Mary': 'Working from home', + 'Peter': 'Working from office', + 'Tom': 'Traveling', + 'Jerry': 'Working from home', + 'Jack': 'Working from office', + 'Rose': 'Working from office', + 'Bob': 'Traveling', + 'Alice': 'Working from home', + 'Mike': 'Working from office', + 'Jane': 'Working from office', + } + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): The response from the API call. + - groundtruth (dict): The groundtruth response. + + Returns: + - is_correct (bool): Whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + + def call(self, member_name: str) -> dict: + input_parameters = {'member_name': member_name} + try: + status = self.get_travel_status(member_name) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': status, + 'exception': None} + + + def get_travel_status(self, member_name): + return self.database[member_name] \ No newline at end of file diff --git a/api-bank/lv3_apis/update_account_info.py b/api-bank/lv3_apis/update_account_info.py new file mode 100644 index 00000000..8a576d93 --- /dev/null +++ b/api-bank/lv3_apis/update_account_info.py @@ -0,0 +1,74 @@ +from apis.api import API + +class PersonalInfoUpdate(API): + description = "The API for updating a user's personal information and address." + input_parameters = { + 'username': {'type': 'str', 'description': 'Name of the user.'}, + 'password': {'type': 'str', 'description': 'Password of the user.'}, + 'address': {'type': 'str', 'description': 'Updated address information.'}, + } + output_parameters = { + 'status': {'type': 'str', 'description': 'Success or failure'}, + } + + + def __init__(self): + self.database = [ + (('John', '123456'), {'email': 'john@example.com', 'phone': '1234567890'}), + (('Mary', 'abcdef'), {'email': 'mary@example.com', 'phone': '0987654321'}), + (('Peter', 'qwerty'), {'email': 'peter@example.com', 'phone': '1231231234'}), + (('Tom', 'asdfgh'), {'email': 'tom@example.com', 'phone': '4564564567'}), + (('Jerry', 'zxcvbn'), {'email': 'jerry@example.com', 'phone': '7897897890'}), + ] + + + def check_api_call_correctness(self, response, groundtruth): + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['input'] == groundtruth['input'] and response['output'] == groundtruth['output'] + + + def call(self, username: str, password: str, address: dict) -> dict: + input_parameters = { + 'username': username, + 'password': password, + 'address': address + } + try: + status = self.update_personal_info(username, password, address) + output_parameters = { + 'status': status + } + except Exception as e: + exception = str(e) + return { + 'api_name': self.__class__.__name__, + 'input': input_parameters, + 'output': None, + 'exception': exception + } + else: + return { + 'api_name': self.__class__.__name__, + 'input': input_parameters, + 'output': output_parameters, + 'exception': None + } + + + def update_personal_info(self, username, password, address): + for (user, pwd), user_account in self.database: + if user == username and pwd == password: + user_account['address'] = address + return 'success' + raise Exception('User not found.') + diff --git a/api-bank/lv3_apis/user_posts.py b/api-bank/lv3_apis/user_posts.py new file mode 100644 index 00000000..91a66f6a --- /dev/null +++ b/api-bank/lv3_apis/user_posts.py @@ -0,0 +1,53 @@ +from apis.api import API + +class UserPosts(API): + description = "API to retrieve the post IDs for a specific user." + input_parameters = { + 'user_id': {'type': 'int', 'description': "User's ID."}, + } + output_parameters = { + 'post_ids': {'type': 'list', 'description': 'List of post IDs.'}, + } + + + def __init__(self): + self.database = { + 1: [1, 2, 3], + 2: [4, 5, 6], + 3: [7, 8, 9], + 4: [10, 11, 12], + 5: [13, 14, 15], + } + + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + + def call(self, user_id: int) -> dict: + input_parameters = { + 'user_id': user_id, + } + try: + post_ids = self.retrieve_user_post_ids(user_id) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': {'post_ids': post_ids}, 'exception': None} + + def retrieve_user_post_ids(self, user_id: int) -> list: + return self.database[user_id] diff --git a/api-bank/lv3_apis/user_watched_movies.py b/api-bank/lv3_apis/user_watched_movies.py new file mode 100644 index 00000000..fbbc73d4 --- /dev/null +++ b/api-bank/lv3_apis/user_watched_movies.py @@ -0,0 +1,54 @@ +from apis.api import API + +class UserWatchedMovies(API): + description = "API for retrieving a user's watched movie list." + input_parameters = { + 'user_name': {'type': 'str', 'description': 'Name of the user.'}, + } + output_parameters = { + 'watched_movies': {'type': 'list', 'description': 'List of watched movies.'}, + } + + def __init__(self): + self.database = { + 'John': ['The Matrix', 'The Lord of the Rings', 'The Dark Knight'], + 'Mary': ['The Lord of the Rings', 'The Dark Knight', 'The Matrix'], + 'Peter': ['The Matrix', 'The Lord of the Rings', 'The Dark Knight'], + 'Tom': ['The Lord of the Rings', 'The Dark Knight', 'The Matrix'], + 'Jerry': ['The Matrix', 'The Lord of the Rings', 'The Dark Knight'], + } + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, user_name: str) -> dict: + input_parameters = { + 'user_name': user_name + } + try: + watched_movies = self.retrieve_watched_movies(user_name) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, + 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': watched_movies, + 'exception': None} + + def retrieve_watched_movies(self, user_name): + if user_name not in self.database: + raise Exception('User not found.') + + return self.database[user_name] + diff --git a/api-bank/lv3_apis/users_movie_preference.py b/api-bank/lv3_apis/users_movie_preference.py new file mode 100644 index 00000000..43fc119d --- /dev/null +++ b/api-bank/lv3_apis/users_movie_preference.py @@ -0,0 +1,48 @@ +from apis.api import API + +class UserMoviePreferences(API): + description = "API for retrieving user preferences for movie recommendations." + input_parameters = { + 'user_name': {'type': 'str', 'description': 'Name of the user.'}, + } + output_parameters = { + 'preferences': {'type': 'list', 'description': 'List of movie preferences.'}, + } + + def __init__(self): + self.database = { + 'John': ['Action', 'Comedy', 'Drama'], + 'Mary': ['Comedy', 'Drama', 'Romance'], + 'Peter': ['Action', 'Drama', 'Thriller'], + 'Tom': ['Action', 'Comedy', 'Drama'], + 'Jerry': ['Comedy', 'Drama', 'Romance'], + } + + def check_api_call_correctness(self, response, groundtruth) -> bool: + """ + Checks if the response from the API call is correct. + + Parameters: + - response (dict): the response from the API call. + - groundtruth (dict): the groundtruth response. + + Returns: + - is_correct (bool): whether the response is correct. + """ + assert response['api_name'] == groundtruth['api_name'], "The API name is not correct." + return response['output'] == groundtruth['output'] and response['exception'] == groundtruth['exception'] and response['input'] == groundtruth['input'] + + def call(self, user_name: str) -> dict: + input_parameters = { + 'user_name': user_name + } + try: + preferences = self.retrieve_preferences(user_name) + except Exception as e: + exception = str(e) + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': None, 'exception': exception} + else: + return {'api_name': self.__class__.__name__, 'input': input_parameters, 'output': {'preferences': preferences}, 'exception': None} + + def retrieve_preferences(self, user_name): + return self.database[user_name] diff --git a/api-bank/lv3_evaluator.py b/api-bank/lv3_evaluator.py new file mode 100644 index 00000000..bc8ed2f2 --- /dev/null +++ b/api-bank/lv3_evaluator.py @@ -0,0 +1,121 @@ +import json +from tool_manager import ToolManager +from api_call_extraction import parse_api_call, get_api_call +import logging +from rouge import Rouge + +def split_by_uppercase(s): + return ''.join([' ' + c if c.isupper() else c for c in s]).strip() + +def calculate_rouge_l_score(reference, hypothesis): + rouge = Rouge() + if hypothesis == '': + return 0 + scores = rouge.get_scores(hypothesis, reference) + rouge_l_score = scores[0]['rouge-l']['f'] + return rouge_l_score + +def test_json(): + # if_api = True + if_api = False + pred_path = 'path-to-json' + gt_path = 'test_data/level-3.json' + tool_manager = ToolManager('./lv3_apis') + with open(pred_path, 'r') as f: + preds = [json.loads(line) for line in f.readlines()] + # preds = json.load(f) + + with open(gt_path, 'r') as f: + gts = json.load(f) + + # if_api = 'API-Request:' in preds[0]['output'] + + if if_api: + total_num = len(preds) + correct_num = 0 + errored_sample_ids = [] + tool_search_error_num = 0 + else: + rougel_scores = [] + for pred_id, pred in enumerate(preds): + if if_api: + sample_id = pred['sample_id'] + # if sample_id in errored_sample_ids: + # continue + api_id = pred['api_id'] + gt = gts[sample_id]['apis'][api_id] + gt_api_name = gt['api_name'] + gt_result = gt['output'] + + pred_api_call = get_api_call(pred['pred']) + if not pred_api_call: + logging.warning('No api call found in pred: {}'.format(pred_id)) + errored_sample_ids.append(sample_id) + continue + try: + pred_api_name, pred_param_dict = parse_api_call(pred_api_call) + except Exception as e: + logging.warning('Parse api call error: {} {}'.format(str(e), pred_id)) + errored_sample_ids.append(sample_id) + continue + try: + if pred_api_name == 'ToolSearcher': + pred_param_dict['keywords'] = split_by_uppercase(pred_param_dict['keywords']) + pred_result = tool_manager.api_call(pred_api_name, **pred_param_dict) + except TypeError as e: + logging.warning('TypeError: {} {}'.format(str(e), pred_id)) + errored_sample_ids.append(sample_id) + continue + except AssertionError as e: + logging.warning('AssertionError: {} {}'.format(str(e), pred_id)) + errored_sample_ids.append(sample_id) + continue + except Exception as e: + if str(e) == 'invalid tool name.': + logging.warning('invalid tool name: {} {}'.format(str(e), pred_id)) + errored_sample_ids.append(sample_id) + continue + else: + raise e + + gt_api = tool_manager.init_tool(gt_api_name) + try: + correct = gt_api.check_api_call_correctness(pred_result, gt_result) + except KeyError: + correct = False + logging.warning('KeyError: {}'.format(pred_id)) + except AssertionError as e: + correct = False + logging.warning('AssertionError: {} {}'.format(str(e), pred_id)) + if correct: + correct_num += 1 + else: + # for test toolsearcher + errored_sample_ids.append(sample_id) + if gt_api_name != 'ToolSearcher': + pass + else: + tool_search_error_num += 1 + logging.warning('Incorrect: {}'.format(pred_id)) + else: + gt_response = pred['output'] + pred_response = pred['pred'].replace('User:', '').replace('AI:', '').strip() + rouge_l_score = calculate_rouge_l_score(gt_response, pred_response) + rougel_scores.append(rouge_l_score) + + if if_api: + print('Accuracy: {}'.format(correct_num / total_num)) + print('Total: {}'.format(total_num)) + print('Correct: {}'.format(correct_num)) + + print('Sample Accuracy: {}'.format((50 - len(set(errored_sample_ids))) / 50)) + print('Total: {}'.format(50)) + print('Correct: {}'.format(50 - len(set(errored_sample_ids)))) + + print('ToolSearcher Error Num: {}'.format(tool_search_error_num)) + else: + print('Rouge-L: {}'.format(sum(rougel_scores) / len(rougel_scores))) + + +if __name__ == '__main__': + test_json() diff --git a/api-bank/tool_manager.py b/api-bank/tool_manager.py index ca58fff8..e46e335d 100644 --- a/api-bank/tool_manager.py +++ b/api-bank/tool_manager.py @@ -4,17 +4,17 @@ import json from api_call_extraction import parse_api_call class ToolManager: - def __init__(self) -> None: + def __init__(self, apis_dir='./apis') -> None: import importlib.util all_apis = [] # import all the file in the apis folder, and load all the classes - apis_dir = './apis' except_files = ['__init__.py', 'api.py'] for file in os.listdir(apis_dir): if file.endswith('.py') and file not in except_files: api_file = file.split('.')[0] - module = importlib.import_module("apis." + api_file) + basename = os.path.basename(apis_dir) + module = importlib.import_module(f'{basename}.{api_file}') classes = [getattr(module, x) for x in dir(module) if isinstance(getattr(module, x), type)] for cls in classes: if issubclass(cls, API) and cls is not API: @@ -49,7 +49,8 @@ def __init__(self) -> None: self.apis = apis self.inited_tools = {} - self.token_checker = self.init_tool('CheckToken') + if 'CheckToken' in self.list_all_apis(): + self.token_checker = self.init_tool('CheckToken') def get_api_by_name(self, name: str): """ @@ -147,19 +148,21 @@ def api_call(self, tool_name: str, **kwargs): required_type = required_para['type'] if required_type == 'int': - assert input_value.isdigit(), 'invalid parameter type. parameter: {}'.format(input_value) + if isinstance(input_value, str): + assert input_value.isdigit(), 'invalid parameter type. parameter: {}'.format(input_value) processed_parameters[input_key] = int(input_value) elif required_type == 'float': - assert input_value.replace('.', '', 1).isdigit(), 'invalid parameter type.' + if isinstance(input_value, str): + assert input_value.replace('.', '', 1).isdigit(), 'invalid parameter type.' processed_parameters[input_key] = float(input_value) elif required_type == 'str': processed_parameters[input_key] = input_value elif required_type == 'list(str)': - input_value = input_value.replace('\'', '"') - processed_parameters[input_key] = json.loads(input_value) + # input_value = input_value.replace('\'', '"') + processed_parameters[input_key] = input_value elif required_type == 'list': - input_value = input_value.replace('\'', '"') - processed_parameters[input_key] = json.loads(input_value) + # input_value = input_value.replace('\'', '"') + processed_parameters[input_key] = input_value elif required_type == 'bool': processed_parameters[input_key] = input_value == 'True' else: