Skip to content

Commit

Permalink
cosmetic improvements to README
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 27, 2024
1 parent cc22379 commit ef3e7af
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,44 @@

To install
<span style="color: #4D4352; font-family: 'arial narrow', arial, sans-serif;">
quanda
</span>:
quanda</span>:

```setup
pip install git+https://github.com/dilyabareeva/quanda.git
```

## Usage

An excerpt from `tutorials/usage_testing.py`:
```python
Excerpts from `tutorials/usage_testing.py`:

<details>
<summary><b><big>Step 1. Import library components</big></b></summary>

```python
from src.explainers.wrappers.captum_influence import captum_similarity_explain
from src.metrics.localization.identical_class import IdenticalClass
from src.metrics.randomization.model_randomization import (
ModelRandomizationMetric,
)
from src.metrics.unnamed.top_k_overlap import TopKOverlap
```

<details>

# define explanation parameters
<summary><b><big>Step 2. Define explanation parameters</big></b></summary>

```python
explain = captum_similarity_explain
explain_fn_kwargs = {"layers": "avgpool"}
model_id = "default_model_id"
cache_dir = "./cache"
```

<details>

# initialize metrics
<summary><b><big>Step 3. Initialize metrics</big></b></summary>

```python
model_rand = ModelRandomizationMetric(
model=model,
train_dataset=train_set,
Expand All @@ -58,13 +70,17 @@ model_rand = ModelRandomizationMetric(
id_class = IdenticalClass(model=model, train_dataset=train_set, device=DEVICE)

top_k = TopKOverlap(model=model, train_dataset=train_set, top_k=1, device="cpu")
```

<details>
<summary><b><big>Step 4. Iterate over test set and feed tensor batches first to explain, then to metric</big></b></summary>

# iterate over test set and feed tensor batches first to explain, then to metric
```python
for i, (data, target) in enumerate(tqdm(test_loader)):
data, target = data.to(DEVICE), target.to(DEVICE)

# some metrics have an explain_update() method in addition to update():
model_rand.update(data)
model_rand.explain_update(data)

# metrics that do not generate explanations only have an update() method:
tda = explain(
Expand All @@ -79,7 +95,4 @@ for i, (data, target) in enumerate(tqdm(test_loader)):
model_rand.update(data, tda)
id_class.update(target, tda)
top_k.update(target)



```

0 comments on commit ef3e7af

Please sign in to comment.