Skip to content

Commit

Permalink
Update: seaborn plot dependency optional
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanHeng committed Jun 1, 2024
1 parent 361fcc2 commit 5928e38
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages

VERSION = '0.42.9'
VERSION = '0.42.10'
DESCRIPTION = 'Machine Learning project startup utilities'
LONG_DESCRIPTION = 'My commonly used utilities for machine learning projects'

Expand All @@ -13,7 +13,7 @@
description=DESCRIPTION,
long_description=LONG_DESCRIPTION,
url='https://github.com/StefanHeng/stef-util',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.42.9Y.tar.gz',
download_url='https://github.com/StefanHeng/stef-util/archive/refs/tags/v0.42.10.tar.gz',
packages=find_packages(),
include_package_data=True,
install_requires=[
Expand All @@ -24,7 +24,8 @@
],
extras_require={
'legacy_styling': ['sty', 'colorama'],
'plot': ['matplotlib', 'seaborn'],
'plot': ['matplotlib'],
'plot-optional': ['seaborn'],
'machine_learning': ['scikit-learn'],
'deep_learning': ['spacy', 'torch', 'transformers>=4.33.2', 'sentence-transformers', 'tensorboard'],
'optional': ['pygments', 'pyinstrument']
Expand Down
3 changes: 2 additions & 1 deletion stefutil/packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
'installed_packages'
]

_PKGS_PLT = ['matplotlib', 'seaborn']
# _PKGS_PLT = ['matplotlib', 'seaborn'] # make seaborn optional
_PKGS_PLT = ['matplotlib']
_PKGS_ML = ['scikit-learn']
_PKGS_DL = ['torch', 'tensorboard', 'transformers', 'sentence-transformers', 'spacy']

Expand Down
40 changes: 23 additions & 17 deletions stefutil/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from stefutil.prettier import s, ca, get_logger, Timer
from stefutil.container import df_col2cat_col
from stefutil.packaging import _use_plot, _use_ml
from stefutil.packaging import installed_packages, _use_plot, _use_ml


__all__ = []
Expand All @@ -25,7 +25,6 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Ellipse
from matplotlib.colors import to_rgba
import matplotlib.colors as colors
Expand All @@ -46,8 +45,10 @@ def set_plot_style():
r'\usepackage{sansmath}', # render math sans-serif
r'\sansmath'
]))
sns.set_style('darkgrid')
sns.set_context(rc={'grid.linewidth': 0.5})
if 'seaborn' in installed_packages():
import seaborn as sns
sns.set_style('darkgrid')
sns.set_context(rc={'grid.linewidth': 0.5})

LN_KWARGS = dict(marker='o', ms=0.3, lw=0.25) # matplotlib line plot default args

Expand Down Expand Up @@ -77,6 +78,7 @@ def vals2colors(
If given, reduce visual spread/difference of colors
Intended for a less drastic color at the extremes
"""
import seaborn as sns
vals = np.asarray(vals)
cmap = sns.color_palette(color_palette, as_cmap=True)
mi, ma = np.min(vals), np.max(vals)
Expand Down Expand Up @@ -117,6 +119,8 @@ def barplot(
show: bool = True,
**kwargs
):
import seaborn as sns

ca(bar_orient=orient)
if data is not None:
df = data
Expand Down Expand Up @@ -192,21 +196,23 @@ def vector_projection_plot(
Given vectors grouped by key, plot projections of vectors into 2D space
Intended for plotting embedding space of SBert sentence representations
:param name2vectors: 2D vectors grouped by setup name
:param tsne_args: Arguments for TSNE dimensionality reduction
:param tight_fig_size: If true, resize the figure to fit the axis range
:param key_name: column name for setup in the internal dataframe
:param ellipse: If true, plot confidence ellipse for each setup
:param ellipse_std: Number of standard deviations for ellipse
:param scatter_ms: Base marker size for scatter plot
Will be scaled by number of samples in each group
:param scatter_kwargs: arguments for scatter plot
:param ax: matplotlib axes object to plot on
:param title: plot title
:param verbose: If true, prints status to logger
:param logger: logger
:param name2vectors: 2D vectors grouped by setup name.
:param tsne_args: Arguments for TSNE dimensionality reduction.
:param tight_fig_size: If true, resize the figure to fit the axis range.
:param key_name: column name for setup in the internal dataframe.
:param ellipse: If true, plot confidence ellipse for each setup.
:param ellipse_std: Number of standard deviations for ellipse.
:param scatter_ms: Base marker size for scatter plot.
Will be scaled by number of samples in each group.
:param scatter_kwargs: arguments for scatter plot.
:param ax: matplotlib axes object to plot on.
:param title: plot title.
:param verbose: If true, prints status to logger.
:param logger: logger.
"""
from sklearn.manifold import TSNE # lazy import to save time
import seaborn as sns

vects = np.concatenate(list(name2vectors.values()), axis=0)
tsne_args_ = dict(n_components=2, perplexity=50, random_state=42)
tsne_args_.update(tsne_args or dict())
Expand Down

0 comments on commit 5928e38

Please sign in to comment.