Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/performance test #850

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def attach_routes(self) -> APIRouter:
def get_recordings() -> dict[str, list[Recording]]:
"""Get all recordings."""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)
return {"recordings": recordings}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def populate_menu(self, menu: QMenu, action: Callable, action_type: str) -> None
action_type (str): The type of action to perform ["visualize", "replay"]
"""
session = crud.get_new_session(read_only=True)
recordings = crud.get_all_recordings(session)
recordings = crud.get_recordings(session)

self.recording_actions[action_type] = []

Expand Down
11 changes: 7 additions & 4 deletions openadapt/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Configuration module for OpenAdapt."""


from enum import Enum
from typing import Any, ClassVar, Type, Union
import json
Expand Down Expand Up @@ -33,6 +32,7 @@
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"
DB_FILE_PATH = (DATA_DIR_PATH / "openadapt.db").absolute()

STOP_STRS = [
"oa.stop",
Expand Down Expand Up @@ -124,7 +124,8 @@ class SegmentationAdapter(str, Enum):

# Database
DB_ECHO: bool = False
DB_URL: ClassVar[str] = f"sqlite:///{(DATA_DIR_PATH / 'openadapt.db').absolute()}"
DB_FILE_PATH: str = str(DB_FILE_PATH)
DB_URL: ClassVar[str] = f"sqlite:///{DB_FILE_PATH}"

# Error reporting
ERROR_REPORTING_ENABLED: bool = True
Expand Down Expand Up @@ -428,11 +429,13 @@ def show_alert() -> None:
"""Show an alert to the user."""
msg = QMessageBox()
msg.setIcon(QMessageBox.Warning)
msg.setText("""
msg.setText(
"""
An error has occurred. The development team has been notified.
Please join the discord server to get help or send an email to
[email protected]
""")
"""
)
discord_button = QPushButton("Join the discord server")
discord_button.clicked.connect(
lambda: webbrowser.open("https://discord.gg/yF527cQbDG")
Expand Down
26 changes: 23 additions & 3 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,22 +281,27 @@ def delete_recording(session: SaSession, recording: Recording) -> None:
delete_video_file(recording_timestamp)


def get_all_recordings(session: SaSession) -> list[Recording]:
def get_recordings(session: SaSession, max_rows=None) -> list[Recording]:
"""Get all recordings.

Args:
session (sa.orm.Session): The database session.
max_rows: The number of recordings to return, starting from the most recent.
Defaults to all if max_rows is not specified.

Returns:
list[Recording]: A list of all original recordings.
"""
return (
query = (
session.query(Recording)
.filter(Recording.original_recording_id == None) # noqa: E711
.order_by(sa.desc(Recording.timestamp))
.all()
)

if max_rows:
query = query.limit(max_rows)
return query.all()


def get_all_scrubbed_recordings(
session: SaSession,
Expand Down Expand Up @@ -352,6 +357,21 @@ def get_recording(session: SaSession, timestamp: float) -> Recording:
return session.query(Recording).filter(Recording.timestamp == timestamp).first()


def get_recordings_by_desc(session: SaSession, description_str: str) -> list[Recording]:
"""Get recordings by task description.
Args:
session (sa.orm.Session): The database session.
task_description (str): The task description to search for.
Returns:
list[Recording]: A list of recordings whose task descriptions contain the given string.
"""
return (
session.query(Recording)
.filter(Recording.task_description.contains(description_str))
.all()
)


BaseModelType = TypeVar("BaseModelType")


Expand Down
132 changes: 132 additions & 0 deletions openadapt/scripts/generate_db_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
from sqlalchemy import create_engine, inspect
from sqlalchemy.orm import sessionmaker
from openadapt.db.db import Base
from openadapt.config import DATA_DIR_PATH, PARENT_DIR_PATH, RECORDING_DIR_PATH
import openadapt.db.crud as crud
from loguru import logger


def get_session():
db_url = RECORDING_DIR_PATH / "recording.db"
print(f"Database URL: {db_url}")
engine = create_engine(f"sqlite:///{db_url}")
# SessionLocal = sessionmaker(bind=engine)
Base.metadata.create_all(bind=engine)
session = crud.get_new_session(read_only=True)
print("Database connection established.")
Animesh404 marked this conversation as resolved.
Show resolved Hide resolved
return session, engine


def check_tables_exist(engine):
inspector = inspect(engine)
tables = inspector.get_table_names()
expected_tables = [
"recording",
"action_event",
"screenshot",
"window_event",
"performance_stat",
"memory_stat",
]
for table_name in expected_tables:
table_exists = table_name in tables
logger.info(f"{table_name=} {table_exists=}")
return tables


def fetch_data(session):
# get the most recent three recordings
recordings = crud.get_recordings(session, max_rows=3)
recording_ids = [recording.id for recording in recordings]

action_events = []
screenshots = []
window_events = []
performance_stats = []
memory_stats = []

for recording in recordings:
action_events.extend(crud.get_action_events(session, recording))
screenshots.extend(crud.get_screenshots(session, recording))
window_events.extend(crud.get_window_events(session, recording))
performance_stats.extend(crud.get_perf_stats(session, recording))
memory_stats.extend(crud.get_memory_stats(session, recording))

data = {
"recordings": recordings,
"action_events": action_events,
"screenshots": screenshots,
"window_events": window_events,
"performance_stats": performance_stats,
"memory_stats": memory_stats,
}

# Debug prints to verify data fetching
print(f"Recordings: {len(data['recordings'])} found.")
print(f"Action Events: {len(data['action_events'])} found.")
print(f"Screenshots: {len(data['screenshots'])} found.")
print(f"Window Events: {len(data['window_events'])} found.")
print(f"Performance Stats: {len(data['performance_stats'])} found.")
print(f"Memory Stats: {len(data['memory_stats'])} found.")

return data


def format_sql_insert(table_name, rows):
if not rows:
return ""

columns = rows[0].__table__.columns.keys()
sql = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES\n"
values = []

for row in rows:
row_values = [getattr(row, col) for col in columns]
row_values = [
f"'{value}'" if isinstance(value, str) else str(value)
for value in row_values
]
values.append(f"({', '.join(row_values)})")

sql += ",\n".join(values) + ";\n"
return sql


def dump_to_fixtures(filepath):
session, engine = get_session()
check_tables_exist(engine)
data = fetch_data(session)

with open(filepath, "a", encoding="utf-8") as file:
if data["recordings"]:
file.write("-- Insert sample recordings\n")
file.write(format_sql_insert("recording", data["recordings"]))

if data["action_events"]:
file.write("-- Insert sample action_events\n")
file.write(format_sql_insert("action_event", data["action_events"]))

if data["screenshots"]:
file.write("-- Insert sample screenshots\n")
file.write(format_sql_insert("screenshot", data["screenshots"]))

if data["window_events"]:
file.write("-- Insert sample window_events\n")
file.write(format_sql_insert("window_event", data["window_events"]))

if data["performance_stats"]:
file.write("-- Insert sample performance_stats\n")
file.write(format_sql_insert("performance_stat", data["performance_stats"]))

if data["memory_stats"]:
file.write("-- Insert sample memory_stats\n")
file.write(format_sql_insert("memory_stat", data["memory_stats"]))
print(f"Data appended to {filepath}")


if __name__ == "__main__":

fixtures_path = PARENT_DIR_PATH / "tests/assets/fixtures.sql"

dump_to_fixtures(fixtures_path)
4 changes: 2 additions & 2 deletions openadapt/scripts/reset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

def reset_db() -> None:
"""Clears the database by removing the db file and running a db migration."""
if os.path.exists(config.DB_FPATH):
os.remove(config.DB_FPATH)
if os.path.exists(config.DB_FILE_PATH):
os.remove(config.DB_FILE_PATH)

# Prevents duplicate logging of config values by piping stderr
# and filtering the output.
Expand Down
12 changes: 12 additions & 0 deletions openadapt/window/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys

from loguru import logger
import pywinauto

from openadapt.config import config

Expand Down Expand Up @@ -67,6 +68,17 @@ def get_active_window_state(read_window_data: bool) -> dict | None:
return None


def get_active_window() -> pywinauto.application.WindowSpecification:
"""Get the active window object.

Returns:
pywinauto.application.WindowSpecification: The active window object.
"""
app = pywinauto.application.Application(backend="uia").connect(active_only=True)
window = app.top_window()
return window.wrapper_object()


def get_active_element_state(x: int, y: int) -> dict | None:
"""Get the state of the active element at the specified coordinates.

Expand Down
88 changes: 88 additions & 0 deletions tests/openadapt/test_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import time
from loguru import logger
import logging
from openadapt.db.crud import (
get_recordings_by_desc,
get_new_session,
)
from openadapt.replay import replay
from openadapt.window import (
get_active_window,
)

# logging to a txt file
logging.basicConfig(
level=logging.INFO,
filename="test_results.txt",
filemode="w",
format="%(asctime)s | %(levelname)s | %(message)s",
)


# parametrized tests
@pytest.mark.parametrize(
"task_description, replay_strategy, expected_value, instructions",
[
("test_calculator", "VisualReplayStrategy", "6", " "),
("test_calculator", "VisualReplayStrategy", "8", "calculate 9-8+7"),
# ("test_spreadsheet", "NaiveReplayStrategy"),
# ("test_powerpoint", "NaiveReplayStrategy")
],
)
def test_replay(task_description, replay_strategy, expected_value, instructions):
# Get recordings which contain the string "test_calculator"
session = get_new_session(read_only=True)
recordings = get_recordings_by_desc(session, task_description)

assert (
len(recordings) > 0
), f"No recordings found with task description: {task_description}"
recording = recordings[0]

result = replay(
strategy_name=replay_strategy,
recording=recording,
instructions=instructions,
)
assert result is True, f"Replay failed for recording: {recording.id}"

def find_display_element(element, timeout=10):
"""Find the display element within the specified timeout.

Args:
element: The parent element to search within.
timeout: The maximum time to wait for the element (default is 10 seconds).

Returns:
The found element.

Raises:
TimeoutError: If the element is not found within the specified timeout.
"""
end_time = time.time() + timeout
Animesh404 marked this conversation as resolved.
Show resolved Hide resolved
while time.time() < end_time:
elements = element.descendants(control_type="Text")
for elem in elements:
if elem.element_info.name.startswith(
"Display is"
): # Target the display element
return elem
time.sleep(0.5)
Animesh404 marked this conversation as resolved.
Show resolved Hide resolved
raise TimeoutError("Display element not found within the specified timeout")

active_window = get_active_window()
element = find_display_element(active_window)
value = element.element_info.name[-1]

element_value = value
assert (
element_value == expected_value
), f"Value mismatch: expected '{expected_value}', got '{element_value}'"

result_message = f"Value match: '{element_value}' == '{expected_value}'"
logging.info(result_message)


if __name__ == "__main__":
pytest.main()
3 changes: 3 additions & 0 deletions tests/openadapt/test_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
2024-07-12 23:11:21,853 | INFO | Value match: '6' == '6'
2024-07-12 23:11:45,640 | INFO | Value match: '8' == '8'
2024-07-12 23:11:46,388 | INFO | HTTP Request: POST https://api.anthropic.com/v1/messages "HTTP/1.1 401 Unauthorized"
Animesh404 marked this conversation as resolved.
Show resolved Hide resolved
Loading