Skip to content

Commit

Permalink
more snake_case
Browse files Browse the repository at this point in the history
  • Loading branch information
wspr committed Mar 14, 2024
1 parent 1ae9275 commit f00e5c8
Showing 1 changed file with 64 additions and 65 deletions.
129 changes: 64 additions & 65 deletions ausankey/ausankey.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def sankey(
data = pandas dataframe of labels and weights in alternating columns
colorDict = Dictionary of colors to use for each label
{'label':'color'}
leftLabels = order of the left labels in the diagram
rightLabels = order of the right labels in the diagram
left_labels = order of the left labels in the diagram
right_labels = order of the right labels in the diagram
aspect = vertical extent of the diagram in units of horizontal extent
rightColor = If true, each strip in the diagram will be be colored
according to its left label
Expand Down Expand Up @@ -217,22 +217,22 @@ def _sankey(
))

if labelOrder is not None:
leftLabels = list(labelOrder[ii])
rightLabels = list(labelOrder[ii+1])
left_labels = list(labelOrder[ii])
right_labels = list(labelOrder[ii+1])
else:
leftLabels = list(wgt[0].keys())
rightLabels = list(wgt[2].keys())
left_labels = list(wgt[0].keys())
right_labels = list(wgt[2].keys())

# check labels
check_data_matches_labels(
leftLabels, left, 'left')
left_labels, left, 'left')
check_data_matches_labels(
rightLabels, right, 'right')
right_labels, right, 'right')

# check colours
allLabels = pd.Series(np.r_[left.unique(), right.unique()]).unique()
all_labels = pd.Series(np.r_[left.unique(), right.unique()]).unique()

