Advanced Seaborn (SNS) plotting and labelling techniques¶

If you are anything like me, you are probably using seaborn (SNS) to obtain good-looking plots for reporting purposes while minimising the amount of coding required. Also, if you are anything like me, you are probably wanting to do more with SNS than what most tutorials and courses online show, and you have probably already found yourself adapting code from Stack Overflow to suit your needs on a case-by-case basis.

The beauty of SNS is its simplicity and power: good-looking visualisations with sometimes no more than 3 lines of code on average. However, trying to find ways to manipulate the output can be challenging, especially if you want to do some non-standard operations, like adding labels to some but not all points or bars, changing tick labels on the fly, adding information in a dataframe that is not currently captured by either the x, y or hue parameters, and so on. These sorts of things could be done manually as post-production on PowerPoint or any other software before adding them to the report, but this is time-consuming and prone to error, and if you work in business, you know there is little time for inefficiency.

So I decided to put together this little guide to compile some of the most interesting and useful things I managed to accomplish with SNS outside those that can be easily found on most online courses and tutorials. If you think there are things missing that are useful, drop me a line!

The data I will be using here comes from Kaggle and corresponds to the video game sales dataset by Gregory Smith, and the Supermarket Sales dataset by Aung Pyae, as well as some other datasets that can be loaded directly from SNS. The data is only as accurate as the source states, and I take no responsibility for its accuracy (after all this is a tutorial on plotting, not on data accuracy).

Performed with Python: 3.10.6; Seaborn: 0.11.2; Matplotlib: 3.5.3; Pandas: 1.4.4; SciPy: 1.9.1; and Numpy: 1.23.2

Setting the overall theme by overriding the defaults¶

Please change the default theme to something that conveys professionalism and gives your readers a better experience while interpreting your data. Compare some of the default things on SNS to some of the best data visualisations on the internet (examples here or here). True, the majority cannot be accomplished with SNS, but if you can see past that issue, you will see some common trends: very few have grids, very few have background colors that do not blend with the report (or, more commonly, they are plotted over background images), nearly all of them use custom colour palettes instead of the defauls, very few have ticks in the axes, and so on.

The first step to get consistency and professionalism is to set up the overall theme. Here you can set the palette as well.

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import seaborn as sns

import warnings
warnings.filterwarnings("ignore")

We set the aesthetics using sns.set() (alternatively, set_theme()), and you can see in the documentation all the things that can be changed. Most importantly, most of the aesthetics that we're interested in run under the style parameter. If you run sns.axes_style() without arguments, it will show you all the things that can be changed (alternatively you can check the relevant documentation). Changing things is mostly intuitive, so there is no further explanation required. I will change them to the style I will use for this tutorial.

In [3]:
sns.set_theme(
    context='talk',
    font_scale=0.9,
    palette = ['#0F00F5', '#3061FF', '#9AB1FF', '#CDD9FF', '#E6ECFF','#E5E5E5',
               '#B6BBCB', '#878B98','#696A6F','#292A2E'],
    style = {
        'axes.facecolor': '#FFFFFF',
         'axes.edgecolor': '#000000',
         'axes.legend.edgecolor':'#FFFFFF',
         'axes.grid': False,
         'axes.axisbelow': 'line',
         'axes.labelcolor': 'black',
         'figure.facecolor': '#FFFFFF',
         'grid.color': '#b0b0b0',
         'grid.linestyle': '-',
         'text.color': 'black',
         'xtick.color': 'black',
         'ytick.color': 'black',
         'xtick.direction': 'out',
         'ytick.direction': 'out',
         'patch.edgecolor': '#FFFFFF',
         'patch.force_edgecolor': True,
         'image.cmap': 'viridis',
         'font.family': ['sans-serif'],
         'font.sans-serif': 'Helvetica Neue',
         'xtick.bottom': False,
         'xtick.top': False,
         'ytick.left': False,
         'ytick.right': False,
         'axes.spines.left': False,
         'axes.spines.bottom': False,
         'axes.spines.right': False,
         'axes.spines.top': False
    }
)

Check that your palette works (i.e. that the colours are distinguishable from each other if you are using them for different groups, for example), by simply printing it:

In [4]:
sns.color_palette()
Out[4]:

Barplots with styled axes and bar annotations¶

