Matplotlib is one of the most commonly used tools for plotting in Python. In this article, we’ll look at how to use matplotlib to create some basic plots, such as line plots, pie chart, histograms, bar and scatter plots.
So, let’s get started with the line plot.
Line Plots:
This is the most basic type of plot. A line plot is typically used to plot the relationship between two numerical variables.
Below is a code snippet for plotting the number of ice creams sold during a week.
|
import pandas as pd import matplotlib.pyplot as plt
#number of ice creams sold during a week ice_cream = [35,33,65,44,75,88,101]
plt.plot(ice_cream)
plt.show() |
A line plot, simply put, connects the data points with a straight line.
Pie Chart:
A pie chart is a circular figure that depicts the percentage or proportion of data.
To create a pie chart, we can use the Matplotlib’s pie() function.
The following code snippet shows the market share of browsers worldwide. The stats are taken from this website.
|
x = [64.73,18.43,3.37,3.36,10.11] labels = [‘Chrome’, ‘Safari’, ‘Edge’, ‘Firefox’, ‘Others’]
plt.pie(x, labels=labels,autopct=‘%.1f%%’) plt.show() |
The autopct argument in the function pie() is used to show the percentage values on the pie chart.
If you want the slices to be separated, we can use the argument explode in the pie() function.
To explode all the slices use the following setting.
|
plt.pie(x, labels=labels,autopct=‘%.1f%%’,explode=[0.1]*5) plt.show() |
It’s worth noting that the length of the list provided to explode should be the same as the number of categories.
If you need to highlight the market share of a particular company, say, Safari, we can use the following.
|
plt.pie(x, labels=labels,autopct=‘%.1f%%’,explode=[0,0.3,0,0,0]) plt.show() |
Scatter Plot:
Another popular plot is scatter plot. It plots the relationship between two numeric features in a data set.
Each data point will have a x and y coordinate and is represented by a dot.
The following scatter plot shows the relationship between the experience and salary of people.
The CSV file can be downloaded here.
|
import pandas as pd import matplotlib.pyplot as plt
sal = pd.read_csv(‘/Salary_Data.csv’) sal.head() |
The following is the first 5 entries from the dataset.
Now let’s draw the scatter plot.
|
experience = sal[‘YearsExperience’] salary = sal[‘Salary’]
plt.scatter(experience, salary) plt.show() |
Bar Plot:
When comparing data, bar charts are helpful. It compares different types of data using rectangular bars.
We’ll use the same ice cream sales data that we used for the line plot.
|
import pandas as pd import matplotlib.pyplot as plt
ice_cream = [35,33,65,44,75,88,101] days = [‘Mon’,‘Tues’,‘Wed’,‘Thur’,‘Fri’,‘Sat’,‘Sun’]
plt.bar(days, ice_cream) plt.show() |
To plot a horizontal bar graph we can use the barh() function.
|
plt.barh(days, ice_cream) plt.show() |
Histogram Plot:
A histogram graphically depicts the distribution of numerical data. The range of values is divided into equal-sized bins. The height of each bin represents the frequency of values in that bin.
The following is an example of plotting the distribution of salaries(same data which we have used for scatter plot).
To draw the histogram we’ll make use the hist() function in matplotlib. It will group the data points into bins and plot the frequencies as bars for each bin.
|
import pandas as pd import matplotlib.pyplot as plt
sal = pd.read_csv(‘/Salary_Data.csv’)
experience = sal[‘YearsExperience’] salary = sal[‘Salary’]
plt.hist(salary, bins=7) plt.show() |
You can also change the size of the bin using the bins argument.
Box Plot:
Box plot, also known as box-and-whisker plot, helps us to study the distribution of the data. It is a very convenient way to visualize the spread and skew of the data.
It is created by plotting the five-number summary of the dataset: minimum, first quartile, median, third quartile, and maximum.
If you’re curious about the five-number summary, I have written an article on How to Interpret Box Plots. Check it out.
The following code snippet plots the boxplot for the salary feature.
|
import pandas as pd import matplotlib.pyplot as plt
sal = pd.read_csv(‘/content/Salary_Data.csv’)
experience = sal[‘YearsExperience’] salary = sal[‘Salary’]
plt.boxplot(salary) plt.show() |
Customizing Plots to make it more readable:
Now let’s see how we can improve the readability of our plots.
First, we’ll start with adding titles to our plots. We’ll use the line plot.
We can add the title to our plots by simply adding the line plt.title(). Similarly, the xlabel() and ylabel() functions can be used to add x and y labels.
|
import pandas as pd import matplotlib.pyplot as plt
ice_cream = [35,33,65,44,75,88,101]
plt.plot(ice_cream) plt.title(‘Ice Cream Sales’) plt.xlabel(‘Day’) plt.ylabel(‘# ice creams sold’) plt.show() |
To add a legend to the plot, we must first pass a value to the plot() function’s argument label. Next, we need to add the legend() function from matplotlib.
|
plt.plot(ice_cream, label=‘No. Of Ice Creams sold’) plt.title(‘Ice Cream Sales’) plt.xlabel(‘Day’) plt.ylabel(‘# ice creams sold’) plt.legend() plt.show() |
We can also change the location of the legend by passing value to the argument loc in the plt.legend() function.
|
plt.legend(loc=‘lower right’) |
The loc argument accepts the following values:
- best
- upper right
- upper left
- lower left
- lower right
- right
- center left
- center right
- lower center
- upper center
- center
Now let’s say we have the names of the days on the x-axis instead of numbers. The resultant plot would look like the following.
We can see that the x axis is little congested. To make this more readable, we can rotate the labels with the xticks() function.
By using the rotation parameter in the xticks() we can rotate the x-axis label.
The following code will rotate it vertically.
|
plt.xticks(rotation=‘vertical’) plt.show() |
We can also pass any number as a value to the argument rotation. If we pass 45 as the value to the argument rotation, it will rotate the labels by 45 degrees.
Similarly, we can use the yticks () to rotate the label on the y-axis.
All these techniques can be used for other plots as well.
Now let’s see how we can draw multiple plots on the same figure.
Plotting Multiple Plots:
Multiple plots are arranged on a m x n grid in a figure, where m denotes the number of rows and n denotes the number of columns.
Matplotlib’s subplot() function can be used to create multiple plots on a single figure.
We’ll use the Iris data set to plot the distribution of different features using a histogram. You can download the data set from here.
The following code snippet will plot the histogram for all the four features in the iris dataset i.e., SepalLength, SepalWidth, PetalLength, and PetalWidth
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
|
import pandas as pd import matplotlib.pyplot as plt
iris = pd.read_csv(‘/content/Iris.csv’)
plt.subplot(2,2,1) plt.title(‘Sepal Length’) plt.hist(iris[‘SepalLengthCm’])
plt.subplot(2,2,2) plt.title(‘Sepal Width’) plt.hist(iris[‘SepalWidthCm’])
plt.subplot(2,2,3) plt.title(‘Petal Length’) plt.hist(iris[‘PetalLengthCm’])
plt.subplot(2,2,4) plt.title(‘Petal Width’) plt.hist(iris[‘PetalWidthCm’])
plt.show() |
The subplot() function is the only thing that is new to us. So, let’s try to understand what the numbers inside the subplot() signify.
The first two values in the plt.subplot(2,2,1) denote the grid size, the first being the value of m (row size) and the second being the value of n (column size).
The third value denotes where we want to place the plot on the grid. A value of 1 should be used to place the plot in the first cell of the grid.
However, one problem with our multiplot figure is that the titles are overlapping and are difficult to read.
We can use the tight_layout() method to make the subplots more spaced out.
|
# to make the plots spaced out plt.tight_layout() plt.show() |
The titles in the above figure is more readable than the previous one.
Saving the Plot:
To save the output we can use the savefig() method. We just need to pass the name for the file.
The following code snippet will save the output with the name output.png
|
import pandas as pd import matplotlib.pyplot as plt
#number of ice creams sold for the week ice_cream = [35,33,65,44,75,88,101]
plt.plot(ice_cream)
plt.show()
plt.savefig(‘output.png’) |
You can find the complete code in this Github Repo.