missing = [label for label in allLabels if label not in colorDict]
missing = [label for label in all_labels if label not in colorDict]
if missing:
msg = (
"The colorDict parameter is missing "
Expand All @@ -244,94 +244,93 @@ def _sankey(
# Determine sizes of individual strips
barSizeLeft = {}
barSizeRight = {}
for leftLabel in leftLabels:
barSizeLeft[leftLabel] = {}
barSizeRight[leftLabel] = {}
for rightLabel in rightLabels:
ind = (left == leftLabel) & (right == rightLabel)
barSizeLeft[leftLabel][rightLabel] = left_weight[ind].sum()
barSizeRight[leftLabel][rightLabel] = right_weight[ind].sum()
for left_label in left_labels:
barSizeLeft[left_label] = {}
barSizeRight[left_label] = {}
for right_label in right_labels:
ind = (left == left_label) & (right == right_label)
barSizeLeft[left_label][right_label] = left_weight[ind].sum()
barSizeRight[left_label][right_label] = right_weight[ind].sum()

# Determine positions of left label patches and total widths
leftWidths = {}
for i, leftLabel in enumerate(leftLabels):
left_widths = {}
for i, left_label in enumerate(left_labels):
myD = {}
myD['left'] = left_weight[left == leftLabel].sum()
myD['left'] = left_weight[left == left_label].sum()
if i == 0:
myD['bottom'] = voffset[ii]
else:
myD['bottom'] = (
leftWidths[leftLabels[i-1]]['top'] + barGap*plot_height
left_widths[left_labels[i-1]]['top'] + barGap*plot_height
)
myD['top'] = myD['bottom'] + myD['left']
leftWidths[leftLabel] = myD
left_widths[left_label] = myD

# Determine positions of right label patches and total widths
rightWidths = {}
for i, rightLabel in enumerate(rightLabels):
right_widths = {}
for i, right_label in enumerate(right_labels):
myD = {}
myD['right'] = right_weight[right == rightLabel].sum()
myD['right'] = right_weight[right == right_label].sum()
if i == 0:
myD['bottom'] = voffset[ii+1]
else:
myD['bottom'] = (
rightWidths[rightLabels[i-1]]['top'] + barGap * plot_height
right_widths[right_labels[i-1]]['top'] + barGap * plot_height
)
myD['top'] = myD['bottom'] + myD['right']
rightWidths[rightLabel] = myD
right_widths[right_label] = myD

# horizontal extents of flows in each subdiagram
xMax = sub_width
barW = barWidth*xMax
xLeft = barW + labelWidth*xMax + ii*(xMax+barW)
xRight = xLeft + xMax
x_bar_width = barWidth*sub_width
x_left = x_bar_width + labelWidth*sub_width + ii*(sub_width+x_bar_width)
x_right = x_left + sub_width

# Draw bars and their labels
if ii == 0: # first time
for leftLabel in leftLabels:
lbot = leftWidths[leftLabel]['bottom']
lll = leftWidths[leftLabel]['left']
for left_label in left_labels:
lbot = left_widths[left_label]['bottom']
lll = left_widths[left_label]['left']
ax.fill_between(
xLeft+[-barW, 0],
x_left+[-x_bar_width, 0],
2*[lbot],
2*[lbot + lll],
color=colorDict[leftLabel],
color=colorDict[left_label],
alpha=1,
lw=0,
snap=True,
)
ax.text(
xLeft - (labelGap+barWidth)*xMax,
x_left - (labelGap+barWidth)*sub_width,
lbot + 0.5*lll,
labelDict.get(leftLabel, leftLabel),
labelDict.get(left_label, left_label),
{'ha': 'right', 'va': 'center'},
fontsize=fontsize
)
for rightLabel in rightLabels:
rbot = rightWidths[rightLabel]['bottom']
rrr = rightWidths[rightLabel]['right']
for right_label in right_labels:
rbot = right_widths[right_label]['bottom']
rrr = right_widths[right_label]['right']
ax.fill_between(
xRight+[0, barW],
x_right+[0, x_bar_width],
2*[rbot],
[rbot + rrr],
color=colorDict[rightLabel],
color=colorDict[right_label],
alpha=1,
lw=0,
snap=True,
)
if ii < num_flow-1: # inside labels
ax.text(
xRight + (labelGap+barWidth)*xMax,
x_right + (labelGap+barWidth)*sub_width,
rbot + 0.5*rrr,
labelDict.get(rightLabel, rightLabel),
labelDict.get(right_label, right_label),
{'ha': 'left', 'va': 'center'},
fontsize=fontsize
)
if ii == num_flow-1: # last time
ax.text(
xRight + (labelGap+barWidth)*xMax,
x_right + (labelGap+barWidth)*sub_width,
rbot + 0.5*rrr,
labelDict.get(rightLabel, rightLabel),
labelDict.get(right_label, right_label),
{'ha': 'left', 'va': 'center'},
fontsize=fontsize
)
Expand All @@ -341,9 +340,9 @@ def _sankey(

# leftmost title
if ii == 0:
xt = xLeft - xMax*barWidth/2
xt = x_left - x_bar_width/2
if titleSide in ("top", "both"):
yt = titleGap * plot_height + leftWidths[leftLabel]['top']
yt = titleGap * plot_height + left_widths[left_label]['top']
va = 'bottom'
ax.text(
xt, yt, titles[ii],
Expand All @@ -362,9 +361,9 @@ def _sankey(
)

# all other titles
xt = xRight + xMax*barWidth/2
xt = x_right + x_bar_width/2
if (titleSide == "top") | (titleSide == "both"):
yt = titleGap * plot_height + rightWidths[rightLabel]['top']
yt = titleGap * plot_height + right_widths[right_label]['top']

ax.text(
xt, yt, titles[ii+1],
Expand All @@ -384,17 +383,17 @@ def _sankey(
# Plot strips
num_div = 20
num_arr = 50
for leftLabel in leftLabels:
for rightLabel in rightLabels:
for left_label in left_labels:
for right_label in right_labels:

if not any(
(left == leftLabel) & (right == rightLabel)):
(left == left_label) & (right == right_label)):
continue

lbot = leftWidths[leftLabel]['bottom']
rbot = rightWidths[rightLabel]['bottom']
lbar = barSizeLeft[leftLabel][rightLabel]
rbar = barSizeRight[leftLabel][rightLabel]
lbot = left_widths[left_label]['bottom']
rbot = right_widths[right_label]['bottom']
lbar = barSizeLeft[left_label][right_label]
rbar = barSizeRight[left_label][right_label]

# Create array of y values for each strip, half at left value,
# half at right, convolve
Expand All @@ -408,13 +407,13 @@ def _sankey(

# Update bottom edges at each label
# so next strip starts at the right place
leftWidths[leftLabel]['bottom'] += lbar
rightWidths[rightLabel]['bottom'] += rbar
left_widths[left_label]['bottom'] += lbar
right_widths[right_label]['bottom'] += rbar

xx = np.linspace(xLeft, xRight, len(ys_d))
cc = combineColours(
colorDict[leftLabel],
colorDict[rightLabel], len(ys_d))
xx = np.linspace(x_left, x_right, len(ys_d))
cc = combine_colours(
colorDict[left_label],
colorDict[right_label], len(ys_d))

for jj in range(len(ys_d)-1):
ax.fill_between(
Expand Down Expand Up @@ -447,7 +446,7 @@ def check_data_matches_labels(labels, data, side):
raise LabelMismatchError(side, msg)


def combineColours(c1, c2, num_col):
def combine_colours(c1, c2, num_col):

colorArrayLen = 4
# if not [r,g,b,a] assume a hex string like "#rrggbb":
Expand Down

0 comments on commit f00e5c8

Please sign in to comment.