Let us style a barplot, since it is probably the most widely used plot ever. Using the games dataset I am interested in looking at the game that selled the most worldwide each year, between 2000 and 2015. The basic plot would look like this:

In [5]:
#get data
games_data  = pd.read_csv('data/vgsales.csv')
games_data['Year'] = pd.to_numeric(games_data['Year'], downcast='integer')

#prepare data for plotting
sales_by_ypn = games_data.groupby(['Year', 'Platform', 'Name']).sum().reset_index().iloc[:,[0,1,2,-1]]
sales_by_ypn['Share_of_global_inyear'] = sales_by_ypn['Global_Sales'] * 100 / sales_by_ypn.groupby('Year')['Global_Sales'].transform(np.sum)
sales_by_ypn['Year_rank'] = sales_by_ypn.groupby('Year')['Share_of_global_inyear'].rank(ascending=False)

selected_years = range(2000,2016)
winner_per_year = sales_by_ypn[(sales_by_ypn['Year_rank']==1) & (sales_by_ypn['Year'].isin(selected_years))]
winner_per_year['Year'] = pd.to_numeric(winner_per_year['Year'], downcast='integer')

#basic plot
fig, axes = plt.subplots(1,1,figsize=(10,7))
sns.barplot(
    data=winner_per_year,
    x='Year',
    y='Share_of_global_inyear',
    ax=axes,
    color=sns.color_palette()[0]
).set_title('Best Performing Game Per Year by Share of Global Sales\n(2000-2015)\n', fontsize=15, weight='bold')
plt.xticks(rotation=90);

But in this case the plot does not really need either axis label since the title conveys both, so we can safely get rid of them. Additionally, since Share_of_global_inyear is percentual, it would look much better with the % symbol attached to the tick labels (and you probably do not need all the ticks, so we will remove every second one). Finally, the plot would not be informative if the game name (and platform) were not attached to each bar (although the game names are long so this makes the plot a bit messy, and forces us to expand the plot limits, but I hope you can see past that).

In [6]:
fig, axes = plt.subplots(1,1,figsize=(10,7))
sns.barplot(
    data=winner_per_year,
    x='Year',
    y='Share_of_global_inyear',
    ax=axes,
    color=sns.color_palette()[0]
).set_title('Best Performing Game Per Year by Share of Global Sales\n(2000-2015)\n', fontsize=15, weight='bold')

#style the y tick labels
ylabels = axes.get_yticks()
axes.set_yticks(ticks=ylabels, labels = [f'{i:.0f}%' if i%4 == 0 else '' for i in ylabels])

#the xtick labels can be styled as well on the fly
xlabels = [f'\'{str(i)[-2:]}' for i in winner_per_year['Year']]
axes.set_xticklabels(xlabels)

#prepare and assign bar labels
barlabels = [f'{name} ({platform})' for name, platform in zip(winner_per_year['Name'], winner_per_year['Platform'])]
axes.bar_label(axes.containers[-1], label_type='edge', labels=barlabels, color='k', size=10, padding=5, rotation=90)

plt.ylabel('')
plt.xlabel('')
plt.ylim(0,20);

Much better.

Alternatively, if you have a scenario where you have less information to show, (say, it is only one game you are interested in across the years) you can simply get rid of the y label and show the relevant information as bar labels. It looks much more professional this way:

In [7]:
#Need for Speed: Most Wanted
nfsmw = sales_by_ypn[sales_by_ypn['Name']=='Need for Speed: Most Wanted']

#I'll modify the year and platform data for the sake of the example
nfsmw['Year'] = range(2000,2012)
nfsmw['Platform'] = 'PS2'

fig, axes = plt.subplots(1,1,figsize=(10,7))
sns.barplot(
    data=nfsmw,
    x='Year',
    y='Share_of_global_inyear',
    ax=axes,
    color=sns.color_palette()[0]
).set_title('Share of Global Sales for NFS: Most Wanted (PS2)\n(2000-2011)\n', fontsize=15, weight='bold')

#style the y tick labels
ylabels = axes.get_yticks()
axes.set_yticks(ticks=ylabels, labels = ['' for i in ylabels])

#the xtick labels can be styled as well on the fly
xlabels = [f'\'{str(i)[-2:]}' for i in nfsmw['Year']]
axes.set_xticklabels(xlabels)

