Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdomanskyi committed Nov 26, 2020
1 parent 499b562 commit c3d93f7
Show file tree
Hide file tree
Showing 9 changed files with 422 additions and 40 deletions.
4 changes: 4 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
- 1.3.7
* Added a function to import data from kallisto-bustools and cellranger
* Updated documentation

- 1.3.6
* Added quick-demo materials

Expand Down
125 changes: 97 additions & 28 deletions DigitalCellSorter/VisualizationFunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def saveFigure(self, fig, saveDir, label = 'Figure', extension = 'png', dpi = 30
# MatPlotLib-powered figures

@tryExcept
def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = True, subtract = False, saveExcel = True, nameToAppend = 'heatmap', plotBy = 'cluster', figsize = (8, 4), convertGenes = False, orderGenes=False, orderClusters=False, dpi = 300, extension = 'png', fontsize = 10, **kwargs):
def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = True, logScale=False, subtract = False, saveExcel = True, nameToAppend = 'heatmap', plotBy = 'cluster', figsize = (8, 4), convertGenes = False, orderGenes=False, orderClusters=False, dpi = 300, extension = 'png', fontsize = 10, labelsFontsize = 10, **kwargs):

'''Make heatmap gene expression plot from a provided gene expression matrix.
Expand Down Expand Up @@ -179,47 +179,96 @@ def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = Tru
DCS.makeHeatmapGeneExpressionPlot()
'''

lengthListGenes = []
labelsListGenes = []

if not genes is None:
if type(genes) in [list, np.array, tuple]:
if type(genes) in [list, np.ndarray, tuple]:
isNegativeGenes = [True if gene[-1]=='-' else False for gene in genes]
genes = [gene[:-1] if gene[-1]=='-' else gene for gene in genes]

if convertGenes:
genes = np.unique(self.gnc.Convert(genes, 'alias', 'hugo', returnUnknownString=False))
genes = self.gnc.Convert(genes, 'alias', 'hugo', returnUnknownString=False)


elif type(genes) in [dict]:
isNegativeGenes = dict()
for key in genes.keys():
isNegativeGenes[key] = [True if gene[-1]=='-' else False for gene in genes[key]]
genes[key] = [gene[:-1] if gene[-1]=='-' else gene for gene in genes[key]]

if convertGenes:
for key in genes.keys():
genes[key] = self.gnc.Convert(genes[key], 'alias', 'hugo', returnUnknownString=False)

lengthListGenes = []
listGenes = []
for key in genes.keys():
listGenes.extend(genes[key])
lengthListGenes.append(len(genes[key]))
labelsListGenes.append(key)

genes = listGenes
print('Length of genes lists:', lengthListGenes)

else:
if self.verbose >= 1:
print('Plotting all expressed genes not supported. Provide a smaller list of genes')

return

lengthListGenes = []
labelsListGenes = []

if df is None:
if self.df_expr is None:
self.loadExpressionData()

if self.df_expr is None:
return

common = pd.Index(genes).intersection(self.df_expr.index) #.drop_duplicates()
targetIndex = self.df_expr.index

if type(genes) in [list, np.ndarray, tuple]:
ind = np.isin(genes, targetIndex)
common = np.array(genes)[ind]
isNegativeGenes = np.array(isNegativeGenes)[ind]

elif type(genes) in [dict]:
common = []
temp_negative = []
for key in genes.keys():
ind = np.isin(genes[key], targetIndex)
temp_common = np.array(genes[key])[ind]
isNegativeGenes_common = np.array(isNegativeGenes[key])[ind]

if len(temp_common) > 0:
common.extend(temp_common)
temp_negative.extend(isNegativeGenes_common)
lengthListGenes.append(len(temp_common))
labelsListGenes.append(key)

isNegativeGenes = np.array(temp_negative)

else:
return

df = self.df_expr.loc[common].copy()

else:
common = pd.Index(genes).intersection(df.index).drop_duplicates()
targetIndex = df.index

if type(genes) in [list, np.ndarray, tuple]:
ind = np.isin(genes, targetIndex)
common = np.array(genes)[ind]
isNegativeGenes = np.array(isNegativeGenes)[ind]

