From 838bea4e387de23c096f1356f7fb41891b3a48d6 Mon Sep 17 00:00:00 2001 From: Xing Han Lu <21180505+xhluca@users.noreply.github.com> Date: Tue, 25 Jun 2024 14:20:40 -0400 Subject: [PATCH] Add `webllama.experimental` API (#12) * Introduce experimental API * Add examples and initial docs, reference docs in readme * Move readme * Update remote name * Fix cuda visible devices * Update examples/browsergym/agent.py --- README.md | 13 +- docs/README.md | 388 ++++++++++ examples/README.md | 55 ++ examples/browsergym/agent.py | 198 +++++ examples/browsergym/run_bg.py | 22 + examples/complete/run_all.py | 107 +++ examples/web_api/run_client.py | 31 + examples/web_api/run_http.py | 63 ++ modeling/README.md | 5 +- requirements-basic.txt | 3 + requirements-extra.txt | 6 + requirements.txt | 3 - setup.py | 9 +- webllama/__init__.py | 3 +- webllama/experimental/__init__.py | 1 + webllama/experimental/classes.py | 353 +++++++++ webllama/experimental/formatting.py | 191 +++++ webllama/experimental/functions.py | 702 ++++++++++++++++++ .../experimental/integrations/__init__.py | 1 + .../integrations/browsergym/__init__.py | 69 ++ .../integrations/browsergym/functions.py | 68 ++ webllama/experimental/processing.py | 392 ++++++++++ webllama/experimental/templates/__init__.py | 1 + webllama/experimental/templates/weblinx.py | 21 + webllama/experimental/web/__init__.py | 1 + webllama/experimental/web/client.py | 43 ++ webllama/experimental/web/server.py | 195 +++++ 27 files changed, 2933 insertions(+), 11 deletions(-) create mode 100644 docs/README.md create mode 100644 examples/README.md create mode 100644 examples/browsergym/agent.py create mode 100644 examples/browsergym/run_bg.py create mode 100644 examples/complete/run_all.py create mode 100644 examples/web_api/run_client.py create mode 100644 examples/web_api/run_http.py create mode 100644 requirements-basic.txt create mode 100644 requirements-extra.txt delete mode 100644 requirements.txt create mode 100644 webllama/experimental/__init__.py create mode 100644 webllama/experimental/classes.py create mode 100644 webllama/experimental/formatting.py create mode 100644 webllama/experimental/functions.py create mode 100644 webllama/experimental/integrations/__init__.py create mode 100644 webllama/experimental/integrations/browsergym/__init__.py create mode 100644 webllama/experimental/integrations/browsergym/functions.py create mode 100644 webllama/experimental/processing.py create mode 100644 webllama/experimental/templates/__init__.py create mode 100644 webllama/experimental/templates/weblinx.py create mode 100644 webllama/experimental/web/__init__.py create mode 100644 webllama/experimental/web/client.py create mode 100644 webllama/experimental/web/server.py diff --git a/README.md b/README.md index 05ec6f6..48c7694 100644 --- a/README.md +++ b/README.md @@ -98,13 +98,20 @@ Although the 24K training examples from [`WebLINX`](https://mcgill-nlp.github.io We are working hard to make it easy for you to deploy Llama web agents to the web. We want to integrate `WebLlama` with existing deployment platforms, including Microsoft's Playwright, ServiceNow Research's BrowserGym, and other partners. +At the moment, we offer the following integrations: +* `Browsergym`: Please find more information in [`examples/README.md`](examples/README.md) and [`docs/README.md`](docs/README.md). + ## Code -The code for finetuning the model and evaluating it on the [`WebLINX`](https://mcgill-nlp.github.io/weblinx/) benchmark is available now. You can find the detailed instructions in [modeling](modeling/README.md). +The code for finetuning the model and evaluating it on the [`WebLINX`](https://mcgill-nlp.github.io/weblinx/) benchmark is available now. +* **Modeling**: You can find the detailed instructions in [modeling](modeling/README.md) for training `Llama-3-8B-Web` on the `WebLINX` dataset. +* **Examples**: We provide a few example for using the `webllama` API and models, including web API, end-to-end, and BrowserGym integration. You can find them in [examples](examples/README.md). +* **App**: We provide a simple Streamlit app for visualizing the results of your model on the `WebLINX` benchmark. You can find the code in [app](app/Results.py). +* **Docs**: We provide detailed documentation for the code in [docs](docs/README.md). > 👷‍♀️ **Next steps**\ -> We are actively working on new data, evaluation, and deployment integrations at the moment, so stay tuned! +> We are actively working on new data and evaluation at the moment! If you want to help, please create an issue describing what you would like to contribute, and we will be happy to help you get started. ## Citation @@ -129,4 +136,4 @@ The code in this repository is licensed under the MIT license, unless otherwise ### How can I contribute to the project? -We are actively looking for collaborators to help us build the best Llama-3 web agents! To get started, open an issue about what you would like to contribute, and once it has been discussed, you can submit a pull request. We will also soon be announcing a Discord channel for the project, where you can ask questions and discuss with other contributors. +We are actively looking for collaborators to help us build the best Llama-3 web agents! To get started, open an issue about what you would like to contribute, and once it has been discussed, you can submit a pull request. diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..8f2a5d2 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,388 @@ +# `webllama.experimental` API + +`webllama.experimental` is the new experimental API for working with webllama models. It will eventually be moved to `webllama` directly (once the API is deemed stable). + + +## Setup + +```bash +# Please choose the proper version to ensure you do not break the code +# if there are breaking changes in the future. +# e.g. 0.1.0 +pip install webllama=="" +``` + +You will need to download test demonstrations if you want to run the subsequent scripts that use existing weblinx demonstrations. + +```bash +mkdir -p tests/demonstrations +curl -L -o tests/demonstrations/aaabtsd.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aaabtsd.zip +unzip -u tests/demonstrations/aaabtsd.zip -d tests/demonstrations +curl -L -o tests/demonstrations/aajfwoq.zip https://github.com/McGill-NLP/weblinx/releases/download/tests-assets/aajfwoq.zip +unzip -u tests/demonstrations/aajfwoq.zip -d tests/demonstrations +``` + +## Quickstart with `webllama.experimental.processing` + +To install: +```bash +pip install webllama +# if you want to install transformers, pytorch and sentence-transformers, run: +pip install webllama[modeling] +``` + +First, you will need to construct your own `action_history` and `state` using `webllama.experimental.classes`: +```python +import webllama.experimental as wa + +# Create your action history and state! +action_history = [ + wa.classes.Action(...), # ... +] +state = wa.classes.State(...) +``` + +You will also need to load your `dmr` and `act_model` models. For example, you can use `transformers` and `sentence-transformers` to load them: +```python +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, pipeline + +# You can choose your own DMR model, and action model +act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") +dmr = SentenceTransformer(dmr_name, device="cuda") +``` + +Now, inside a Python script, you can use the `webllama.experimental.processing` to seamlessly use `Action` and `State` with action model and DMR, and also process the output: + +```python +import webllama.experimental as wa + +# We will initialize our processor, which helps us prepare the input for action model +proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer) + +# Step 1: prepare query, run DMR and prepare retrieved candidates +query_dmr = proc.prepare_dmr_query(action_history, state) +elems = proc.prepare_dmr_elements(state=state) +scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) +top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) +cands_str = proc.prepare_candidates(top_cands) + +# Step 2: format candidates, utterances, state, and previous actions +html = proc.prepare_state_html(state.html, cands_uids=cands_uids) +utterances = proc.prepare_instructor_chat(action_history, state) +prev_actions = proc.prepare_prev_actions(action_history, state) + +# Let's use the default system prompt template, but you can also use your own +sys_prompt_template: str = proc.default_system_prompt_template +sys_prompt = sys_prompt_template.format( + html=html, + utterances=utterances, + candidates=cands_str, + # ... +) +input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) + +# Use your tokenizer to convert the input to string and pass it to the action model +input_str = act_model.tokenizer.apply_chat_template(input_chat, tokenize=False) +output = act_model(input_str, ...) +pred_action = proc.process_action_model_output(output, state.index, elems) +a = wa.classes.Action.from_dict(pred_action) +``` + + +## End-to-end example + +Here's a full, self-contained example of how to use `webllama.experimental` to interact with a web page using a DMR model and an action model: + +```python +from functools import partial +import time +import logging + +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, pipeline +import weblinx as wl +import webllama.experimental as wa + +logging.getLogger("urllib3").setLevel(logging.WARNING) + +# Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects +# To get that, we will use an example from weblinx, but it's easy to do manually (see below). + +demos = wl.list_demonstrations("tests/demonstrations") +replay = wl.Replay.from_demonstration(demos[0]) +turn = replay[26] + +format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict +) +format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() +action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index +) +state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, +) + +# Now, we can start! +# First, load the DMR model we will use to select candidate elements +dmr_name = "McGill-NLP/MiniLM-L6-dmr" +action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx" +tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx" + +tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name) +act_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") +dmr = SentenceTransformer(dmr_name, device="cuda") + +# We will initialize our processor, which helps us prepare the input for action model +proc = wa.processing.WebTurnProcessor(tokenizer=act_model.tokenizer, start_time=time.time()) + +# Step 1: prepare query, run DMR and prepare retrieved candidates +query_dmr = proc.prepare_dmr_query(action_history, state) +elems = proc.prepare_dmr_elements(state=state) +scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) +top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) +cands_str = proc.prepare_candidates(top_cands) + +# Step 2: format candidates, utterances, state, and previous actions +html = proc.prepare_state_html(state.html, cands_uids=cands_uids) +utterances = proc.prepare_instructor_chat(action_history, state) +prev_actions = proc.prepare_prev_actions(action_history, state) + +# Let's use the default system prompt template, but you can also use your own +sys_prompt_template: str = proc.default_system_prompt_template +sys_prompt = sys_prompt_template.format( + html=html, + num_utterances=proc.num_utterances - 1, + utterances=utterances, + height=state.viewport_height, + width=state.viewport_width, + num_prev_actions=proc.num_prev_actions, + candidates=cands_str, +) +input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) + +# We can now use the tokenizer's apply_chat_template method to convert it to a format +# that can be used by the action model +input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False) + +# Let's now pass our input to the action model +output = act_model( + input_str, + max_new_tokens=256, + return_full_text=False, + batch_size=1, + pad_token_id=tokenizer.eos_token_id, +) +pred_action = proc.process_action_model_output( + output=output, index=state.index, elems=elems +) +# optional: For certain platforms you may need to postprocess the action +pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) +print(pred_action) +# You can now convert this an Action object and add it to the action history +a = wa.classes.Action.from_dict(pred_action) +action_history.append(a) +``` + +## Tests + +To run the tests: + +```bash +python -m unittest discover -s tests +``` + +## Web API + +### Running Server + +To launch the default server: +```bash +# If you do not want to save logs, omit `--save_logs` +python -m webllama.experimental.web.server --save_logs +``` + +To create your own server, simply inherit: +```python +from webllama.experimental.web.server import Server + +from ..classes import Action, State + +# Assuming the classes Action, State, and other necessary imports are already defined +# as provided in your initial setup. + +# Initialize logging +logging.basicConfig(level=logging.INFO) + +class Server(Server): + # override initialize and run + def initialize(self, dmr_name, action_model_name, device, dmr_device, am_device, torch_dtype): + # initialize your model here + + def run(self, action_history_json, state_json): + # ... + pred_action = { + # ... + } + return json.dumps(pred_action) +``` + +### Connecting via SSH + +To connect to the server via SSH, you can use the following command: +```bash +ssh -N -L 8450:localhost:8450 user@server + +# Example: +ssh -N -L 8450:localhost:8450 nlp-gpu-2 +``` + +### Using API + +You can directly send http request to the web server, or use the client. + +Example of HTTP request in python: + +```python +from functools import partial +import http.client +import json + +from functools import partial +import webllama.experimental as wa +import weblinx as wl + +# Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects +demos = wl.list_demonstrations("tests/demonstrations") +replay = wl.Replay.from_demonstration(demos[0]) +turn = replay[26] + +format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict +) +format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() +action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index +) +state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, +) + +# Create a connection to the localhost on the port where your server is running +conn = http.client.HTTPConnection('localhost', 8450) + +# Prepare the POST request data +post_data = json.dumps({ + 'action_history': action_history_dict, + 'state': state_dict +}) +headers = {'Content-Type': 'application/json'} + +# Send a POST request with JSON data +conn.request("POST", "/", body=post_data, headers=headers) +response = conn.getresponse() +print(f"Status: {response.status}") +print(f"Reason: {response.reason}") +print(f"Body: {response.read().decode()}") +response.close() + +# Close the connection +conn.close() +``` + +### Client + +A high level client is provided in `webllama.experimental.web.client`. You can use it as follows: + +```python +from functools import partial +import webllama.experimental as wa +import weblinx as wl + +# Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects +demos = wl.list_demonstrations("tests/demonstrations") +replay = wl.Replay.from_demonstration(demos[0]) +turn = replay[26] + +format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict +) +format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() +action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index +) +state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, +) + +# Now, we can start! +pred_action = wa.web.client.get_prediction( + action_history, state, address="localhost", port=8450, max_new_tokens=128 +) +pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) +print(pred_action) +a = wa.classes.Action.from_dict(pred_action) +print(a) +``` + +## Building objects + +> Note: This section is a work in progress. + +### Build `webllama.experimental.classes.Action` + +#### `say` action + +```python +utterance_instructor = wa.classes.Action( + type="chat", + intent="say", + index=2, + args=dict( + speaker="instructor", utterance="Open independent ie Website.", x=None, y=None + ), + timestamp=13.234, + tag=None, + attrs=None, +) +``` + +#### `click` action + +To be added. + +#### `load` action + +To be added. + +#### `textinput` action + +To be added. + +#### `submit` action + +To be added. + +### Build `webllama.experimental.classes.Bbox` + +To be added. + +### Build `webllama.experimental.classes.State` + +To be added. diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..c2dc2a5 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,55 @@ +# Examples + +### Web API and client + +You can find examples of how to use the server directly with `http.client.HTTPConnection` and through our client in [`examples/web_api/`](/examples/web_api/), respectively with `run_http.py` and `run_client.py`. You should let the server stay up for both examples. For more information, please read the section above about the Web API. + +### End-to-end + +You can find an end-to-end example of using `webllama.experimental` in [`examples/complete/run_all.py`](/examples/complete): + +```bash +python examples/complete/run_all.py +``` + + +### BrowserGym integration + +We provide directly integration to BrowserGym and examples to use it. You can find an example at [`examples/browsergym/run_bg.py`](/examples/browsergym). + + +On remote server (with GPU and hosting the webllama model), run: +```bash +# transformers, sentence-transformers, pytorch, etc. +pip install -e .[modeling] +``` + +First, remotely, run: + +```bash +# change if needed: +export CUDA_VISIBLE_DEVICES=0 + +python -m webllama.experimental.web.server --save_logs +``` + +Then, connect to your remote server via SSH: + +```bash +# 8450 is the default port for our server +ssh -N -L 8450:localhost:8450 "@" +``` + +Now, on your local machine, run: + +```bash +pip install -e . +# browsergym integration +pip install "browsergym==0.3.*" +# install playwright +playwright install +``` + +```bash +python examples/browsergym/run_bg.py +``` diff --git a/examples/browsergym/agent.py b/examples/browsergym/agent.py new file mode 100644 index 0000000..942ee6b --- /dev/null +++ b/examples/browsergym/agent.py @@ -0,0 +1,198 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from functools import partial +import time + + +from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str +from browsergym.core.action.highlevel import HighLevelActionSet +import weblinx as wl + +import webllama.experimental as wa + +from webllama_experimental.integrations.browsergym.functions import ( + say, + click, + textinput, + load, + scroll, + wait, +) +from webllama_experimental.integrations.browsergym import replace_bid_with_wl_uid, reverse_dict, postprocess_for_browsergym + +def remap_bboxes(bboxes, attrs_map): + """ + Cleans the bboxes dictionary by replacing the keys with the new unique ids. + """ + return {attrs_map[k]: v for k, v in bboxes.items()} + +class AgentBase(ABC): + """ + A template class that defines the required signature of an agent interacting with a browsergym environment. + """ + + @abstractmethod + def reset(self, seed=None) -> None: + """ + Resets the agent. + + """ + pass + + @abstractmethod + def get_action(self, obs: dict) -> str: + """ + Updates the agent with the current observation, and returns its next action (plus an info dict, optional). + + Parameters: + ----------- + obs: dict + The current observation of the environment. + """ + pass + + def preprocess_obs(self, obs: dict) -> dict: + """Default preprocessing of the observation.""" + pass + + def get_action_mapping(self) -> callable: + """ + Returns a callable that can be used to map the agent actions to executable python code. + """ + return None + + +class WebLinxAgent(AgentBase): + action_history = None + + def reset(self, seed=None) -> None: + self.action_history = [] + self.messages = [] + self.start_time = time.time() + self.has_user_message = False + + @property + def num_messages(self): + return len(self.messages) + + @staticmethod + def get_bboxes(xprops): + bboxes = {} + for k in xprops: + if xprops[k]["visibility"] == 1.0: + bbox = dict(zip(["x", "y", "width", "height"], xprops[k]["bbox"])) + # add top, left, bottom, right + bbox["top"] = bbox["y"] + bbox["left"] = bbox["x"] + bbox["bottom"] = bbox["y"] + bbox["height"] + bbox["right"] = bbox["x"] + bbox["width"] + bboxes[k] = bbox + + return bboxes + + @staticmethod + def infer_viewport_from_bboxes(bboxes): + """ + DO NOT USE THIS, THIS FUNCTION IS NOT WORKING PROPERLY + """ + if not bboxes: + return 0, 0 + + x = [bboxes[k]["right"] for k in bboxes] + y = [bboxes[k]["bottom"] for k in bboxes] + + return max(x), max(y) + + def infer_from_screenshot(self, screenshot): + h, w, _ = screenshot.shape + return w, h + + @staticmethod + def get_visible(xprops): + return {k: xprops[k]["visibility"] == 1.0 for k in xprops} + + @staticmethod + def rename_uid_attributes(dom_str, new_name="data-webtasks-id", old_name="bid"): + return dom_str.replace(f"{old_name}=", f"{new_name}=") + + def get_action(self, obs: dict) -> str: + # preprocessing + obs["dom_str"] = flatten_dom_to_str(obs["dom_object"]) + obs["bboxes"] = self.get_bboxes(obs["extra_element_properties"]) + # obs["axtree_txt"] = flatten_axtree_to_str(obs["axtree_object"]) + # obs["visible"] = self.get_visible(obs["extra_element_properties"]) + + vw, vh = self.infer_from_screenshot(obs["screenshot"]) + obs['html_str_orig'] = self.rename_uid_attributes(obs['dom_str']) + + obs["html_str"], attrs_map = replace_bid_with_wl_uid(obs["dom_str"], return_mapping=True) + obs["remapped_bboxes"] = remap_bboxes(obs["bboxes"], attrs_map=attrs_map) + reverse_attrs_map = reverse_dict(attrs_map) + + # check if we have new messages in the chat (+1 will skip first default message) + new_messages = obs["chat_messages"][self.num_messages + 1 :] + self.messages.extend(new_messages) + + # update action history with new messages + for message in new_messages: + role = "instructor" if message["role"] == "user" else "navigator" + if role == "instructor": + self.has_user_message = True + + self.action_history.append( + wa.classes.Action( + type="chat", + index=len(self.action_history), + intent="say", + args={"utterance": message["message"], "speaker": role}, + timestamp=time.time() - self.start_time, + tag=None, + attrs=None, + ) + ) + print(f"New message by '{role}': {message['message']}") + + if not self.has_user_message: + # sleep and do nothing if no user message has been received + return "wait(2)" + + state = wa.classes.State( + index=len(self.action_history), + html=obs["html_str"], + bboxes=obs["remapped_bboxes"], + viewport_height=vh, + viewport_width=vw, + type="browser", + ) + pred_action = wa.web.client.get_prediction( + self.action_history, + state, + address="localhost", + port=8450, + max_new_tokens=128, + ) + # breakpoint() + pred_action = postprocess_for_browsergym(pred_action, uid_map=reverse_attrs_map) + # pred_action = postprocess_for_browsergym(pred_action) + + a = wa.classes.Action.from_dict(pred_action) + + # add action to action history + self.action_history.append(a) + + action_str = a.to_str() + print("Action String:", action_str) + + return action_str + + def get_action_mapping(self) -> callable: + """ + Returns a callable that can be used to map the agent actions to executable python code. + """ + action_set = HighLevelActionSet( + subsets="custom", + custom_actions=[say, click, textinput, load, scroll, wait], + multiaction=False, + strict=True, + ) + return action_set.to_python_code diff --git a/examples/browsergym/run_bg.py b/examples/browsergym/run_bg.py new file mode 100644 index 0000000..cbfe2c2 --- /dev/null +++ b/examples/browsergym/run_bg.py @@ -0,0 +1,22 @@ +import gymnasium as gym +import browsergym.core # register the openended task as a gym environment +from examples.browsergym.agent import WebLinxAgent + +agent = WebLinxAgent() + +env = gym.make( + "browsergym/openended", + headless=False, + wait_for_user_message=False, + action_mapping=agent.get_action_mapping(), + task_kwargs={"start_url": "chrome://newtab"}, + # task_kwargs={"start_url": "https://en.wikipedia.org"}, +) + +agent.reset() +obs, info = env.reset() + +done = False +while not done: + action = agent.get_action(obs) + obs, reward, terminated, truncated, info = env.step(action) diff --git a/examples/complete/run_all.py b/examples/complete/run_all.py new file mode 100644 index 0000000..23d48bf --- /dev/null +++ b/examples/complete/run_all.py @@ -0,0 +1,107 @@ +from functools import partial +import time +import logging + +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, pipeline +import weblinx as wl +import webllama_experimental as wa + +logging.getLogger("urllib3").setLevel(logging.WARNING) + +# Prerequisite: We need a `wa.classes.State` object and a list of `wa.classes.Action` objects +# To get that, we will use an example from weblinx, but it's easy to do manually (see below). + +demos = wl.list_demonstrations("tests/demonstrations") +replay = wl.Replay.from_demonstration(demos[0]) +turn = replay[26] + +format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict +) +format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() +action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index +) + +state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, +) + + +# Verifying that the to_dict and from_dict methods work as expected +act = action_history[0] +d = act.to_dict() +act2 = wa.classes.Action.from_dict(d) +assert act == act2 + +d = state.to_dict() +state2 = wa.classes.State.from_dict(d) +assert state == state2 + + +# Now, we can start! +# First, load the DMR model we will use to select candidate elements +dmr_name = "McGill-NLP/MiniLM-L6-dmr" +action_model_name = "McGill-NLP/Sheared-LLaMA-2.7B-weblinx" +tokenizer_chat_name = "McGill-NLP/Llama-2-7b-chat-weblinx" + +tokenizer = AutoTokenizer.from_pretrained(action_model_name) +tokenizer_chat = AutoTokenizer.from_pretrained(tokenizer_chat_name) +dmr = SentenceTransformer(dmr_name, device="cuda") +action_model = pipeline(model=action_model_name, device=0, torch_dtype="auto") + +# We will initialize our processor, which helps us prepare the input for action model +proc = wa.processing.WebTurnProcessor(tokenizer=tokenizer, start_time=time.time()) + +# Step 1: prepare query, run DMR and prepare retrieved candidates +query_dmr = proc.prepare_dmr_query(action_history, state) +elems = proc.prepare_dmr_elements(state=state) +scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) +top_cands, cands_uids = wa.functions.get_top_dmr_candidates(elems, scores, proc.top_k) +cands_str = proc.prepare_candidates(top_cands) + +# Step 2: format candidates, utterances, state, and previous actions +html = proc.prepare_state_html(state.html, cands_uids=cands_uids) +utterances = proc.prepare_instructor_chat(action_history, state) +prev_actions = proc.prepare_prev_actions(action_history, state) + +# Let's use the default system prompt template, but you can also use your own +sys_prompt_template: str = proc.default_system_prompt_template +sys_prompt = sys_prompt_template.format( + html=html, + num_utterances=proc.num_utterances - 1, + utterances=utterances, + height=state.viewport_height, + width=state.viewport_width, + num_prev_actions=proc.num_prev_actions, + candidates=cands_str, +) +input_chat = proc.convert_to_chat_list(sys_prompt, prev_actions) + +# We can now use the tokenizer's apply_chat_template method to convert it to a format +# that can be used by the action model +input_str = tokenizer_chat.apply_chat_template(input_chat, tokenize=False) + +# Let's now pass our input to the action model +output = action_model( + input_str, + max_new_tokens=256, + return_full_text=False, + batch_size=1, + pad_token_id=tokenizer.eos_token_id, +) +pred_action = proc.process_action_model_output( + output=output, index=state.index, elems=elems +) +# optional: For certain platforms you may need to postprocess the action +pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) +print(pred_action) +# You can now convert this an Action object and add it to the action history +a = wa.classes.Action.from_dict(pred_action) +action_history.append(a) diff --git a/examples/web_api/run_client.py b/examples/web_api/run_client.py new file mode 100644 index 0000000..84c6a39 --- /dev/null +++ b/examples/web_api/run_client.py @@ -0,0 +1,31 @@ +from functools import partial +import webllama_experimental as wa +import weblinx as wl + +demos = wl.list_demonstrations("tests/demonstrations") +replay = wl.Replay.from_demonstration(demos[0]) +turn = replay[26] + +format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict +) +format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() +action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index +) +state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, +) + +pred_action = wa.web.client.get_prediction( + action_history, state, address="localhost", port=8450, max_new_tokens=128 +) +pred_action = wa.integrations.browsergym.postprocess_for_browsergym(pred_action) +print(pred_action) +a = wa.classes.Action.from_dict(pred_action) +print(a) diff --git a/examples/web_api/run_http.py b/examples/web_api/run_http.py new file mode 100644 index 0000000..6911ed1 --- /dev/null +++ b/examples/web_api/run_http.py @@ -0,0 +1,63 @@ +from functools import partial +import http.client +import json + +import weblinx as wl +import webllama_experimental as wa + +def run_http(): + demos = wl.list_demonstrations("tests/demonstrations") + replay = wl.Replay.from_demonstration(demos[0]) + turn = replay[26] + + format_intent_am = partial( + wa.formatting.build_formatters_action_model(), return_as=dict + ) + format_intent_input_dmr, format_intent_out_dmr = wa.formatting.build_formatters_dmr() + action_history = wa.functions.create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=turn.index + ) + state = wa.classes.State( + index=turn.index, + html=turn.html, + bboxes=turn.bboxes, + viewport_height=turn.viewport_height, + viewport_width=turn.viewport_width, + type=turn.type, + ) + action_history_dict = [action.to_dict() for action in action_history] + state_dict = state.to_dict() + + # Create a connection to the localhost on the port where your server is running + conn = http.client.HTTPConnection('localhost', 8450) + + # Send a request without parameters to test server response + conn.request("POST", "/", body=json.dumps({}), headers={'Content-Type': 'application/json'}) + response = conn.getresponse() + print("Test 1 - Server Initialization Check:") + print(f"Status: {response.status}") + print(f"Reason: {response.reason}") + print(f"Body: {response.read().decode()}\n") + response.close() + + # Prepare the POST request data + post_data = json.dumps({ + 'action_history': action_history_dict, + 'state': state_dict + }) + headers = {'Content-Type': 'application/json'} + + # Send a POST request with JSON data + conn.request("POST", "/", body=post_data, headers=headers) + response = conn.getresponse() + print("Test 2 - Functionality Check:") + print(f"Status: {response.status}") + print(f"Reason: {response.reason}") + print(f"Body: {response.read().decode()}") + response.close() + + # Close the connection + conn.close() + +if __name__ == "__main__": + run_http() diff --git a/modeling/README.md b/modeling/README.md index 0f27035..5b32344 100644 --- a/modeling/README.md +++ b/modeling/README.md @@ -50,12 +50,12 @@ export WEBLLAMA_PROJECT_DIR=$(pwd) You need to install the dependencies by running the following command: ```bash -pip install -r requirements.txt +pip install -e .[extra] +pip install -r modeling/requirements.txt ``` However, due to `flash-attention` requiring `torch` to be pre-installed, it has to be install right after everything else has been installed: ```bash -pip install wheel # Regular install pip install "flash-attn>=2.3.0" # IF you have limited RAM, you can try this: @@ -139,7 +139,6 @@ Behind the scene, this will use the `weblinx.eval.auto_eval_and_save` function t Note that it might be slow the first time you run, because it reads a lot of demonstrations and load millions of files. However, a demo-level cache is automatically created (see `./.cache/demonstrations`), so the next time you run it, it should be much faster. - ### Dense Markup Ranking (DMR) #### Train DMR diff --git a/requirements-basic.txt b/requirements-basic.txt new file mode 100644 index 0000000..a100a56 --- /dev/null +++ b/requirements-basic.txt @@ -0,0 +1,3 @@ +weblinx>=0.3.0rc1 +lxml +numpy \ No newline at end of file diff --git a/requirements-extra.txt b/requirements-extra.txt new file mode 100644 index 0000000..e5ca80d --- /dev/null +++ b/requirements-extra.txt @@ -0,0 +1,6 @@ +weblinx[eval]>=0.3.0.rc1 +streamlit +sentence-transformers +transformers +playwright +browsergym diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0bb9b4c..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -weblinx[eval] -streamlit --r modeling/requirements.txt \ No newline at end of file diff --git a/setup.py b/setup.py index 4e9f51f..e309dd2 100644 --- a/setup.py +++ b/setup.py @@ -8,8 +8,15 @@ with open("README.md") as fp: long_description = fp.read() +with open('requirements-extra.txt') as f: + extras = f.read().splitlines() + +with open('requirements-basic.txt') as f: + install_requires = f.read().splitlines() + extras_require = { "dev": ["black"], + "extra": extras, } # Dynamically create the 'all' extra by combining all other extras extras_require["all"] = sum(extras_require.values(), []) @@ -24,7 +31,7 @@ long_description=long_description, packages=find_packages(include=[f"{package_name}*"]), package_data={}, - install_requires=['scipy', 'numpy'], + install_requires=install_requires, extras_require=extras_require, classifiers=[ "Programming Language :: Python :: 3", diff --git a/webllama/__init__.py b/webllama/__init__.py index 7152555..540f0fa 100644 --- a/webllama/__init__.py +++ b/webllama/__init__.py @@ -1 +1,2 @@ -from .version import __version__ \ No newline at end of file +from .version import __version__ +from . import experimental \ No newline at end of file diff --git a/webllama/experimental/__init__.py b/webllama/experimental/__init__.py new file mode 100644 index 0000000..a128cc6 --- /dev/null +++ b/webllama/experimental/__init__.py @@ -0,0 +1 @@ +from . import classes, functions, integrations, formatting, processing, templates, web \ No newline at end of file diff --git a/webllama/experimental/classes.py b/webllama/experimental/classes.py new file mode 100644 index 0000000..6aecfaa --- /dev/null +++ b/webllama/experimental/classes.py @@ -0,0 +1,353 @@ +from copy import deepcopy +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple, TypedDict +import typing + +from weblinx.utils.format import format_output_dictionary + +# Custom types +UID = typing.NewType("UID", str) +AttrsCore = TypedDict( + "AttrsCore", + {"class": str, "title": str, "href": str, "aria-label": str, "d": str, "src": str}, +) + + +class BBox(TypedDict): + """ + A class to represent the bounding box of an element. + + Attributes + ---------- + x : int + The x-coordinate of the bounding box. + y : int + The y-coordinate of the bounding box. + width : float + The width of the bounding box. + height : float + The height of the bounding box. + top : float, optional + The top position of the bounding box, calculated from `y` if not provided. + bottom : float, optional + The bottom position of the bounding box, calculated from `y` and `height` if not provided. + left : float, optional + The left position of the bounding box, calculated from `x` if not provided. + right : float, optional + The right position of the bounding box, calculated from `x` and `width` if not provided. + """ + x: int + y: int + width: float + height: float + top: float = None + bottom: float = None + left: float = None + right: float = None + + def __post_init__(self): + """ + Ensures required attributes are provided and calculates optional attributes if not given. + For example, if `top` is not provided, it is calculated from `y`. + """ + if any(x is None for x in [self.x, self.y, self.width, self.height]): + raise ValueError("x, y, width, and height must be provided.") + + if self.top is None: + self.top = self.y + + if self.bottom is None: + self.bottom = self.y + self.height + + if self.left is None: + self.left = self.x + + if self.right is None: + self.right = self.x + self.width + + +@dataclass +class State: + """ + A class to represent the state during navigation. + + Attributes + ---------- + index : int + The index of the state in the sequence of states. + html : str + The DOM tree represented using HTML. + bboxes : Dict[UID, BBox] + A dictionary mapping unique IDs to bounding boxes. + viewport_height : int + The height of the viewport of the browser. + viewport_width : int + The width of the viewport of the browser. + type : str + The type of the state, either "browser" or "chat". + + Methods + ------- + from_dict(cls, dictionary): + Creates a `State` instance from a dictionary. + to_dict(): + Converts the `State` instance to a dictionary. + """ + index: int + html: str + bboxes: Dict[UID, BBox] + viewport_height: int + viewport_width: int + type: str # either "browser" or "chat" + + # check type + def __post_init__(self): + if self.type not in ["browser", "chat"]: + raise ValueError("type must be either 'browser' or 'chat'.") + + @classmethod + def from_dict(cls, dictionary): + """ + Creates a `State` instance from a dictionary. + + Parameters + ---------- + dictionary : dict + The dictionary to create the `State` instance from. + + Returns + ------- + State + The created `State` instance. + """ + return cls( + index=dictionary["index"], + html=dictionary["html"], + bboxes=dictionary["bboxes"], + viewport_height=dictionary["viewport_height"], + viewport_width=dictionary["viewport_width"], + type=dictionary["type"], + ) + + def to_dict(self): + """ + Converts the `State` instance to a dictionary. + + Returns + ------- + dict + A dictionary representation of the `State` instance. + """ + return { + "index": self.index, + "html": self.html, + "bboxes": self.bboxes, + "viewport_height": self.viewport_height, + "viewport_width": self.viewport_width, + "type": self.type, + } + +@dataclass +class Action: + """ + A class to represent an action taken by the user. + + Attributes + ---------- + type : str + The type of the action, either "chat" or "browser". + index : int + The index of the action in the sequence of state/actions. + intent : str + The intent of the action (e.g., "click", "type", "scroll", "say"). + args : Dict[str, str] + A dictionary of arguments associated with the action, such as the unique + ID of the element clicked, the text typed, or the message said. + timestamp : float + The timestamp of the action in seconds, relative to the start time. + tag : str, optional + The HTML tag associated with the action (e.g., "button", "input"). + attrs : AttrsCore, optional + The attributes associated with the action (e.g., "class", "title", "href", "aria-label", "d", "src"). + """ + type: str + index: int + intent: str + args: Dict[str, str] + timestamp: float + tag: str = None + attrs: AttrsCore = None + + def get(self, key): + """ + Retrieves the value of the specified argument key. + + Parameters + ---------- + key : str + The key of the argument to retrieve. + + Returns + ------- + str + The value of the specified argument key. + """ + return self.args.get(key, None) + + @classmethod + def from_dict( + cls, + dictionary: Dict, + included_attrs: Tuple[str] = ("class", "title", "href", "aria-label", "d", "src"), + ) -> "Action": + """ + Creates an `Action` instance from a dictionary. + + Parameters + ---------- + dictionary : dict + The dictionary to create the `Action` instance from. It should have the following + keys: "intent", "index", "timestamp", "attrs" (optional), "tag" (optional), and + any other keys as arguments. Moreover, the type of the action is inferred from + the "intent" key. + included_attrs : tuple of str, optional + A tuple of attribute keys to include in the `attrs` dictionary. + + Returns + ------- + Action + The created `Action` instance. + """ + di = deepcopy(dictionary) + intent = di.pop("intent") + index = di.pop("index") + timestamp = di.pop("timestamp") + attrs = di.pop("attrs", None) + if attrs is not None: + attrs = {k: v for k, v in attrs.items() if k in included_attrs} + + args = di + type_ = "chat" if intent == "say" else "browser" + tag = di.pop("tag") if "tag" in di else None + + return cls( + index=index, + intent=intent, + args=args, + type=type_, + timestamp=timestamp, + attrs=attrs, + tag=tag, + ) + + def to_dict( + self, + include_timestamp=True, + include_attrs=True, + include_tag=True, + include_index=True, + drop_none_coords=False, + format_timestamp_fn=None, + ignore_args=None, + ): + """ + Convert the action to a dictionary, given specific options. + + Parameters + ---------- + include_timestamp: bool + Whether to include the timestamp in the output dictionary, as "timestamp" + include_attrs: bool + Whether to include the attributes in the output dictionary, as "attrs" + include_tag: bool + Whether to include the tag in the output dictionary, as "tag" + include_index: bool + Whether to include the index in the output dictionary, as "index" + ignore_args: list + A list of keys to ignore in the args dictionary, if None, then all keys are included + format_timestamp_fn: callable + A function to format the timestamp, if None, then the raw timestamp is used + start_time: float + The start time of the action, used to calculate the timestamp + + Returns + ------- + dict + A dictionary representation of the action. + """ + if ignore_args is not None: + args = {k: v for k, v in self.args.items() if k not in ignore_args} + else: + args = self.args + + out = {"intent": self.intent, **args} + + if include_tag and self.tag is not None: + out["tag"] = self.tag + + if include_attrs and self.attrs is not None: + out["attrs"] = self.attrs + + if include_timestamp: + if format_timestamp_fn is not None: + out["timestamp"] = format_timestamp_fn(self)["timestamp"] + else: + out["timestamp"] = self.timestamp + + if include_index: + out["index"] = self.index + + if drop_none_coords: + if "x" in out and out["x"] is None: + del out["x"] + if "y" in out and out["y"] is None: + del out["y"] + + return out + + def to_str(self, **kwargs): + """ + Converts the `Action` instance to a formatted string. + + Parameters + ---------- + kwargs : dict + Keyword arguments to pass to the `to_dict` method. + + Returns + ------- + str + A formatted string representation of the action. + + Notes + ----- + + This runs the `to_dict` method and then formats the output dictionary as a string, using + `weblinx.utils.format.format_output_dictionary` with the intent as the "function" key. + """ + di = self.to_dict(**kwargs) + return format_output_dictionary(di, function_key="intent", return_as=str) + + def items(self): + """ + Mimics `weblinx.Turn.items()` to retrieve dictionary items of the action. + + Returns + ------- + ItemsView + A view object that displays a list of a dictionary's key-value tuple pairs. + + Notes + ----- + + This method is aimed to mimic `weblinx.Turn.items()` + """ + di = self.to_dict( + include_timestamp=True, + include_attrs=False, + include_tag=False, + include_index=False, + drop_none_coords=True, + ) + + return di.items() diff --git a/webllama/experimental/formatting.py b/webllama/experimental/formatting.py new file mode 100644 index 0000000..34ac2ea --- /dev/null +++ b/webllama/experimental/formatting.py @@ -0,0 +1,191 @@ +from functools import partial +import weblinx.utils.format as wlf + +def build_formatters_action_model() -> callable: + """ + Builds and returns a dictionary of formatters for action model events. + + This function uses partial functions from the `weblinx.utils.format` module to create + formatters for various user actions, such as clicks, text inputs, changes, etc. These + formatters are then combined into a single formatter function for automatically formatting + intents based on user actions. + + Returns + ------- + function + A function that formats intents automatically using the defined formatters. + + Notes + ----- + + Slightly improved over original implementation from weblinx: + https://github.com/McGill-NLP/weblinx/blob/7f151eaf819a9665b9b0b2232a99db6d4c4d2738/modeling/llama/processing.py#L23 + """ + format_click = partial(wlf.format_click, formatters=(wlf.format_uid,)) + format_text_input = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text", max_length=200), + wlf.format_uid, + ), + ) + format_change = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value", max_length=200), + wlf.format_uid, + ), + ) + format_copy = partial(wlf.format_copy, include_timestamp=False) + format_submit = partial(wlf.format_submit, formatters=(wlf.format_uid,)) + format_load = partial( + wlf.format_load, + include_transition=False, + include_timestamp=False, + max_length=200, + ) + format_hover = partial(wlf.format_hover, formatters=(wlf.format_uid,)) + format_paste = partial(wlf.format_paste, include_timestamp=False) + format_scroll = partial(wlf.format_scroll, include_timestamp=False) + format_say = partial(wlf.format_say, include_timestamp=False) + format_tab = wlf.format_tab + + format_intent_auto = partial( + wlf.format_intent_automatically, + format_change=format_change, + format_click=format_click, + format_copy=format_copy, + format_hover=format_hover, + format_load=format_load, + format_paste=format_paste, + format_say=format_say, + format_scroll=format_scroll, + format_submit=format_submit, + format_tab=format_tab, + format_text_input=format_text_input, + ) + + return format_intent_auto + + +def build_formatters_dmr(): + """ + Builds and returns two dictionaries of formatters for DMR (Document Model Retrieval) events. + + This function creates formatters for both input and output events using partial functions + from the `weblinx.utils.format` module. For inputs, it formats elements, clicks, changes, + hovers, submits, and text inputs. For outputs, it formats elements, clicks, changes, loads, + scrolls, and text inputs. + + Returns + ------- + tuple of functions + A tuple containing two functions: one for formatting input intents and one for formatting + output intents. + + + Examples + ----- + + ```python + format_intent_input, format_intent_out = build_formatters_dmr() + ``` + + """ + format_element_input = partial( + wlf.format_element, + include_text=False, + include_attrs=("class", "title", "href", "aria-label", "d", "src"), + ) + format_click_input = partial( + wlf.format_click, + formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), + ) + format_change_input = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value"), + format_element_input, + wlf.format_timestamp, + ), + ) + format_hover_input = partial( + wlf.format_hover, + formatters=(wlf.format_mouse_xy, format_element_input, wlf.format_timestamp), + ) + + format_submit_input = partial( + wlf.format_submit, formatters=(format_element_input, wlf.format_timestamp) + ) + + format_text_input_input = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text"), + partial(format_element_input), + wlf.format_timestamp, + ), + ) + + format_intent_input = partial( + wlf.format_intent_automatically, + format_click=format_click_input, + format_change=format_change_input, + format_hover=format_hover_input, + format_submit=format_submit_input, + format_text_input=format_text_input_input, + format_tab=wlf.format_tab, + return_as=str, + ) + + # second, for the output (prediction text) + format_element_out = partial( + wlf.format_element, + # Only want the tag + include_text=False, + include_attrs=False, + ) + + format_click_out = partial(wlf.format_click, formatters=(wlf.format_mouse_xy,)) + format_text_input_out = partial( + wlf.format_text_input, + formatters=( + partial(wlf.format_arg_item, name="text", max_length=200), + format_element_out, + wlf.format_target_bbox, + ), + ) + format_change_out = partial( + wlf.format_change, + formatters=( + partial(wlf.format_arg_item, name="value", max_length=200), + format_element_out, + wlf.format_target_bbox, + ), + ) + format_submit_out = partial( + wlf.format_submit, formatters=(format_element_out, wlf.format_target_bbox) + ) + format_load_out = partial( + wlf.format_load, + include_transition=False, + include_timestamp=False, + max_length=200, + ) + format_scroll_out = partial(wlf.format_scroll, include_timestamp=False) + + format_say_out = partial(wlf.format_say, include_timestamp=False) + + format_intent_out = partial( + wlf.format_intent_automatically, + format_change=format_change_out, + format_click=format_click_out, + format_load=format_load_out, + format_say=format_say_out, + format_scroll=format_scroll_out, + format_submit=format_submit_out, + format_text_input=format_text_input_out, + ) + + return format_intent_input, format_intent_out + diff --git a/webllama/experimental/functions.py b/webllama/experimental/functions.py new file mode 100644 index 0000000..40c4b2d --- /dev/null +++ b/webllama/experimental/functions.py @@ -0,0 +1,702 @@ +import time +from typing import Dict, List, TYPE_CHECKING +import typing + +import lxml.html +import weblinx.utils.html as wh +from weblinx.processing.prompt import ( + find_turns_with_instructor_chat, + format_utterances, + get_speaker, + format_prev_turns, +) +from weblinx.processing.outputs import ( + parse_predicted_output_string, + sanitize_args, +) +from weblinx.processing.truncation import ( + convert_elem_dict_to_str_dmr, +) + +from .classes import Action, BBox, State + +if TYPE_CHECKING: + from transformers import AutoTokenizer + +# Custom types +UID = typing.NewType("UID", str) +DEFAULT_FINAL_USER_MSG = "Please select the best action using the correct format, do not provide any other information or explanation." + + +def _shorten(s, max_length=100, side="center", ellipsis="..."): + """ + Shortens a string to a specified maximum length, adding an ellipsis if necessary. + + Parameters + ---------- + s : str + The string to be shortened. + max_length : int, optional + The maximum length of the shortened string, including the ellipsis. Default is 100. + side : str, optional + The side from which to shorten the string. Options are "center", "left", and "right". Default is "center". + ellipsis : str, optional + The string to use as an ellipsis. Default is "...". + + Returns + ------- + str + The shortened string. + """ + if max_length is None: + return s + + if len(s) <= max_length: + return s + + max_length = max_length - len(ellipsis) + + if side == "right": + s = s[:max_length] + ellipsis + elif side == "left": + s = ellipsis + s[-max_length:] + elif side == "center": + s = s[: max_length // 2] + ellipsis + s[-max_length // 2 :] + else: + raise ValueError(f"Invalid side: {side}") + + return s + + +def _format_attrs(attrs): + """ + Formats a dictionary of attributes into a string. + + Parameters + ---------- + attrs : dict + The dictionary of attributes to format. + + Returns + ------- + str + The formatted attributes as a string. + """ + return " ".join([f"{k!s}={v!r}" for k, v in attrs.items()]) + + +def _extract_output_str_from_pipeline_output(output): + """ + Extracts the output string from a pipeline output. + + Parameters + ---------- + output : str or dict or list + The output from the pipeline, which can be a string, a dictionary, or a list. + + Returns + ------- + str + The extracted output string. + """ + if isinstance(output, str): + output_str = output + elif isinstance(output, dict): + if "generated_text" not in output: + raise ValueError("Output dictionary does not have 'generated_text' key.") + output_str = output["generated_text"] + + elif isinstance(output, list): + if len(output) == 0: + raise ValueError("Output list is empty, cannot be processed.") + if len(output) > 1: + raise ValueError( + "Output list has more than one element, cannot be processed." + ) + o = output[0] + + if isinstance(o, str): + output_str = o + elif isinstance(o, dict): + if "generated_text" not in o: + raise ValueError( + "Output dictionary does not have 'generated_text' key." + ) + output_str = o["generated_text"] + else: + raise ValueError( + f"Output list has element of type {type(o)}, cannot be processed." + ) + + return output_str + + +def represent_element_as_dict( + element, + bbox, + root_tree, + max_text_length=200, + max_attr_length=100, + max_child_depth=2, + return_attrs=True, +): + """ + Format an lxml element into a dictionary of strings. + + Parameters + ---------- + element : lxml.html.Element + The element to format. + bbox : dict + The bounding box of the element. + root_tree : lxml.html.ElementTree + The root tree of the document. + max_text_length : int, optional + The maximum length of the text. Default is 200. + max_attr_length : int, optional + The maximum length of the attributes. Default is 100. + max_child_depth : int, optional + The maximum depth of children to include. Default is 2. + return_attrs : bool, optional + Whether to return the attributes. Default is True. + + Returns + ------- + dict + A dictionary representation of the element. + + Notes + ----- + We expect the following keys in the output dictionary: + - tag: the tag name of the element + - xpath: the xpath of the element + - text: the text of the element, truncated to `max_text_length` + - bbox: the bounding box of the element + - attributes: the attributes of the element, truncated to `max_attr_length` + - children: the children of the element, truncated to `max_attr_length` + """ + # Get the tag name + tag = element.tag + xpath = root_tree.getpath(element) + children = element.getchildren() + text = element.text if element.text is not None else "" + + # Shorten the text and attributes + text = _shorten(text, max_text_length) + attrs = {k: _shorten(v, max_attr_length) for k, v in element.attrib.items()} + + # Sort the attributes by length + attrs = dict(sorted(attrs.items(), key=lambda x: len(x[1]))) + + # Truncate the children + children = children[:max_child_depth] + + # Format the children + children_str = " ".join([c.tag for c in children if isinstance(c.tag, str)]) + children_str = _shorten(children_str, max_attr_length) + + # Format the attributes + attrs_str = _format_attrs(attrs) + + # Format the bounding box + bbox_str = " ".join( + [f"{k}={round(bbox[k], 1)}" for k in ["x", "y", "width", "height"]] + ) + + # format as a dict + element_dict = { + "tag": tag, + "xpath": xpath, + "text": text, + "bbox": bbox_str, + "attributes": attrs_str, + "children": children_str, + } + + if return_attrs: + return element_dict, attrs + else: + return element_dict + + +def calculate_remaining_tokens_before_candidates( + html: str, + utterance_context: str, + prev_actions_as_chat: List[Dict[str, str]], + tokenizer: "AutoTokenizer", + max_html_tokens: int, + max_utterance_tokens: int, + max_prev_turns_tokens: int, + tokenizer_chat: "AutoTokenizer" = None, +): + """ + Calculates the remaining tokens before candidates can be processed. + + Parameters + ---------- + html : str + The HTML content. + utterance_context : str + The context of the current utterance. + prev_actions_as_chat : list of dict + A list of previous actions formatted as chat. + tokenizer : AutoTokenizer + The tokenizer for the HTML content. + max_html_tokens : int + The maximum number of HTML tokens. + max_utterance_tokens : int + The maximum number of utterance tokens. + max_prev_turns_tokens : int + The maximum number of previous turn tokens. + tokenizer_chat : AutoTokenizer, optional + The tokenizer for the chat content. Default is None. + + Returns + ------- + int + The remaining tokens before candidates. + """ + if tokenizer_chat is None: + tokenizer_chat = tokenizer + + num_prev_turns_tokens = len( + tokenizer_chat.apply_chat_template( + [{"role": "system", "content": ""}, *prev_actions_as_chat], tokenize=True + ) + ) + # Add the unused length to the candidates + num_html_tokens = len(tokenizer.tokenize(html)) + num_utter_tokens = len(tokenizer.tokenize(utterance_context)) + remain_html_tokens = max_html_tokens - num_html_tokens + remain_utter_tokens = max_utterance_tokens - num_utter_tokens + remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens + remain_tokens = remain_html_tokens + remain_utter_tokens + remain_prev_turns_tokens + # Add the unused length to the max_candidates_tokens + remain_tokens + + return remain_tokens + + +def convert_prev_actions_to_chat( + prev_actions_text_list, + final_user_message=DEFAULT_FINAL_USER_MSG, + add_dummy_user_first="auto", +): + """ + Converts previous actions to a chat format compatible with huggingface + (e.g. the `tokenizer.apply_chat_template` method) and openai. + + Parameters + ---------- + prev_actions_text_list : list of str + A list of previous actions as text. + final_user_message : str, optional + The final user message to include. Default is DEFAULT_FINAL_USER_MSG. + add_dummy_user_first : str or bool, optional + Whether to add a dummy user message first. Default is 'auto'. + + Returns + ------- + list of dict + A list of previous actions formatted as chat. + """ + prev_turns_merged = [] + + # Merge turns from the same role + for i, action_text in enumerate(prev_actions_text_list): + role = get_speaker( + action_text, + instructor_name="user", + navigator_name="assistant", + default_name="unknown", + ) + + if i > 0 and prev_turns_merged[-1]["role"] == role: + prev_turns_merged[-1]["content"] += " " + action_text + else: + prev_turns_merged.append({"role": role, "content": action_text}) + + if len(prev_turns_merged) > 0 and prev_turns_merged[-1]["role"] == "user": + prev_turns_merged[-1]["content"] += " " + final_user_message + else: + prev_turns_merged.append({"role": "user", "content": final_user_message}) + + if add_dummy_user_first == "auto": + # only add dummy user if the user has not provided any input at the beginning + add_dummy_user_first = prev_turns_merged[0]["role"] == "assistant" + + if add_dummy_user_first is True: + # This is needed in case the user has not provided any input, + # since the tokenizet apply_chat_template will not work if there's no user/assistant + # alternating + prev_turns_merged.insert(0, {"role": "user", "content": ""}) + + return prev_turns_merged + + +def prepare_query_for_dmr( + action_history: List[Action], + state: State, + format_intent, + turn_sep=" ; ", + num_prev_turns=5, + num_utterances=5, + return_str=True, +): + """ + Formats a turn for query input to the DMR model. + + Parameters + ---------- + action_history : list of Action + The history of actions. + state : State + The current state. + format_intent : callable + The function to format intents. + turn_sep : str, optional + The separator for turns. Default is " ; ". + num_prev_turns : int, optional + The number of previous turns to include. Default is 5. + num_utterances : int, optional + The number of utterances to include. Default is 5. + return_str : bool, optional + Whether to return the result as a string, or as a 2-tuple of strings, one for the utterance context + and one for the previous turns. Default is True. + + Returns + ------- + str or tuple of str + The formatted query, either as a single string or as a tuple of strings. + + Notes + ----- + To format a turn for query input to the DMR model, we combine the following: + + 1. The first and last `num_utterances-1` utterances from the instructor + 2. The previous turns (up to `num_prev_turns` turns) + """ + prev_turns_text = format_prev_turns( + replay=action_history, + turn=state, + format_intent=format_intent, + turn_sep=turn_sep, + num_prev_turns=num_prev_turns, + ) + instructor_chat_turns = find_turns_with_instructor_chat( + action_history, state, num_prev_turns=num_prev_turns + ) + utterance_context = format_utterances( + instructor_chat_turns, num_utterances=num_utterances + ) + + if not return_str: + return utterance_context, prev_turns_text + + # Now, let's combine the text from the previous turns with the utterance context + # and the current turn's utterance + text = ( + f"Viewport(height={state.viewport_height}, width={state.viewport_width}) ---- " + f"Instructor Utterances: {utterance_context} ---- " + f"Previous Turns:{prev_turns_text}" + ) + + return text + + +def prepare_dmr_elements( + state: State, uid_key: str = "data-webtasks-id", bboxes: BBox = None +) -> List[dict]: + """ + Builds a list of dictionaries representing DMR elements. + + Parameters + ---------- + state : State + The current state. + uid_key : str, optional + The key for the unique identifier. Default is "data-webtasks-id". + bboxes : BBox, optional + The bounding boxes of elements. Default is None. + + Returns + ------- + list of dict + A list of dictionaries (i.e. record) representing DMR elements. + + Notes + ----- + + Each record in the output has the following keys: + - query: the dialogue history, used as a query for training the model + - doc: concise representation of HTML element used as doc for training + - label: either 0 or 1, indicating whether the document is the target element + - uid: the unique identifier for an element, must be in the element attributes + - turn_index: the index of the turn in the replay + - demo_name: the name of the demonstration + """ + if bboxes is None: + bboxes = state.bboxes + + bboxes_filt = wh.filter_bboxes( + bboxes, + viewport_height=state.viewport_height, + viewport_width=state.viewport_width, + ) + root = lxml.html.fromstring(state.html) + root_tree = root.getroottree() + elements = root.xpath(f"//*[@{uid_key}]") + elements_filt = [p for p in elements if p.attrib[uid_key] in bboxes_filt] + candidates = [] + + for elem in elements_filt: + bbox = bboxes[elem.attrib[uid_key]] + elem_dict, attrs = represent_element_as_dict(elem, bbox, root_tree) + elem_str = convert_elem_dict_to_str_dmr(elem_dict) + + record = { + "doc": elem_str, + "uid": elem.attrib[uid_key], + "elem_dict": elem_dict, + "attrs": attrs, + } + candidates.append(record) + + return candidates + + +def cos_sim(a, b): + """ + Computes the cosine similarity between two matrices. + + Parameters + ---------- + a : torch.Tensor + The first matrix. + b : torch.Tensor + The second matrix. + + Returns + ------- + torch.Tensor + A matrix with cosine similarities between all pairs of rows in a and b. + + Notes + ----- + Taken from: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/util.py + Use of this function is subject to the following license: https://github.com/UKPLab/sentence-transformers/blob/master/LICENSE + """ + try: + import torch + except: + raise ImportError("This function requires PyTorch to be installed.") + + if not isinstance(a, torch.Tensor): + a = torch.tensor(a) + + if not isinstance(b, torch.Tensor): + b = torch.tensor(b) + + if len(a.shape) == 1: + a = a.unsqueeze(0) + + if len(b.shape) == 1: + b = b.unsqueeze(0) + + a_norm = torch.nn.functional.normalize(a, p=2, dim=1) + b_norm = torch.nn.functional.normalize(b, p=2, dim=1) + return torch.mm(a_norm, b_norm.transpose(0, 1)) + + +def compute_dmr_scores(dmr, query, cands, batch_size=16): + """ + Computes the DMR scores for candidate elements. + + Parameters + ---------- + dmr : callable + The DMR model. + query : str + The query string, returned by `prepare_query_for_dmr`. + cands : list of dict + The candidate elements, returned by `prepare_dmr_elements`. + batch_size : int, optional + The batch size for encoding. Default is 16. + + Returns + ------- + list of float + The DMR scores for the candidate elements. + + Examples + -------- + + ```python + from sentence_transformers import SentenceTransformer + from transformers import AutoTokenizer + + dmr = SentenceTransformer(dmr_name, device="cuda") + tokenizer = AutoTokenizer.from_pretrained(act_model) + proc = wa.processing.WebTurnProcessor(tokenizer=tokenizer) + + query_dmr = proc.prepare_dmr_query(action_history, state) + elems = proc.prepare_dmr_elements(state=state) + scores = wa.functions.compute_dmr_scores(dmr, query_dmr, elems) + ``` + """ + cands_docs = [cand["doc"] for cand in cands] + encoded = dmr.encode( + [query] + cands_docs, batch_size=batch_size, show_progress_bar=False + ) + query_vector, doc_vectors = encoded[0], encoded[1:] + scores = cos_sim(query_vector, doc_vectors).cpu().squeeze().tolist() + + return scores + + +def get_top_dmr_candidates(candidate_elements, scores, top_k): + """ + Gets the top DMR candidates based on scores. + + Parameters + ---------- + candidate_elements : list of dict + The candidate elements. + scores : list of float + The scores for the candidate elements. + top_k : int + The number of top candidates to return. + + Returns + ------- + tuple of (list of dict, list of str) + The top candidates and their unique identifiers. + """ + indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True) + top_cands_indices = indices[:top_k] + top_cands = [candidate_elements[i] for i in top_cands_indices] + cands_uids = [cand["uid"] for cand in top_cands] + + return top_cands, cands_uids + + +def infer_tag_and_attrs( + args, elems, accepted=("class", "title", "href", "aria-label", "d", "src") +): + """ + Infers the tag and attributes for an element based on its arguments. This is an internal function + used by the `process_action_model_output`. + + Parameters + ---------- + args : dict + The arguments for the element. + elems : list of dict + The elements to infer from. + accepted : tuple of str, optional + The accepted attributes. Default is ("class", "title", "href", "aria-label", "d", "src"). + + Returns + ------- + tuple of (str, dict) + The inferred tag and attributes. + """ + if "uid" in args: + uids_to_elems = {elem["uid"]: elem for elem in elems} + uid = args["uid"] + if uid in uids_to_elems: + elem_dict = uids_to_elems[uid]["elem_dict"] + # need to take elem_dict's attributes_dict and bbox from state.bboxes + attrs = { + k: v for k, v in uids_to_elems[uid]["attrs"].items() if k in accepted + } + tag = elem_dict["tag"] + else: + attrs, tag = {}, None + + else: + attrs, tag = {}, None + + return tag, attrs + + +def create_action_history_from_replay( + replay, format_intent_input_dmr, format_intent_am, stop_at_idx=None +): + """ + Creates an action history from a replay of actions. The replay must be from + WebLINX, or a strictly compatible format. + + Parameters + ---------- + replay : list of Action + The replay of actions. + format_intent_input_dmr : callable + The function to format DMR input intents. + format_intent_am : callable + The function to format action model intents. + stop_at_idx : int, optional + The index to stop at in the replay. Default is None. + + Returns + ------- + list of Action + The created action history. + """ + if stop_at_idx is None: + stop_at_idx = len(replay) + else: + stop_at_idx = min(stop_at_idx, len(replay)) + + idx = stop_at_idx + action_history: List[Action] = [] + for t in replay[:idx]: + dmr_input = format_intent_input_dmr(t, return_as=dict) + dictionary = format_intent_am(t, return_as=dict) + dictionary["tag"] = dmr_input.get("tag", None) + dictionary["x"] = dmr_input.get("x", None) + dictionary["y"] = dmr_input.get("y", None) + dictionary["attrs"] = dmr_input.get("attrs", None) + dictionary["timestamp"] = t.timestamp + dictionary["index"] = t.index + + action_history.append(Action.from_dict(dictionary)) + + return action_history + + +def process_action_model_output(output, index, elems, start_time): + """ + Processes the output of an action model to create a predicted action. + + Parameters + ---------- + output : str or dict or list + The output from the action model. + index : int + The index of the action. + elems : list of dict + The elements to process. + start_time : float + The start time of the action. + + Returns + ------- + dict + The predicted action. + """ + output_str = _extract_output_str_from_pipeline_output(output) + intent, args = parse_predicted_output_string(output_str) + args = sanitize_args(args) + tag, attrs = infer_tag_and_attrs(args=args, elems=elems) + pred_action = { + "index": index, + "intent": intent, + "timestamp": time.time() - start_time, + "tag": tag, + "attrs": attrs, + } + # only take args element that is not in pred_action + pred_action.update({k: v for k, v in args.items() if k not in pred_action}) + + return pred_action diff --git a/webllama/experimental/integrations/__init__.py b/webllama/experimental/integrations/__init__.py new file mode 100644 index 0000000..1f720d9 --- /dev/null +++ b/webllama/experimental/integrations/__init__.py @@ -0,0 +1 @@ +from . import browsergym \ No newline at end of file diff --git a/webllama/experimental/integrations/browsergym/__init__.py b/webllama/experimental/integrations/browsergym/__init__.py new file mode 100644 index 0000000..91f7b34 --- /dev/null +++ b/webllama/experimental/integrations/browsergym/__init__.py @@ -0,0 +1,69 @@ +from copy import deepcopy +import random +import lxml.html + + +def postprocess_for_browsergym(action, uid_map=None): + # if uid is a int, we need to convert it to a string + uid_map = {} if uid_map is None else uid_map + + if "uid" in action: + action["uid"] = str(action["uid"]) + if action["uid"] in uid_map: + action["uid"] = uid_map[action["uid"]] + + action = deepcopy(action) + if action["intent"] == "scroll": + if not "x" in action: + action["x"] = 0 + if not "y" in action: + action["y"] = 0 + + return action + + +def generate_uuid(old_attr_name): + # We do not use old_attr_name here, but it is required by the signature of the function. + def replace_char(c): + r = random.randint(0, 15) + v = r if c == "x" else (r & 0x3 | 0x8) + return format(v, "x") + + uuid_template = "xxxxxxxx-xxxx-4xxx" + return "".join(replace_char(c) if c in "xy" else c for c in uuid_template) + + +def reverse_dict(mapping): + return {v: k for k, v in mapping.items()} + +def replace_bid_with_wl_uid( + dom_str, + new_attr_name="data-webtasks-id", + old_attr_name="bid", + generate_fn=generate_uuid, + return_mapping=False, +): + """ + Replaces the bid attributes in the dom string with a new attribute name and a new unique id. + + generate_fn must be a function that takes the old attribute name and returns a new unique id. + """ + html_parsed = lxml.html.fromstring(dom_str) + + new_attr_mapping = { + str(elem.get(old_attr_name)): generate_fn(old_attr_name) + for elem in html_parsed.xpath(f"//*[@{old_attr_name}]") + if elem.get(old_attr_name) is not None + } + + # remap the attributes from bid="key" to data-webtasks-id="value" + for elem in html_parsed.xpath("//*[@bid]"): + elem.set(new_attr_name, new_attr_mapping[elem.get(old_attr_name)]) + elem.attrib.pop(old_attr_name) + + html_processed_str = lxml.html.tostring(html_parsed).decode("utf-8") + + if return_mapping: + return html_processed_str, new_attr_mapping + else: + return html_processed_str \ No newline at end of file diff --git a/webllama/experimental/integrations/browsergym/functions.py b/webllama/experimental/integrations/browsergym/functions.py new file mode 100644 index 0000000..6731175 --- /dev/null +++ b/webllama/experimental/integrations/browsergym/functions.py @@ -0,0 +1,68 @@ +from browsergym.core.action.utils import get_elem_by_bid +import playwright.sync_api + +page: playwright.sync_api.Page = None +send_message_to_user: callable = None + +# Define your actions here + +def say(utterance: str, *args, **kwargs): + """ + Sends a message to the user. + + Examples: + say("Based on the results of my search, the city was built in 1751.") + """ + send_message_to_user(utterance) + + +def click(uid: str, *args,**kwargs): + """ + Click an element. + + Examples: + click('51') + """ + elem = get_elem_by_bid(page, uid) + elem.click() + +def textinput(uid: str, value: str, *args,**kwargs): + """ + Fill out a form field. It focuses the element and triggers an input event with the entered text. + It works for ,