#prepare and assign bar labels
barlabels = [f'{i:.2f}%' for i in nfsmw['Share_of_global_inyear']]
axes.bar_label(axes.containers[-1], label_type='edge', labels=barlabels, color='k', size=12, padding=5, rotation=0)

plt.ylabel('');
plt.xlabel('');

Labelling only certain key points¶

Sometimes the story flows better when there is less information on the plot, for example if we want to convey how much sales have grown between two particular points. Labelling all points in this case is distracting from the message, so it's better to label only the key points. This example is on a lineplot but works on other plots as well.

In [8]:
#Sales for Sony's PSX Games in North America between 1994 and 2000
na_sales_psx = games_data.groupby(['Year', 'Platform']).sum().reset_index()
na_sales_psx['Year'] = pd.to_numeric(na_sales_psx['Year'], downcast='integer')

selected_years = range(1994,2001)
na_sales_psx = na_sales_psx[(na_sales_psx['Platform']=='PS') & na_sales_psx['Year'].isin(selected_years)].reset_index(drop=True).iloc[:,:4]

fig, axes = plt.subplots(1,1,figsize=(10,5))
sns.lineplot(
    data=na_sales_psx,
    x='Year',
    y='NA_Sales',
    marker='o',
    ax=axes,
    color=sns.color_palette()[0]
).set_title('Total Game Sales for Sony\'s PSX in North America\n(1994-2000)\n', fontsize=15, weight='bold')

#style the y tick labels
ylabels = axes.get_yticks()
axes.set_yticks(ticks=ylabels, labels = [f'${i:.0f}M' if i % 20 == 0 else '' for i in ylabels])

# select the years where the labels will be visible, and the loop with conditional flow to plot the relevant labels
label_years = [1994, 1998]
for year, sales in zip(na_sales_psx['Year'], na_sales_psx['NA_Sales']):
    if year in label_years:
        plt.text(x = year + 0.25, #offset the text so it's not on top of the point
                 y = sales,
                 s = f'${sales:.2f}M',
                 color='k',
                 size= 12
                 )

plt.ylabel('');
plt.xlabel('');

That way, the story of how sales for PSX games grew from $1.76M in 1994 to an all-time high of $83.22M in 1998 is nicely conveyed without distractions.

Again, this technique works really well when you have too many points and you need to bring attention to only some of them, for example to highlight only the model x task combinations that meet or pass a certain performance threshold in the glue dataset. Note here that the approach is different to a line or a bar plot since a heatmap takes a 2D array instead of a list for labelling.

In [9]:
glue = sns.load_dataset("glue").pivot(index="Model", columns="Task", values="Score")

hm_labels = np.array([[i if i >= 90 else '' for i in glue.values[j]] for j in range(len(glue.values))])

fig, axes = plt.subplots(1,1,figsize=(8,8))
sns.heatmap(glue, 
            cmap='Blues',
            ax=axes,
            annot=hm_labels, 
            fmt='', 
            annot_kws={'size':12},
            cbar_kws={'label': 'Performance Index', 'location':'bottom'}
).set_title('Model x Task Performance Heatmap\n', size=18, weight='bold');

And if you need to personalise the heatmap more to drive more impact or to remove unnecessary labels that some stakeholders will not be interested in (or for whatever change you need to make), heatmap() comes with some functionality to directly modify what you need. The only issue is that you have to capture the colorbar separately to modify it more intuitively:

In [10]:
fig, axes = plt.subplots(1,1,figsize=(8,8))
sns.heatmap(glue, 
            cmap='Blues',
            ax=axes,
            annot=hm_labels, 
            fmt='', 
            annot_kws={'size':12},
            cbar_kws={'label': 'Performance Index', 'location':'bottom'},
            xticklabels=[f'Task\n{i}' for i in range(len(glue.columns.tolist()))]
).set_title('Model x Task Performance Heatmap\n', size=18, weight='bold');

axes.set(xlabel='', ylabel='');

cbar = axes.collections[0].colorbar
cbar_ticks = cbar.get_ticks()
cbar.set_ticks([np.min(glue.values), np.max(glue.values)])
cbar.set_ticklabels(['Worst', 'Best'])
cbar.set_label('Performance Index', labelpad=-45, color='w', weight='bold')

Better/custom legends¶