elif type(genes) in [dict]:
common = []
temp_negative = []
for key in genes.keys():
ind = np.isin(genes[key], targetIndex)
temp_common = np.array(genes[key])[ind]
isNegativeGenes_common = np.array(isNegativeGenes[key])[ind]

if len(temp_common) > 0:
common.extend(temp_common)
temp_negative.extend(isNegativeGenes_common)
lengthListGenes.append(len(temp_common))
labelsListGenes.append(key)

isNegativeGenes = np.array(temp_negative)

else:
return

df = df.loc[common]

counts = df.loc[[df.index[0]]].groupby(axis=1, level=plotBy).count()
Expand All @@ -235,6 +284,10 @@ def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = Tru
df.iloc[i,:] -= np.min(df.iloc[i,:])

df.iloc[i,:] /= np.max(df.iloc[i,:])

if logScale:
df += 1.
df = np.log(df)

if orderGenes:
df = df.iloc[scipy.cluster.hierarchy.dendrogram(scipy.cluster.hierarchy.linkage(df, 'ward'), no_plot=True, get_leaves=True)['leaves']]
Expand All @@ -252,6 +305,13 @@ def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = Tru
ax.imshow(df.T.values[1:,:], cmap='Blues', interpolation='None', aspect='auto',
extent=(-0.5, df.shape[0] - 0.5, df.shape[1] - 0.5, +0.5))

data = df.T.values[1:,:].copy()
data[:, ~isNegativeGenes] = np.nan
data = np.ma.masked_where(np.isnan(data), data)

ax.imshow(data, cmap='Reds', interpolation='None', aspect='auto',
extent=(-0.5, df.shape[0] - 0.5, df.shape[1] - 0.5, +0.5))

ax.imshow(df.T.values[:1,:], cmap='Reds', interpolation='None', aspect='auto',
extent=(-0.5, df.shape[0] - 0.5, -0.5, +0.5))

Expand All @@ -262,7 +322,11 @@ def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = Tru
for label, value in zip(labelsListGenes, lengthListGenes):
currPosition += value
ax.axvline(x=currPosition - 0.5, c='k', lw=1)
ax.text(currPosition - 0.5*value - 0.5, df.shape[1], label, fontsize=10, c='k', ha='center', va='bottom')
ax.text(currPosition - 0.5*value - 0.5, df.shape[1], label, fontsize=labelsFontsize, c='k', ha='center', va='top')

df_temp = pd.DataFrame(index=df.columns[1:], columns=labelsListGenes)
df_temp.index = pd.MultiIndex.from_tuples(df_temp.index).get_level_values(0)[::-1]
df_temp.to_excel(os.path.join(self.saveDir, self.dataName + '-' + nameToAppend + '.xlsx'))

