Skip to content

Commit

Permalink
Merge pull request #2 from jdvala/add-rich-text
Browse files Browse the repository at this point in the history
Add rich based console output
  • Loading branch information
jdvala authored Feb 19, 2022
2 parents 87e97ef + 34f704c commit 84245ea
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 117 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ dask-worker-space
# virtualenv
.venv
venv/
env/
ENV/

# Spyder project settings
Expand Down
75 changes: 39 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,42 +90,45 @@ lazy_text = LazyTextPredict(
models = lazy_text.fit(x_train, x_test, y_train, y_test)


Label Analysis
| Classes | Weights |
|--------------------:|---------------------:|
| tech | 0.8725490196078431 |
| politics | 1.1528497409326426 |
| sport | 1.0671462829736211 |
| entertainment | 0.8708414872798435 |
| business | 1.1097256857855362 |

Result Analysis
| Model | Accuracy | Balanced Accuracy | F1 Score | Custom Metric Score | Time Taken |
| ----------------------------: | -------------------:| -------------------:| -------------------:| -------------------:| -------------------:|
| AdaBoostClassifier | 0.7260479041916168 | 0.717737172132769 | 0.7248335989941609 | NA | 1.829047679901123 |
| BaggingClassifier | 0.8817365269461078 | 0.8796633962363677 | 0.8814695332332374 | NA | 3.5215072631835938 |
| BernoulliNB | 0.9535928143712575 | 0.9505929193425733 | 0.9533647387436917 | NA | 0.020041465759277344|
| CalibratedClassifierCV | 0.9760479041916168 | 0.9760018220340847 | 0.9755904096436046 | NA | 0.4990670680999756 |
| ComplementNB | 0.9760479041916168 | 0.9752329192546583 | 0.9754237510855159 | NA | 0.013598203659057617|
| DecisionTreeClassifier | 0.8532934131736527 | 0.8473956671194278 | 0.8496464898940103 | NA | 0.478792667388916 |
| DummyClassifier | 0.2155688622754491 | 0.2 | 0.07093596059113301 | NA | 0.008046865463256836|
| ExtraTreeClassifier | 0.7275449101796407 | 0.7253518459908658 | 0.7255575847020816 | NA | 0.026398658752441406|
| ExtraTreesClassifier | 0.9655688622754491 | 0.9635363285903302 | 0.9649837485086689 | NA | 1.6907336711883545 |
| GradientBoostingClassifier | 0.9565868263473054 | 0.9543725191544354 | 0.9554606292723953 | NA | 39.16400766372681 |
| KNeighborsClassifier | 0.938622754491018 | 0.9370053693959814 | 0.9367294513157219 | NA | 0.14803171157836914 |
| LinearSVC | 0.9745508982035929 | 0.974262691599302 | 0.9740343976103922 | NA | 0.10053229331970215 |
| LogisticRegression | 0.968562874251497 | 0.9668995859213251 | 0.9678778814908909 | NA | 2.9565982818603516 |
| LogisticRegressionCV | 0.9715568862275449 | 0.9708896757262861 | 0.971147482393915 | NA | 109.64091444015503 |
| MLPClassifier | 0.9760479041916168 | 0.9753381642512078 | 0.9752912960666735 | NA | 35.64296746253967 |
| MultinomialNB | 0.9700598802395209 | 0.9678795721187026 | 0.9689200656860745 | NA | 0.024427413940429688|
| NearestCentroid | 0.9520958083832335 | 0.9499045135454718 | 0.9515097876015481 | NA | 0.024636268615722656|
| NuSVC | 0.9670658682634731 | 0.9656159420289855 | 0.9669719954040374 | NA | 8.287142515182495 |
| PassiveAggressiveClassifier | 0.9775449101796407 | 0.9772388820754925 | 0.9770812340935414 | NA | 0.10332632064819336 |
| Perceptron | 0.9775449101796407 | 0.9769254658385094 | 0.9768161404324825 | NA | 0.07216000556945801 |
| RandomForestClassifier | 0.9625748502994012 | 0.9605135542632081 | 0.9624462948504477 | NA | 1.2427525520324707 |
| RidgeClassifier | 0.9775449101796407 | 0.9769254658385093 | 0.9769176825464448 | NA | 0.17272400856018066 |
| SGDClassifier | 0.9700598802395209 | 0.9695007868373973 | 0.969787370271274 | NA | 0.13134551048278809 |
| SVC | 0.9715568862275449 | 0.9703778467908902 | 0.9713021262026043 | NA | 8.388679027557373 |
Label Analysis
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Classes ┃ Weights ┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│ business │ 0.8725490196078431
│ sport │ 1.1528497409326426
│ politics │ 1.0671462829736211
│ entertainment │ 0.8708414872798435
│ tech │ 1.1097256857855362
└───────────────┴────────────────────┘
Result Analysis
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Accuracy ┃ Balanced Accuracy ┃ F1 Score ┃ Custom Metric Score ┃ Time Taken ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ AdaBoostClassifier │ 0.72604790419161680.7177371721327690.7248335989941609NA1.4244091510772705
│ BaggingClassifier │ 0.88173652694610780.87966339623636770.8814695332332374NA2.422576904296875
│ BernoulliNB │ 0.95359281437125750.95059291934257330.9533647387436917NA0.015914201736450195
│ CalibratedClassifierCV │ 0.97604790419161680.97600182203408470.9755904096436046NA0.36926722526550293
│ ComplementNB │ 0.97604790419161680.97523291925465830.9754237510855159NA0.009947061538696289
│ DecisionTreeClassifier │ 0.85329341317365270.84739566711942780.8496464898940103NA0.34440088272094727
│ DummyClassifier │ 0.21556886227544910.20.07093596059113301NA0.005555868148803711
│ ExtraTreeClassifier │ 0.72754491017964070.72535184599086580.7255575847020816NA0.018934965133666992
│ ExtraTreesClassifier │ 0.96556886227544910.96353632859033020.9649837485086689NA1.2101161479949951
│ GradientBoostingClassifier │ 0.95508982035928150.95263338871965290.9539060578037555NA30.256237030029297
│ KNeighborsClassifier │ 0.9386227544910180.93700536939598140.9367294513157219NA0.12071108818054199
│ LinearSVC │ 0.97455089820359290.9742626915993020.9740343976103922NA0.11713886260986328
│ LogisticRegression │ 0.9685628742514970.96689958592132510.9678778814908909NA0.8916082382202148
│ LogisticRegressionCV │ 0.97155688622754490.97088967572628610.971147482393915NA37.82431483268738
│ MLPClassifier │ 0.97604790419161680.97533816425120780.9752912960666735NA30.700589656829834
│ MultinomialNB │ 0.97005988023952090.96787957211870260.9689200656860745NA0.01410818099975586
│ NearestCentroid │ 0.95209580838323350.94990451354547180.9515097876015481NA0.018617868423461914
│ NuSVC │ 0.96706586826347310.96561594202898550.9669719954040374NA6.941549062728882
│ PassiveAggressiveClassifier │ 0.97754491017964070.97723888207549250.9770812340935414NA0.05249309539794922
│ Perceptron │ 0.97754491017964070.97692546583850940.9768161404324825NA0.030637741088867188
│ RandomForestClassifier │ 0.96257485029940120.96051355426320810.9624462948504477NA0.9921820163726807
│ RidgeClassifier │ 0.97754491017964070.97692546583850930.9769176825464448NA0.09582686424255371
│ SGDClassifier │ 0.97005988023952090.96950078683739730.969787370271274NA0.04686570167541504
SVC0.97155688622754490.97037784679089020.9713021262026043NA6.64256477355957
└─────────────────────────────┴────────────────────┴────────────────────┴─────────────────────┴─────────────────────┴──────────────────────┘
```

Result of each estimator is stored in `models` which is a list and each trained estimator is also returned which can be used further for analysis.
Expand Down
2 changes: 1 addition & 1 deletion requirements/prod.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
scikit-learn==1.0.1
tqdm==4.62.3
rich==11.2.0
pandas==1.3.5
2 changes: 1 addition & 1 deletion src/lazytext/_repo_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.1.dev1+dirty"
version = "0.0.2.dev2+dirty"
86 changes: 47 additions & 39 deletions src/lazytext/create_table.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,70 @@
import typing as tt

from rich.console import Console
from rich.table import Table


def create_table(results: tt.Dict, label_analysis: tt.Dict = None):
"""Create summary table for all the results.
Example:
Results:
```
| Model | Accuracy | Balanced Accuracy | Time Taken |
| -------------------: | -------------------: | -------------------: | -------------------: |
| MultinomialNB | 0.641908620301598 | 0.62653122841884 | 0.03511333465576172 |
Label Analysis
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Classes ┃ Weights ┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│ entertainment │ 0.8725490196078431 │
│ sport │ 1.1528497409326426 │
│ business │ 1.0671462829736211 │
│ tech │ 0.8708414872798435 │
│ politics │ 1.1097256857855362 │
└───────────────┴────────────────────┘
Result Analysis
┏━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Accuracy ┃ Balanced Accuracy ┃ F1 Score ┃ Custom Metric Score ┃ Time Taken ┃
┡━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ SVC │ 0.23502994011976047 │ 0.2184410069760388 │ 0.10772097568020063 │ NA │ 6.808304071426392 │
└───────┴─────────────────────┴────────────────────┴─────────────────────┴─────────────────────┴───────────────────┘
```
Label Analysis
Args:
results: Dictonary of all the results
label_analysis: Analysis of the labels
"""
result_format = "| {:<30}| {:<20}| {:<20}| {:<20}| {:<20}| {:<20}|"
label_format = "| {:<20}| {:<20} |"

# Class weights
console = Console()
if label_analysis:
print("\n Label Analysis")
print(label_format.format("Classes", "Weights"))
print("|--------------------:|---------------------:|")
label_table = Table(title="Label Analysis")
label_table.add_column("Classes", justify="left", style="cyan", no_wrap=True)
label_table.add_column("Weights", justify="left", style="magenta", no_wrap=True)

for name, weight in label_analysis.items():
print(label_format.format(name, weight))

# Result
print("\n Result Analysis")

print(
result_format.format(
"Model",
"Accuracy",
"Balanced Accuracy",
"F1 Score",
"Custom Metric Score",
"Time Taken",
)
)
print(
result_format.format(
"----------------------------:",
"-------------------:",
"-------------------:",
"-------------------:",
"-------------------:",
"-------------------:",
)
label_table.add_row(str(name), str(weight))

console.print(label_table)

result_table = Table(title="Result Analysis")
result_table.add_column("Model", justify="left", style="cyan", no_wrap=True)
result_table.add_column("Accuracy", justify="left", style="magenta", no_wrap=True)
result_table.add_column(
"Balanced Accuracy", justify="left", style="green", no_wrap=True
)
result_table.add_column("F1 Score", justify="left", style="red", no_wrap=True)
result_table.add_column("Custom Metric Score", justify="left", style="yellow")
result_table.add_column("Time Taken", justify="left", style="white")

for result in results:
temp = []
for key, value in result.items():
temp.append(value)
print(
result_format.format(temp[0], temp[1], temp[2], temp[3], temp[4], temp[5])

result_table.add_row(
str(temp[0]),
str(temp[1]),
str(temp[2]),
str(temp[3]),
str(temp[4]),
str(temp[5]),
)
print("\n")

console.print(result_table)
Loading

0 comments on commit 84245ea

Please sign in to comment.