The default legend comes with a box that, in my opinion, is not required if the legend is kept away from the bars and points in the plot (as it should). It looks better without the box as well, so we can get rid of it by simply passing .get_frame().setlinedidth(0) to the legend.

In [11]:
sm_sales_data = pd.read_csv('data/supermarket_sales - Sheet1.csv')
sm_sales_data['Branch'] = sm_sales_data['Branch'].map({'A': 'Inner City', 'B': 'Regional', 'C': 'Other'})

fig, axes = plt.subplots(1,1,figsize=(10,7))

sns.barplot(
    data=sm_sales_data,
    x='Branch',
    y='Unit price',
    hue='Customer type',
    errwidth=2,
    ax=axes
).set_title('Average Unit Price ($USD) Per Branch and Customer Type\n\n', size=18, weight='bold');

ylabels = axes.get_yticks()
axes.set_yticks(ticks=ylabels, labels=[f'${i:.0f}' if i % 20 == 0 else '' for i in ylabels])

plt.ylabel('');
plt.xlabel('');
plt.legend(bbox_to_anchor=(0.7,1.1), ncol=2).get_frame().set_linewidth(0)

Sometimes you may need to create a custom legend (and/or merging legends), for example when you're overlaying multiple SNS plots, or using a twinx() to get a second y axis. In these cases, it is sometimes better to draw the legend from scratch to avoid issues, especially if the second plot does not generate a legend by default. This example involves using a box plot to draw a line for average gross income (data that is not on the plot) on top of a bar plot:

In [12]:
###Plotting the average gross income per branch on top of the previous plot, keeping legend as is
fig, axes = plt.subplots(1,2,figsize=(15,5))

sns.barplot(data=sm_sales_data, x='Branch', y='Unit price', hue='Customer type', errwidth=2, ax=axes[0], palette=sns.color_palette()[2:])\
    .set_title('Average Unit Price ($USD) Per Branch and Customer Type\n\n\n\n', size=14, weight='bold');

ylabels = axes[0].get_yticks()
axes[0].set_yticks(ticks=ylabels, labels=[f'${i:.0f}' if i % 20 == 0 else '' for i in ylabels])

sns.boxplot(data=sm_sales_data.groupby(['Branch']).mean().reset_index(), x='Branch', y='gross income', ax=axes[0])

axes[0].set_ylabel('');
axes[0].set_xlabel('');

#Default legend
axes[0].legend(bbox_to_anchor=(0.8,1.15), ncol=2).get_frame().set_linewidth(0)


###As above, but creating the legend from scratch
sns.barplot(data=sm_sales_data, x='Branch', y='Unit price', hue='Customer type', errwidth=2, ax=axes[1], palette=sns.color_palette()[2:])\
    .set_title('Average Unit Price ($USD) Per Branch and Customer Type\n\n\n\n', size=14, weight='bold');

ylabels = axes[0].get_yticks()
axes[1].set_yticks(ticks=ylabels, labels=[f'${i:.0f}' if i % 20 == 0 else '' for i in ylabels])

sns.boxplot(data=sm_sales_data.groupby(['Branch']).mean().reset_index(), x='Branch', y='gross income', ax=axes[1])

axes[1].set_ylabel('');
axes[1].set_xlabel('');


#Custom legend here
legend_elements = [
    Line2D([0], [0], color='k', lw=1, label='Avg Branch\nGross Income'),
    Patch(facecolor=sns.color_palette()[2], edgecolor=sns.color_palette()[2], label='Member'),
    Patch(facecolor=sns.color_palette()[3], edgecolor=sns.color_palette()[3], label='Normal')
]
axes[1].legend(handles = legend_elements, bbox_to_anchor=(1,1), ncol=1, fontsize=14, handleheight=0.5).get_frame().set_linewidth(0)

Working with labels and hue¶

One of the most poweful features of SNS in my opinion is the hue parameter, which allows us to split the data based on another column. This is where SNS truly shines, especially when paired to a visually appealing colour palette. When it comes to assigning labels, however, things can get messy, especially when we want to assign labels with data not inherently contained in the plot (i.e. neither in x, y, or hue), and ever more so when order or hue_order are invoked. When hue is at work, each hue group becomes a container, and each container will have its own data, labels, and so forth. So to assign labels we have to be mindful of this and assign labels looping through each container when applicable.