ax.set_xticks(range(df.shape[0]))
ax.set_yticks(range(df.shape[1]))
Expand All @@ -271,7 +335,7 @@ def makeHeatmapGeneExpressionPlot(self, df = None, genes = None, normalize = Tru
ylabels[0] = 'Mean across all cells'

ax.set_xticklabels(df.index, rotation=90, fontsize=fontsize)
ax.set_yticklabels(ylabels, rotation=0, fontsize=fontsize)
ax.set_yticklabels(ylabels, rotation=0, fontsize=1.2*fontsize)

ax.set_xlim([-0.5, df.shape[0] - 0.5])
ax.set_ylim([-0.5, df.shape[1] - 0.5])
Expand Down Expand Up @@ -950,7 +1014,7 @@ def add_colorbar(fig, labels, cmap = matplotlib.colors.LinearSegmentedColormap.f
return fig

@tryExcept
def makeStackedBarplot(self, clusterName = None, legendStyle = False, includeLowQC = True, dpi = 300, extension = 'png', **kwargs):
def makeStackedBarplot(self, clusterName = None, legendStyle = False, includeLowQC = True, fontsize = 12, dpi = 300, extension = 'png', **kwargs):

'''Produce stacked barplot with cell fractions
Expand Down Expand Up @@ -1072,7 +1136,7 @@ def get_stacked_data_and_colors(saveDir):
centers[i] = centers[i + 1] + step

for i in range(len(centers)):
ax.text(1.3, centers[i], '%s%% ' % (fractions[i]) + labels[i], fontsize=12, va='center', ha='left')
ax.text(1.3, centers[i], '%s%% ' % (fractions[i]) + labels[i], fontsize=fontsize, va='center', ha='left')
ax.plot([0.65, 1.2], [centers_orig[i], centers[i]], c='k', lw=0.75, clip_on=False)

plt.xlim((-0.5, len(df_Main.columns) - 0.5))
Expand Down Expand Up @@ -2004,8 +2068,8 @@ def add_colorbar(fig, labels, cmap = matplotlib.colors.LinearSegmentedColormap.f

ax.plot(attrs2D.T[0], attrs2D.T[1], '*', ms=14, color='k', alpha=1.0, zorder=-10**7, clip_on=False)
for attr in range(attrs2D.T[0].shape[0]):
temp_texts = ax.text(attrs2D.T[0][attr], attrs2D.T[1][attr], attrs_names[attr], fontsize=fontsize, ha='left',va='center', zorder=10 ** 10, clip_on=False)
temp_texts.set_path_effects([path_effects.Stroke(linewidth=1, foreground='white'), path_effects.Normal()])
temp_texts = ax.text(attrs2D.T[0][attr], attrs2D.T[1][attr], attrs_names[attr], fontsize=fontsize, fontweight=550, ha='left',va='center', zorder=10 ** 10, clip_on=False)
temp_texts.set_path_effects([path_effects.Stroke(linewidth=2.5, foreground='white'), path_effects.Normal()])
texts.append(temp_texts)

if adjustText:
Expand All @@ -2023,7 +2087,7 @@ def add_colorbar(fig, labels, cmap = matplotlib.colors.LinearSegmentedColormap.f
# Plotly-powered figures

@tryExcept
def makeSankeyDiagram(self, df, colormapForIndex = None, colormapForColumns = None, linksColor = 'rgba(100,100,100,0.6)', title = '', attemptSavingHTML = False, quality = 4, width = 400, height = 400, border = 20, nameAppend = '_Sankey_diagram'):
def makeSankeyDiagram(self, df, colormapForIndex = None, colormapForColumns = None, linksColor = 'rgba(100,100,100,0.6)', title = '', attemptSavingHTML = False, quality = 4, width = 400, height = 400, border = 20, nodeLabelsFontSize = 15, nameAppend = '_Sankey_diagram'):

'''Make a Sankey diagram, also known as 'river plot' with two groups of nodes
Expand All @@ -2049,6 +2113,9 @@ def makeSankeyDiagram(self, df, colormapForIndex = None, colormapForColumns = No
quality: int, Default 4
Proportional to the resolution of the figure to save
nodeLabelsFontSize: int, Default 15
Font size for node labels
nameAppend: str, Default '_Sankey_diagram'
Name to append to the figure file
Expand Down Expand Up @@ -2098,9 +2165,11 @@ def makeSankeyDiagram(self, df, colormapForIndex = None, colormapForColumns = No
newColor = ','.join(nodeColors[sources[i]].split(',')[:3] + ['0.6)'])
colorscales[i] = dict(label=labels[i], colorscale=[[0, newColor], [1, newColor]])

fig = go.Figure(data=[go.Sankey(valueformat = '', valuesuffix = '',
node = dict(pad = 20, thickness = 40, line = dict(color = 'white', width = 0.5), label = nodeLabels, color = nodeColors,),
link = dict(source = sources, target = targets, value = values, label = labels, colorscales = colorscales, hoverinfo='all'))]) #line ={'color':'rgba(255,0,0,0.8)', 'width':0.1}
fig = go.Figure(data=[go.Sankey(valueformat = '', valuesuffix = '', textfont = dict(color = 'rgb(255,0,0)', size = nodeLabelsFontSize, family = 'Arial'),
node = dict(pad = 20, thickness = 40, line = dict(color = 'white', width = 0.0), label = nodeLabels, color = nodeColors,
), # hoverlabel=dict(bordercolor = 'yellow')
link = dict(source = sources, target = targets, value = values, label = labels, colorscales = colorscales, hoverinfo='all'),
)],) #line ={'color':'rgba(255,0,0,0.8)', 'width':0.1}

if not title is None:
fig.update_layout(title_text=title, font_size=10)
Expand Down
Loading

0 comments on commit c3d93f7

Please sign in to comment.