Skip to content

Commit

Permalink
Introduce orientation to bar plot (ResidentMario#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrugman committed Jul 3, 2021
1 parent 8c31763 commit 6ad031d
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions missingno/missingno.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ def matrix(df,
return ax0


def bar(df, figsize=(24, 10), fontsize=16, labels=None, log=False, color='dimgray', inline=False,
filter=None, n=0, p=0, sort=None, ax=None):
def bar(df, figsize=None, fontsize=16, labels=None, log=False, color='dimgray', inline=False,
filter=None, n=0, p=0, sort=None, ax=None, orientation=None):
"""
A bar chart visualization of the nullity of the given DataFrame.
:param df: The input DataFrame.
:param log: Whether or not to display a logorithmic plot. Defaults to False (linear).
:param log: Whether or not to display a logarithmic plot. Defaults to False (linear).
:param filter: The filter to apply to the heatmap. Should be one of "top", "bottom", or None (default).
:param n: The cap on the number of columns to include in the filtered DataFrame.
:param p: The cap on the percentage fill of the columns in the filtered DataFrame.
Expand All @@ -219,21 +219,36 @@ def bar(df, figsize=(24, 10), fontsize=16, labels=None, log=False, color='dimgra
:param labels: Whether or not to display the column names. Would need to be turned off on particularly large
displays. Defaults to True.
:param color: The color of the filled columns. Default to the RGB multiple `(0.25, 0.25, 0.25)`.
:param orientation: The way the bar plot is oriented. Defaults to vertical if there are less than or equal to 50
columns and horizontal if there are more.
:return: If `inline` is False, the underlying `matplotlib.figure` object. Else, nothing.
"""
df = nullity_filter(df, filter=filter, n=n, p=p)
df = nullity_sort(df, sort=sort, axis='rows')
nullity_counts = len(df) - df.isnull().sum()

if orientation is None:
if len(df.columns) > 50:
orientation = 'left'
else:
orientation = 'bottom'

if ax is None:
ax1 = plt.gca()
if figsize is None:
if len(df.columns) <= 50 or orientation == 'top' or orientation == 'bottom':
figsize = (25, 10)
else:
figsize = (25, (25 + len(df.columns) - 50) * 0.5)
else:
ax1 = ax
figsize = None # for behavioral consistency with other plot types, re-use the given size

(nullity_counts / len(df)).plot.bar(
figsize=figsize, fontsize=fontsize, log=log, color=color, ax=ax1
)
plot_args = {'figsize': figsize, 'fontsize': fontsize, 'log': log, 'color': color, 'ax': ax1}
if orientation == 'bottom':
(nullity_counts / len(df)).plot.bar(**plot_args)
else:
(nullity_counts / len(df)).plot.barh(**plot_args)

axes = [ax1]

Expand All @@ -247,23 +262,61 @@ def bar(df, figsize=(24, 10), fontsize=16, labels=None, log=False, color='dimgra
if not log:
ax1.set_ylim([0, 1])
ax2.set_yticks(ax1.get_yticks())
ax2.set_yticklabels([int(n*len(df)) for n in ax1.get_yticks()], fontsize=fontsize)
else:
# For some reason when a logarithmic plot is specified `ax1` always contains two more ticks than actually
# appears in the plot. The fix is to ignore the first and last entries. Also note that when a log scale
# is used, we have to make it match the `ax1` layout ourselves.
ax2.set_yscale('log')
ax2.set_ylim(ax1.get_ylim())
ax2.set_yticklabels([int(n*len(df)) for n in ax1.get_yticks()], fontsize=fontsize)
ax2.set_yticklabels([int(n*len(df)) for n in ax1.get_yticks()], fontsize=fontsize)

# Create the third axis, which displays columnar totals above the rest of the plot.
ax3 = ax1.twiny()
axes.append(ax3)
ax3.set_xticks(ax1.get_xticks())
ax3.set_xlim(ax1.get_xlim())
ax3.set_xticklabels(nullity_counts.values, fontsize=fontsize, rotation=45, ha='left')
else:
ax1.set_xticks([])

# Create the third axis, which displays columnar totals above the rest of the plot.
ax3 = ax1.twiny()
axes.append(ax3)
ax3.set_xticks(ax1.get_xticks())
ax3.set_xlim(ax1.get_xlim())
ax3.set_xticklabels(nullity_counts.values, fontsize=fontsize, rotation=45, ha='left')
# Create the numerical ticks.
ax2 = ax1.twinx()

axes.append(ax2)
if not log:
# Width
ax1.set_xlim([0, 1])

# Bottom
ax2.set_xticks(ax1.get_xticks())
ax2.set_xticklabels([int(n*len(df)) for n in ax1.get_xticks()], fontsize=fontsize)

# Right
ax2.set_yticks(ax1.get_yticks())
ax2.set_yticklabels(nullity_counts.values, fontsize=fontsize, ha='left')
else:
# For some reason when a logarithmic plot is specified `ax1` always contains two more ticks than actually
# appears in the plot. The fix is to ignore the first and last entries. Also note that when a log scale
# is used, we have to make it match the `ax1` layout ourselves.
ax1.set_xscale('log')
ax1.set_xlim(ax1.get_xlim())

# Bottom
ax2.set_xticks(ax1.get_xticks())
ax2.set_xticklabels([int(n*len(df)) for n in ax1.get_xticks()], fontsize=fontsize)

# Right
ax2.set_yticks(ax1.get_yticks())
ax2.set_yticklabels(nullity_counts.values, fontsize=fontsize, ha='left')

# Create the third axis, which displays columnar totals above the rest of the plot.
ax3 = ax1.twiny()

axes.append(ax3)
ax3.set_yticks(ax1.get_yticks())
if log:
ax3.set_xscale('log')
ax3.set_xlim(ax1.get_xlim())
ax3.set_ylim(ax1.get_ylim())

ax3.grid(False)

for ax in axes:
Expand Down

0 comments on commit 6ad031d

Please sign in to comment.