The example below ilustrates 4 scenarios:

  • Plot 1: We assign labels depicting the information contained in the y component of the bar (in case we want to get rid of the y axis, for example). This is fairly straightforward as we get the values directly from the container and then they can be edited on the fly if need be.

  • Plot 2: We assign labels from another column (num_custs, the number of customers corresponding to each branch and each member type). This works when the dataframe is already ordered in the way we want, and the plot comes out with the group orders (x and hue that we want). This may not always be the case, of course, and we might want to invoke order and/or hue_order to get the plot we want.

  • Plot 3: As with plot 2, but this time we invoked order and hue_order, and we can see the labels are assigned as in plot 2, and therefore the labels are wrong. Why? Because we are simply passing labels in an orderly fashion, and seaborn is blind to the fact that the labels now do not correspond to the groups/hues.

  • Plot 4: As with plot 3, but now we correct the order of the labels by matching both hue and value (the y value corresponding to the bar) on the fly, and extracting the label from the column we want (again num_custs).

In [13]:
sm_sales_data['gender_cust_type'] = [f'{gender} {cust}' for gender, cust in zip(sm_sales_data['Gender'], sm_sales_data['Customer type'])]

total_sales_branch_cust = sm_sales_data.groupby(['Branch', 'gender_cust_type']).sum().reset_index().iloc[:,:3]
total_sales_branch_cust.columns = total_sales_branch_cust.columns[:2].tolist() + ['Total sales']

num_custs = sm_sales_data.groupby(['Branch', 'gender_cust_type']).count().reset_index().iloc[:,:3]
num_custs.columns = num_custs.columns[:2].tolist() + ['num_custs']

total_sales_branch_numcust = total_sales_branch_cust.set_index(['Branch', 'gender_cust_type']).join(num_custs.set_index(['Branch', 'gender_cust_type'])).reset_index()

print(total_sales_branch_numcust)
fig, axes = plt.subplots(2,2,figsize=(16,12))

## Plot 1
sns.barplot(
    data=total_sales_branch_numcust,
    x='Branch',
    y='Total sales',
    hue='gender_cust_type',
    ax=axes[0,0]
).set_title('Plot 1\n\n')
axes[0,0].legend(bbox_to_anchor=(0.6,0.35)).get_frame().set_linewidth(0)

#assign bar labels to plot 1
for container in axes[0,0].containers:
    container.datavalues = [i/1000 for i in container.datavalues]
    axes[0,0].bar_label(container, 
                  labels = [f'{i:.2f}k' for i in container.datavalues],
                  rotation=90,
                  padding=5)

#################    
## Plot 2
sns.barplot(
    data=total_sales_branch_numcust,
    x='Branch',
    y='Total sales',
    hue='gender_cust_type',
    ax=axes[0,1]
).set_title('Plot 2\n\n')
axes[0,1].legend(bbox_to_anchor=(0.6,0.35)).get_frame().set_linewidth(0)

#assign bar labels to plot 2
for container in axes[0,1].containers:
    hue = container.get_label()
    bar_labels = total_sales_branch_numcust[total_sales_branch_numcust['gender_cust_type']==hue]['num_custs']
    axes[0,1].bar_label(container, 
                  labels = [f'{i} custs' for i in bar_labels],
                  rotation=90,
                  padding=5)

#################
## Plot 3
branch_order = ['Regional', 'Inner City', 'Other']
hue_order = ['Female Member', 'Male Member', 'Female Normal', 'Male Normal']

sns.barplot(
    data=total_sales_branch_numcust,
    x='Branch',
    y='Total sales',
    hue='gender_cust_type',
    order = branch_order,
    hue_order=hue_order,
    ax=axes[1,0]
).set_title('Plot 3\n\n')
axes[1,0].legend(bbox_to_anchor=(0.6,0.35)).get_frame().set_linewidth(0)

#assign bar labels to plot 3
for container in axes[1,0].containers:
    hue = container.get_label()
    bar_labels = total_sales_branch_numcust[total_sales_branch_numcust['gender_cust_type']==hue]['num_custs']
    axes[1,0].bar_label(container, 
                  labels = [f'{i} custs' for i in bar_labels],
                  rotation=90,
                  padding=5)
    
#################
## Plot 4
sns.barplot(
    data=total_sales_branch_numcust,
    x='Branch',
    y='Total sales',
    hue='gender_cust_type',
    order = branch_order,
    hue_order=hue_order,
    ax=axes[1,1]
).set_title('Plot 4\n\n')
axes[1,1].legend(bbox_to_anchor=(0.6,0.35)).get_frame().set_linewidth(0)

