diff --git a/.github/workflows/cml.yaml b/.github/workflows/cml.yaml new file mode 100644 index 0000000..6a059c4 --- /dev/null +++ b/.github/workflows/cml.yaml @@ -0,0 +1,20 @@ +name: CML +on: [push] +jobs: + train-and-report: + runs-on: ubuntu-latest + container: docker://ghcr.io/iterative/cml:0-dvc2-base1 + steps: + - uses: actions/checkout@v3 + - name: Train model + env: + REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + pip install -r requirements.txt + python dummy-evaluation.py + + # Create CML report + cat metrics.txt >> report.md + echo '![](./metrics.png "Violin Plot of Metrics")' >> report.md + cml comment create report.md + \ No newline at end of file diff --git a/dummy-evaluation.py b/dummy-evaluation.py new file mode 100644 index 0000000..31fdd08 --- /dev/null +++ b/dummy-evaluation.py @@ -0,0 +1,20 @@ +import numpy as np +import pandas as pd +import plotly.graph_objects as go +import plotly.io as pio + +metrics = {"answer_relevancy", "answer_correctness", "context_precision"} +dummy_data = {metric: np.random.rand(100) for metric in metrics} +df = pd.DataFrame(dummy_data) + +with open("metrics.txt", "w") as f: + for col in df: + f.write(f"{col}: {df[col].mean()}\n") + +pio.templates.default = "gridon" +fig = go.Figure() +metrics = [metric for metric in df.columns.to_list() if metric not in ["question", "ground_truth", "answer", "contexts"]] +for metric in metrics: + fig.add_trace(go.Violin(y=df[metric], name=metric, points="all", box_visible=True, meanline_visible=True)) +fig.update_yaxes(range=[-0.02,1.02]) +fig.write_image("metrics.png") diff --git a/requirements.txt b/requirements.txt index 5c1bed9..bb697d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,88 +1,6 @@ -aiohappyeyeballs==2.4.0 -aiohttp==3.10.5 -aiosignal==1.3.1 -annotated-types==0.7.0 -anyio==4.4.0 -appdirs==1.4.4 -asttokens==2.4.1 -attrs==24.2.0 -certifi==2024.7.4 -charset-normalizer==3.3.2 -comm==0.2.2 -dataclasses-json==0.6.7 -datasets==2.21.0 -debugpy==1.8.5 -decorator==5.1.1 -dill==0.3.8 -distro==1.9.0 -executing==2.0.1 -filelock==3.15.4 -frozenlist==1.4.1 -fsspec==2024.6.1 -greenlet==3.0.3 -h11==0.14.0 -httpcore==1.0.5 -httpx==0.27.0 -huggingface-hub==0.24.6 -idna==3.7 -ipykernel==6.29.5 -ipython==8.26.0 -jedi==0.19.1 -jiter==0.5.0 -jsonpatch==1.33 -jsonpointer==3.0.0 -jupyter_client==8.6.2 -jupyter_core==5.7.2 -langchain==0.2.14 -langchain-community==0.2.12 -langchain-core==0.2.33 -langchain-openai==0.1.22 -langchain-text-splitters==0.2.2 -langsmith==0.1.99 -marshmallow==3.21.3 -matplotlib-inline==0.1.7 -multidict==6.0.5 -multiprocess==0.70.16 -mypy-extensions==1.0.0 -nest-asyncio==1.6.0 -numpy==1.26.4 -openai==1.41.1 -orjson==3.10.7 -packaging==24.1 -pandas==2.2.2 -parso==0.8.4 -pexpect==4.9.0 -platformdirs==4.2.2 -plotly==5.23.0 -prompt_toolkit==3.0.47 -psutil==6.0.0 -ptyprocess==0.7.0 -pure_eval==0.2.3 -pyarrow==17.0.0 -pydantic==2.8.2 -pydantic_core==2.20.1 -Pygments==2.18.0 -pysbd==0.3.4 -python-dateutil==2.9.0.post0 -pytz==2024.1 -PyYAML==6.0.2 -pyzmq==26.1.1 -ragas==0.1.10 -regex==2024.7.24 -requests==2.32.3 -six==1.16.0 -sniffio==1.3.1 -SQLAlchemy==2.0.32 -stack-data==0.6.3 -tenacity==8.5.0 -tiktoken==0.7.0 -tornado==6.4.1 -tqdm==4.66.5 -traitlets==5.14.3 -typing-inspect==0.9.0 -typing_extensions==4.12.2 -tzdata==2024.1 -urllib3==2.2.2 -wcwidth==0.2.13 -xxhash==3.5.0 -yarl==1.9.4 +plotly +pandas +numpy +kaleido +#dvc +#dvc[gdrive]