#assign bar labels to plot 4
for container in axes[1,1].containers:
    hue = container.get_label()
    values = container.datavalues
    bar_labels = [total_sales_branch_numcust[(total_sales_branch_numcust['gender_cust_type']==hue) &\
                                              (total_sales_branch_numcust['Total sales']==value)]['num_custs'].tolist()[0] for value in values
                                              ] #this is the critical line
    axes[1,1].bar_label(container, 
                  labels = [f'{i} custs' for i in bar_labels],
                  rotation=90,
                  padding=5)

plt.subplots_adjust(hspace=0.5, wspace=0.2)
        Branch gender_cust_type  Total sales  num_custs
0   Inner City    Female Member      4469.45         80
1   Inner City    Female Normal      4560.42         81
2   Inner City      Male Member      4612.04         87
3   Inner City      Male Normal      4983.58         92
4        Other    Female Member      5609.23         96
5        Other    Female Normal      4272.90         82
6        Other      Male Member      4336.31         73
7        Other      Male Normal      4349.32         77
8     Regional    Female Member      4479.46         85
9     Regional    Female Normal      4295.78         77
10    Regional      Male Member      4653.21         80
11    Regional      Male Normal      5050.43         90

Stacked Bar Chart with Notation¶

As far as I am aware, Seaborn does not provide stacked bar charts natively, but there are a couple of ways around it for different cases. If we are interested in simple stacked bar charts with different heights (say, we are simply plotting the breakdown of the total revenue by each product involved) we can use the pandas API to plot and give the style using our seaborn configuration. Granted, this does not use seaborn per se, but the result still has a seaborn feeling to it. The second one uses histplot() and is useful when the stacked bars all add to the same total (say, it's a percentual breakdown, so all elements in each bar add to 100%). Additionally, we can always bring attention to a particular product we are interested in, for whatever reason. I will create some random game sales data and plot both approaches while focusing on a particular game to highlight (in this case, 'Grand Chef Auto'). Note that in the second case, for the text to be assigned correctly to the given bar the order of the items needs to be reversed (see the hue_order argument).

In [14]:
np.random.seed(0)

#Create some random games sales data
games_sales = pd.DataFrame({
    'Name': ['Super Kung Fu II', 'Luigi\'s Cottage', 'Grand Chef Auto', 'Wario Kart DX'] * 3,
    'Year': ['1999'] * 4 + ['2000'] * 4 + ['2001'] * 4,
    'Sales': np.random.randint(10, 1000, 12)
})

#Calculate the percentual share of each game per year
games_sales['Share'] = games_sales['Sales'] * 100 / games_sales.groupby('Year')['Sales'].transform(np.sum)

# Normal Stacked Bar Chart
fig, axes = plt.subplots(1,1,figsize=(7,5))
games_sales.pivot(index='Year', columns='Name', values='Sales').plot(kind='bar', stacked=True, color=sns.color_palette()[:4], ax=axes, lw=0);
axes.set_title('Sales by Year')
legend = axes.get_legend()
legend.set_bbox_to_anchor((1, 1))
legend.set_title('Game')
legend.get_frame().set_linewidth(0)
plt.xticks(rotation=0);
plt.xlabel('');

#Add notation for the game of interest:

for name, year, sales in zip(games_sales['Name'], games_sales['Year'], games_sales['Sales']):
    if name == 'Grand Chef Auto':
        plt.text(x = games_sales['Year'].unique().tolist().index(year), 
                 y = sales * .5, #offsetting so the text is around the middle of the relevant bar
                 s = f'${sales:.0f}',
                 color='w',
                 size= 12,
                 horizontalalignment='center',
                 verticalalignment='center',
                 weight='bold'
                 )

plt.ylabel('');
plt.xlabel('');
In [86]:
#Percent stacked bar chart
fig, axes = plt.subplots(1,1,figsize=(7,5))

games_names = games_sales['Name'].unique().tolist()
sns.histplot(
    games_sales, 
    x='Year', 
    hue='Name', 
    hue_order=reversed(games_names),
    weights='Share', 
    multiple='stack', 
    palette=reversed(sns.color_palette()[:4]),
    shrink=0.8,
    lw=0
    ).set_title(f'Share of Sales by Year\n')

# Fix the legend so it's not on top of the bars.
legend = axes.get_legend()
legend.set_bbox_to_anchor((1, 1))
legend.set_title('Game')
legend.get_frame().set_linewidth(0)

#Fix the y ticks so they show the percentage
plt.ylim(0,100)
axes.set_yticks(ticks=axes.get_yticks(), labels=[f'{i:.0f}%' for i in axes.get_yticks()])

#Add the notation for the game of interest
for i in range(len(axes.containers)):
    if i == games_names.index('Grand Chef Auto'):
        hue = games_names[i]
        values = games_sales[games_sales['Name']==hue]['Share'].tolist()
        axes.bar_label(axes.containers[i], 
                    labels = [f'{i:.1f}%' for i in values],
                    rotation=0,
                    size=14,
                    color='k',
                    weight='bold',
                    label_type='center',
                    padding=0
                    )

plt.xlabel('')
plt.ylabel('');

Pyramid Chart¶

Sadly Seaborn does not offer pyramid plots natively, but since a pyramid chart is simply two horizontal bar charts back to back we can create a pyramid chart with seaborn by leveraging barplot() and putting half of the data with the opposite sign. You can choose to annotate it as well, and clear the x ticks since we're passing the information in the form of labels for each chart anyway, and the labels can also contain extra information as usual (in this case the proportion of the segment relative to the whole population).

In [83]:
#Create the data

ages_groups = ['0-4','5-9','10-14','15-19','20-24','25-29','30-34','35-39','40-44','45-49','50-54','55-59',
                             '60-64','65-69','70-74','75-79','80-84','85-89','90-94','95-99','100+']

ages = pd.DataFrame({'Age': reversed(ages_groups), 
                    'Male': reversed([-49228000, -61283000, -64391000, -52437000, -42955000, -44667000, -31570000, -23887000, -22390000, -20971000, 
                             -17685000, -15450000, -13932000, -11020000, -7611000, -4653000, -1952000, -625000, -116000, -14000, -1000]), 
                    'Female': reversed([52367000, 64959000, 67161000, 55388000, 45448000, 47129000, 33436000, 26710000, 25627000, 23612000, 20075000, 16368000, 
                               14220000, 10125000, 5984000, 3131000, 1151000, 312000, 49000, 4000, 2000])})

#Put the data in the right shape
ages_melt = ages.melt(id_vars='Age', var_name='Gender', value_name='Count')
ages_melt['Count'] = ages_melt['Count'] / 10
ages_melt['Prop'] = abs(ages_melt['Count'] * 100 / abs(ages_melt['Count']).sum())
ages_melt['Age'] = pd.Categorical(ages_melt['Age'], categories=reversed(ages_groups), ordered=True)

# Pyramid plot
fig, axes = plt.subplots(1,1,figsize=(10,10))
colors = reversed(sns.color_palette()[1:3])

for color, gender in zip(colors, ages_melt['Gender'].unique()):
    sns.barplot(data=ages_melt.loc[ages_melt['Gender']==gender, :], x='Count', y='Age', color=color, label=gender, lw=0)\
        .set_title('Population Distribution by Age\nin Wakanda\n')
    
for container in axes.containers:
    gender = container.get_label()
    subset=ages_melt[ages_melt['Gender']==gender]
    barlabels = [
        f'{num*-1:.0f}, {perc:.4f}%' if num < 0 and num > -10**3\
        else f'{num*-1/1000:.1f}k, {perc:.2f}%' if num <= -10**3 and num > -10**6\
        else f'{num*-1/10**6:.1f}M, {perc:.1f}%' if num < -10**6\
        else f'{num:.0f}, {perc:.4f}%' if num > 0 and num < 10**3\
        else f'{num/1000:.1f}k, {perc:.2f}%' if num >= 10**3 and num < 10**6\
        else f'{num/10**6:.1f}M, {perc:.1f}%' for num, perc in zip(subset['Count'], subset['Prop'])]
    axes.bar_label(container, 
                   label_type='edge', 
                   labels=barlabels, 
                   padding=10,
                   color='k',
                   size=12);

plt.legend(bbox_to_anchor=(1,1.1)).get_frame().set_linewidth(0)
plt.xlabel('')
plt.ylabel('')
plt.xlim(-10**7, 10**7);
axes.set_xticklabels('');

Plotting over background image¶

Many stakeholders like some 'extra touch' on their presentations, like having their corporate logos as backgrounds on some plots in flagship reports. Why not make them happy? Adding a background image is easy, and pyplot allows us to control the order and transparency of the images, which can be easily integrated with any seaborn plot as well, to make the latter shine brighter.

In [17]:
ps_years = range(2000,2016)
all_ps = games_data[(games_data['Platform'].isin(['PS', 'PS2', 'PS3', 'PS4', 'PSV', 'PSP'])) & (games_data['Year'].isin(ps_years))].reset_index(drop=True)
all_ps_byyear = all_ps.groupby(['Year']).sum().reset_index()
all_ps_byyear = pd.melt(all_ps_byyear, id_vars='Year', value_vars=['NA_Sales','EU_Sales','JP_Sales'], value_name='Sales', var_name='Country')
all_ps_byyear['Country'] = all_ps_byyear['Country'].map({'NA_Sales': 'North America', 'EU_Sales': 'Europe', 'JP_Sales': 'Japan'})

fig, axes = plt.subplots(1,1,figsize=(10,7))
sns.lineplot(data=all_ps_byyear, x='Year', y='Sales', hue='Country', ax=axes, zorder=1).set_title('Play Station Games Sales in Time\nAll Consoles, 2000-2015\n\n')

#Add image
img = plt.imread('images/ps_bg.png')
plt.imshow(img, zorder=0, aspect='auto', alpha=0.3, extent=[axes.get_xlim()[0], #left
                                                axes.get_xlim()[1], #right
                                                axes.get_ylim()[0], #bottom
                                                axes.get_ylim()[1] #top
                                                ]
          )

axes.grid(axis='y', linewidth=0.5)
plt.legend(title='', bbox_to_anchor=(0.7,0.95)).get_frame().set_linewidth(0)
plt.xlabel('')
plt.ylabel('Sales in $USD (Millions)');

Using LaTeX without installing it?¶

Well, not quite "using" LaTeX per se, but this is a way to bypass having to install LaTeX if you only use it sparingly, and this way you can also leverage some online tools to create the equations you want to add to the plots. The idea in this example is to assign LaTeX equations to the x and y labels, but to be honest this method works to add any image to the plot and having it outside the plotting area (otherwise you can simply use the previous method and change the zorder if you wanted it on top of the plot instead of the back).

If you use an online tool to generate and export equations as images (like this one), or you are using something else to generate the equation images, you can then import these equation images and overlay them to the seaborn plot easily using both OffsetImage and AnnotationBbox (image rotation, if needed, can be handled with SciPy's ndimage):

In [18]:
from scipy import ndimage
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

#Create some data and plot it as usual
t = np.linspace(0.0, 1.0, 100)
s = np.cos(4 * np.pi * t) + 2

fig, axes = plt.subplots(1,1, figsize=(10, 7))
sns.lineplot(x=t, y=s, ax=axes, zorder=0).set_title('Custom Function')

#Import/rotate images
t_img = plt.imread('images/t(s)2.png')
cos_img = plt.imread('images/cos2.png')
rotated_cos_img = ndimage.rotate(cos_img, 90)

#Plot the images
im = OffsetImage(rotated_cos_img, zoom=0.5) #zoom will tell you how big/small the image will appear
im2 = OffsetImage(t_img, zoom=0.5)

y_latex = AnnotationBbox(im, (-0.1, 0.5), xycoords='axes fraction', frameon=False) #the tuple indicates the coordinates for the image
x_latex = AnnotationBbox(im2, (0.5, -0.1), xycoords='axes fraction', frameon=False)

axes.add_artist(y_latex);
axes.add_artist(x_latex);


# Some extra fanciness because why not
ylabels = axes.get_yticks()
axes.set_yticks(ticks=ylabels, labels = [i if i%1 == 0 else '' for i in ylabels])

axes.spines.left.set_visible(True)
axes.spines.left.set_bounds(s.min(), s.max())

axes.spines.bottom.set_visible(True)
axes.spines.bottom.set_bounds(t.min(), t.max())
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Thanks for reading and, again, if you think there are some interesting things that could be added to this list that are not easily covered by most tutorials/courses online, drop me a line via the main page!