Mastering Matplotlib: Easy Plotting Tips and Common Pitfalls Explained
We live in an era where everything in the world, and the world itself, is explained by data. We have to learn how to visualize data to be able to depict our thoughts and knowledge.
When it comes to creating plots in Python, Matplotlib stands out as one of the most popular tools. However, using it efficiently can be a bit tricky. This article is here to help, breaking down Matplotlib's features in a simple and practical way.
Matplotlib gives us two main ways to make plots. The first is the functional way, great for quick visualizations, especially in places like Jupyter Notebooks. The second is the object-oriented way, which is super useful for more complex plots. Personally, I like the second way more because it gives us better control over our plots, and it's easier to understand.
In this article, I'll introduce both methods, but we'll mostly focus on the object-oriented way. This choice allows us to explore different figures and features hands-on, getting a real feel for what Matplotlib can do.
So, let's dive into the article and uncover the magic of Matplotlib's 3D plotting. Whether it's scatter plots, bar charts, quiver plots, or polar plots – each method has its own role in making your data come to life. Whether you're showing density distributions, comparing datasets, or creating awesome 3D visuals, Matplotlib has got almost everything you need. Let's see how we can make the most of this powerful library for all your plotting adventures!
Functional approach
The functional approach in Matplotlib involves using the 𝐩𝐲𝐩𝐥𝐨𝐭 interface, which relies on a global state to configure and create plots. This method offers a simple way to generate basic plots by directly calling functions from the 𝐩𝐲𝐩𝐥𝐨𝐭 module.
Consider the following example:
import matplotlib.pyplot as plt
import numpy as np
# Generating sample data
x = np.linspace(0, 10, 100)
y = np.sin(x)
# Using the functional approach (plt.plot())
plt.plot(x, y, label='Sine Curve')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Functional Approach')
plt.legend()
plt.show()
Here's a detailed explanation of the example:
In this functional approach, the 𝐩𝐲𝐩𝐥𝐨𝐭 interface abstracts much of the underlying complexity, making it suitable for quick and simple visualizations. However, this approach might be limited in handling more intricate plot configurations or managing multiple subplots. In contrast, the object-oriented approach offers finer control over individual plot elements, making it more flexible for complex layouts and detailed customizations. Object-oriented methods promote clearer code structure, enhancing readability and maintainability (at least for those with a programming background), making them preferable for intricate visualizations and larger projects.
Object-oriented method
In object oriented approach, we work with figures and subplots.
The following picture, shows a 2x2 grid of subplots within a single figure. Each subplot is positioned in a different section of the grid. They display plots of sin(x), cos(x), tan(x), and x^2.
Creating such a figure using the legacy functional mode results in a code that is challenging to maintain and extend in the future. I always recommend using the newer, modern interfaces.
To create a new figure object, we can use 𝐩𝐥𝐭.𝐟𝐢𝐠𝐮𝐫𝐞(); it initializes a blank canvas ready for plotting. Then add individual subplots to it using the 𝐚𝐝𝐝_𝐬𝐮𝐛𝐩𝐥𝐨𝐭() method.
# Creating a figure object explicitly
fig = plt.figure()
# Adding a subplot to figure using add_subplot()
ax = fig.add_subplot(2,2,1)
# 2,2,1 means 2 row, 2 column, first subplot
ax.plot(x, np.sin(x))
ax.set_title('Sine Curve')
ax = fig.add_subplot(2,2,2)
ax.plot(x, np.cos(x))
ax.set_title('Cos Curve')
ax = fig.add_subplot(2,2,3)
ax.plot(x, np.tan(x))
ax.set_title('Tan Curve')
ax = fig.add_subplot(2,2,4)
ax.plot(x, np.power(x,2))
ax.set_title('x^2')
Alternatively, we can use 𝐩𝐥𝐭.𝐬𝐮𝐛𝐩𝐥𝐨𝐭() and 𝐩𝐥𝐭.𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬() convenience methods. They are both functions, used to create multiple subplots within a single figure, but they have different purposes and usage. 𝐬𝐮𝐛𝐩𝐥𝐨𝐭() function returns one single subplot, while 𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬() returns one figure and multiple subplots. First, lets take a look at an example of using the first function:
# Adding a single subplot using subplot()
ax = plt.subplot(2,2,1)
# 2,2,1 means 2 rows, 2 columns, first subplot
ax.plot(x, np.sin(x))
ax.set_title('Sine Curve')
ax = plt.subplot(2,2,2)
ax.plot(x, np.cos(x))
ax.set_title('Cos Curve')
ax = plt.subplot(2,2,3)
ax.plot(x, np.tan(x))
ax.set_title('Tan Curve')
ax = plt.subplot(2,2,4)
ax.plot(x, np.power(x,2))
ax.set_title('x^2')
As you see in the sample code, 𝐬𝐮𝐛𝐩𝐥𝐨𝐭() is very similar to 𝐚𝐝𝐝_𝐬𝐮𝐛𝐩𝐥𝐨𝐭(). However, with 𝐬𝐮𝐛𝐩𝐥𝐨𝐭(), we do not have direct access to the figure. 𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬() on the other hand, returns both the figure and the collection of all the subplots:
# Adding four subplots using subplots() method
fig, axes = plt.subplots(2,2)
# 2,2 means 2 rows, 2 columns
axes[0][0].plot(x, np.sin(x))
axes[0][0].set_title('Sine Curve')
axes[0][1].plot(x, np.cos(x))
axes[0][1].set_title('Cos Curve')
axes[1][0].plot(x, np.tan(x))
axes[1][0].set_title('Tan Curve')
axes[1][1].plot(x, np.power(x,2))
axes[1][1].set_title('x^2')
In summary, the 𝐩𝐥𝐭.𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬() function is a convenient way to create both the figure and several subplots simultaneously , while 𝐩𝐥𝐭.𝐬𝐮𝐛𝐩𝐥𝐨𝐭() only returns a subplot each time.
Customizing figures
If you use 𝐩𝐥𝐭.𝐟𝐢𝐠𝐮𝐫𝐞() or 𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬() functions, then you have access to figure object and you can customize it.
𝐟𝐢𝐠.𝐬𝐮𝐛𝐩𝐥𝐨𝐭𝐬_𝐚𝐝𝐣𝐮𝐬𝐭() adjusts the spacing between subplots within a figure. By specifying parameters like left, right, top, bottom, wspace, and hspace, you can control the distance between subplots horizontally (wspace) and vertically (hspace), as well as adjust the margins (left, right, top, bottom). This helps manage the layout and alignment of subplots within the figure.
fig.subplots_adjust(hspace=0.4, wspace=0.3)
There are several other figure customization functions that help modify various aspects of the figure. Some common ones include:
Customizing Plots
Customizing Matplotlib plots involves various techniques, including using styles and themes to alter the overall appearance. Annotations can be added to highlight specific points, and the library supports various axis scales, including logarithmic. You can customize grid lines and add axis labels for clarity. When dealing with multiple datasets, incorporating legends helps in distinguishing them.
Styling plots
Plots have some styling parameters (color, line style, marker style, line width, marker size, labels, and legends) that help differentiate and enhance the visual representation of the plotted data, making it easier to interpret and understand the relationships between the sets of data. For example, the following plot illustrates y = x^2 and y = x^3 by adjusting these parameters:
x = np.linspace(-2, 2, 20)
fig = plt.figure()
ax = fig.add_subplot(1,2,1)
ax.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax.legend()
Scaling and presenting multiple datasets
Now, I want to illustrate an annoying problem when you work with data. Guess what you see if you run the following code:
x = np.linspace(-100, 100, 20)
fig = plt.figure()
ax = fig.add_subplot(1,2,1)
ax.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax.legend()
This code is very similar to the previous one, I just expanded the domain where data is plotted. but the result is surprising:
You will just see a horizontal blue line for y = x^2! The issue arises due to the vast difference in the growth rates between the functions y = x^2 and y = x^3 within the specified range of x-values.
When plotted together on the same graph with a linear scale, the y = x^3 curve grows much faster than y = x^2 for the provided range of x-values (-100 to 100). As a result, the y = x^3 curve dominates the plot, and the y = x^2 curve appears nearly horizontal or flat relative to the rapid increase shown by y = x^3.
Because of the significant difference in growth rates, the y = x^2 curve appears almost straight when visualized alongside the much steeper y = x^3 curve. This issue occurs due to the scaling of the plot, where the differences in growth rates of the functions are not effectively accommodated within the same scale range, making the slower-growing function appear relatively flat or linear.
To handle the problem we may consider different strategies. We can simply plot each dataset on separately however this method may not be completely optimal for comparing y values for the same x values. One solution to mitigate the problem is to use logarithmic scaling for Y-axis:
ax.set_yscale("symlog")
The advantage of this method is simplicity, we just need to change the scale by only one simple command. Now, we can see that the x^2 is also growing. While using a logarithmic scale helps to display a wider range of values more clearly, It does not solve the problem completely, as the visualization of the data is not always accurate this way.
In my view, the best way is to use a secondary Y-axis. If the datasets share the same x-axis but have vastly different scales on the y-axis, you can use a secondary y-axis. This approach allows you to plot datasets with different scales on the same plot while maintaining clarity.
x = np.linspace(-100, 100, 20)
fig = plt.figure(figsize=(10,5))
ax1 = fig.add_subplot(1,2,1)
ax1.plot(x, x**2, color='blue', linestyle='--', linewidth=1, marker='o', markersize=2, label='y = x^2')
ax1.set_ylabel("y=x^2")
ax1.legend( loc="upper left")
ax2 = ax1.twinx()
ax2.plot(x, x**3, color='red', linestyle='dotted', linewidth=1, marker='*', markersize=2, label='y = x^3')
ax2.set_ylabel("y=x^3")
ax2.legend(loc="upper right")
Each dataset is plotted separately on its respective y-axis, allowing clear visualization of both datasets' trends without one overshadowing the other due to different scales.
This approach effectively visualizes datasets with different scales on the same x-axis, aiding in comparing their trends while preserving their individual characteristics. Adjust the parameters as needed for your specific datasets and visualization requirements.
Using annotations
Sometimes, it's essential to draw attention to specific data points or areas within a plot. Annotations in Matplotlib serve as a valuable tool for highlighting particular points or regions. They enable the addition of textual information, arrows, or markers at specific locations, emphasizing critical data points. Below is an example, demonstrating the usage of annotations to highlight points on a plot:
ax = plt.subplot()
x = np.arange(-180,180,1)
ax.plot(x,np.sin(x/180*np.pi), label="Sin(x)" , color="blue")
ax.plot(x,np.cos(x/180*np.pi), label="Cos(x)" , color="green")
ax.set_xticks(np.arange(-180,200,45))
ax.legend()
ax.grid()
ax.annotate("Sin x = Cos x",(45,0.7),(45,-0.4),arrowprops=dict(arrowstyle='->'),ha="center")
ax.annotate("",(-135,-0.7),(8,-0.4),arrowprops=dict(arrowstyle='->'),ha="center")
Common Plotting Functions
Now it's time to introduce various plotting types. Matplotlib isn't just about plotting mathematical functions. It offers a host of other helpful methods for handling categorical data, creating bar plots, histograms, scatter plots, and more.
Matplotlib is a versatile library that offers various plotting functions. Initially, I faced confusion due to the multitude of options available, making it challenging to create meaningful plots. To simplify the process, I organized the functions based on the types of data they are best suited for. Understanding these categorized groups can assist in selecting the most suitable function for your specific data.
Methods for Numerical Data:
Methods for Categorical Data:
Methods for Specialized Data:
Understanding these categorized groups can simplify the process of choosing the appropriate function tailored to your specific data type. Note that certain functions, such as boxplot(), bar(), or hist(), can handle both numerical and categorical data depending on their usage and input parameters.
plot() and scatter()
Both 𝐩𝐥𝐨𝐭() and 𝐬𝐜𝐚𝐭𝐭𝐞𝐫() are used with numerical data, presenting relationships between X and Y variables. However, their fundamental differences lie in how they visualize this relationship:
𝐩𝐥𝐨𝐭() typically creates line-based plots, emphasizing the connected nature of the data points. It's often used to display trends, sequences, or continuous data, showcasing the overall pattern between data points through lines or markers connected by default.
𝐬𝐜𝐚𝐭𝐭𝐞𝐫() focuses on individual data points, emphasizing the distinct nature of each point. It doesn't connect points with lines by default, presenting data as separate markers. This function is commonly utilized to explore correlations, clusters, or distributions within a dataset, especially when highlighting individual data points is essential.
For example, this code generates 100 values of y that are linearly related to x with a certain amount of random noise added.
mu_x = 50
mu_y = 30
sigma_x = 10
sigma_y = 15
correlation_coefficient = 0.7
# Generate data
x = np.random.normal(mu_x, sigma_x, 100)
y = np.random.normal(mu_y, sigma_y, 100)
# Transform data to match correlation
y = correlation_coefficient * x + np.sqrt(1 - correlation_coefficient**2) * np.random.normal(0, sigma_y, 100)
We can illustrate the correlation between two datasets using 𝐬𝐜𝐚𝐭𝐭𝐞𝐫(), while the regression line, also known as model line, or least squares line is plotted by 𝐩𝐥𝐨𝐭() function.
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x.reshape(-1, 1), y)
slope = model.coef_[0]
intercept = model.intercept_
ax = plt.subplot()
ax.scatter(x,y)
x1 = x.min()
x2 = x.max()
ax.plot([x1,x2],[slope*x1+intercept, slope*x2+intercept],color="red")
I believe the plot is inspiring and self-explanatory! It's always rewarding when visualizations effectively communicate insights and relationships within data!
hist()
As we saw already, the 𝐩𝐥𝐨𝐭() and 𝐬𝐜𝐚𝐭𝐭𝐞𝐫() functions in Matplotlib are used to visualize relationships between numerical data sets. 𝐩𝐥𝐨𝐭() displays connected points or lines to showcase the correlation between corresponding data pairs. 𝐬𝐜𝐚𝐭𝐭𝐞𝐫() emphasizes individual points, revealing associations between paired values from both sets. Unlike plot and scatter, 𝐡𝐢𝐬𝐭() doesn't require two sets of data; it focuses on a single dataset, illustrating the frequency distribution within defined intervals (bins). This unique feature makes it stand distinct as it's specifically tailored to generate histograms for understanding the distribution pattern and frequency of a single set of numerical data.
uniform = np.random.random(10000) * 8 - 4
ax = plt.subplot()
ax.hist(uniform, alpha=0.5, bins=30)
normal = np.random.normal(0,1,10000)
ax.hist(normal, color="green", alpha=0.5, bins=30)
The code generates two histograms on one plot. The x-axis represents the value range, while the y-axis shows the frequency of occurrence. Alpha (transparency) helps overlay and visualize both histograms simultaneously, where one histogram shows a uniform distribution between -4 to +4, and the other depicts a normal (Gaussian) distribution centered around 0.
Working with data always presents challenges as it often involves unforeseen problems and complexities. Rarely is data as tidy and straightforward as in the previous example. For instance, when dealing with two datasets, a common issue arises when one dataset contains extremely larger number of data samples. Improper representation of such data can result in user misconceptions.
Look at the following example:
uniform = np.random.random(10000) * 8 - 4
fig , axes = plt.subplots(1,2)
fig.set_figwidth(10)
normal = np.random.normal(0,1,1000)
axes[0].hist(uniform, alpha=0.5, bins=30)
axes[0].hist(normal, color="green", alpha=0.5, bins=30)
axes[0].set_xlabel("Frequency")
axes[1].hist(uniform, alpha=0.5, bins=30, density=True)
axes[1].hist(normal, color="green", alpha=0.5, bins=30, density=True)
axes[1].set_xlabel("Density")
In the left plot, a potential misconception stems from the uniform distribution sample having a substantially larger size (10,000) compared to the normal distribution (1,000). Overlooking this sample size difference may lead the reader to misinterpret the plot, incorrectly inferring that the uniform distribution has more samples around the mean than the normal distribution. In reality, the perceived dissimilarity is a consequence of the unequal sample sizes, introducing the possibility of misinterpretations regarding the distribution characteristics of the two datasets. Emphasizing and considering sample sizes is crucial when interpreting visualizations to prevent such misconceptions.
Normalization choices in histograms, like setting density=True, significantly affect y-axis representation. Opting for density=True helps address this issue by presenting relative frequencies or proportions, facilitating fair visual comparisons between datasets with differing sample sizes. This normalization approach mitigates biases stemming from unequal sample sizes, enabling a more accurate interpretation of distribution shapes and proportions.
boxplot()
The 𝐛𝐨𝐱𝐩𝐥𝐨𝐭 primarily focuses on displaying statistical parameters and summarizing the distribution's key features, such as median, quartiles, and outliers, rather than explicitly illustrating the data's underlying distribution function.
Unlike histograms, which provide a visual representation of the data's shape and frequency distribution, boxplots prioritize conveying statistical summary measures and identifying variability between groups or categories within the dataset. While histograms offer insights into the data's distributional shape, boxplots excel in highlighting central tendencies and spread, making them complementary visualization tools for different analytical purposes.
Although a normal distribution causes the interquartile range be narrower due to the fact that a more significant portion (around 68% for the empirical rule) of the data lies within the one standard deviation from the mean, the box plot alone doesn't explicitly display the bell-shaped curve characteristic of a normal distribution.
bar() and barh()
Many people confuse bar charts with histograms. The confusion often arises because both charts use bars to represent data. Additionally, some bar charts may have numerical data on the x-axis, which can resemble a histogram. However, the key difference lies in the nature of the data and the purpose of the chart.
Tips to avoid confusion:
import random
# Define genders
genders = ["Male", "Female"]
cities = ["New York", "London", "Tokyo", "Paris", "Berlin"]
# Generate 100 customers
customers = []
for _ in range(100):
# Randomly choose a gender
gender = random.choice(genders)
# Generate random customer data
customer = {
"id": random.randint(1, 10000),
"name": f"Customer-{random.randint(1, 1000)}",
"gender": gender,
"age": random.randint(18, 80),
"city": random.choice(cities),
}
customers.append(customer)
city_data = {}
for person in customers:
city = person["city"]
gender = person["gender"]
if city not in city_data:
city_data[city] = {"Male": 0, "Female": 0}
city_data[city][gender] += 1
# Plot hierarchical bar chart
cities = list(city_data.keys())
males = [city_data[city]["Male"] for city in cities]
females = [city_data[city]["Female"] for city in cities]
plt.bar(cities, males, label='Male' , width=0.4 , align="edge")
plt.bar(cities, females, label='Female', width=0.4 , align="center")
plt.xlabel('City')
plt.ylabel('Count')
plt.title('Count of Males and Females in Each City')
plt.legend()
plt.show()
The code generates 100 random customers with gender, age, name, and city attributes. It counts the number of males and females in each city and creates a hierarchical bar chart using Matplotlib to display the count of males and females in various cities. The bars for males and females are aligned differently for clarity.
pie()
Pie charts are effective for showcasing categorical proportions, making it easy to visualize how individual categories contribute to the whole.
city_data = {}
for person in customers:
city = person["city"]
if city not in city_data:
city_data[city] = 0
city_data[city] += 1
plt.pie(city_data.values(), labels=city_data.keys())
plt.title('Proportion of customers in each city')
plt.legend()
plt.show()
This code generates a pie chart displaying the distribution of customers in different cities.
polar()
A polar plot represents data in a circular coordinate system, where angles and distances from the center (radius) display relationships. It visualizes information radially, often used for cyclic or periodic data representations like angles, direction, or periodic patterns.
A classic and beautiful example of a polar plot is the rose curve, also known as the "rhodonea curve." This curve creates a symmetric and aesthetically pleasing pattern.
theta = np.linspace(0, 2*np.pi, 1000)
n = 6 # Number of "petals" or loops in the rose curve
r = np.cos(n*theta) # Equation for a rose curve
plt.figure(figsize=(6, 6))
plt.polar(theta, r)
plt.title(f'Rose Curve (n={n})')
plt.show()
This code snippet generates a rose curve with 6 petals.
I wrote this code so you can compare different presentations of the same data:
ax1 = plt.subplot(1,2,1 , projection = "polar")
ax1.plot(theta, r)
ax1.set_title(f'Rose Curve (n={n})')
ax1.set_position((0,0,0.5,1))
ax2 = plt.subplot(1,2,2)
ax2.plot(theta, r)
ax2.set_title(f'Plot (n={n})')
ax2.set_position((0.7,0,0.5,1))
You may have noticed I didn't utilize ax.polar() in my code due to an error:
AttributeError: 'Axes' object has no attribute 'polar'
Instead, I employed 𝐩𝐫𝐨𝐣𝐞𝐜𝐭𝐢𝐨𝐧="𝐩𝐨𝐥𝐚𝐫" alongside the plot() function. The projection parameter defines the coordinate system or projection for subplots. Setting projection="polar" configures the axes to a polar coordinate system for circular plots so the subplot ax1 functions specifically within polar coordinates for polar-type plotting capabilities.
In polar (left), the rose curve shows symmetric patterns but can be complex to interpret. In Cartesian (right), it's linear but lacks the circular clarity of polar plots. Polar offers circular visualization; Cartesian offers linear simplicity.
quiver()
The 𝐪𝐮𝐢𝐯𝐞𝐫() method is primarily designed for visualizing 2D vector fields. It plots arrows on a flat plane, representing the magnitude and direction of vectors at each data point.
Unlike many tutorials, I begin by specifying customized input parameters to illustrate how quiver works, as the default values can be somewhat confusing to explain. The following example code depicts a vector in Cartesian space. This is not an extraordinary plot; rather, it's a representation of a standard vector in a Cartesian system that you are familiar with from school years.
ax = plt.subplot()
ax.quiver([-1,0,1],[0,0,0],[-1,1,1],[1,-1,1],scale=1, scale_units="xy")
ax.set_aspect("equal")
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
ax.grid()
Let's break down each part of the code:
ax.quiver([-1, 0, 1], [0, 0, 0], [-1, 1, 1], [1, -1, 1], scale=1, scale_units="xy")
This line generates the quiver plot. The four arrays provided as arguments represent the coordinates and components of vectors. The syntax is as follows:
If we do not set the scale and scale_units parameters, we will encounter a surprising result because quiver uses a built-in algorithm to calculate the best values for scale. For example, if we do not set either the scale or scale_units parameters, the result will be:
For some scenarios, this visualization might be appropriate, but in our case, we want to see the vectors in accordance with grid lines.
Lets go through the rest of the code because they are also essential to have a clean and correct plot:
In summary, the code creates a quiver plot with three vectors, sets an equal aspect ratio, limits the plot to a specific range in both x and y directions, and adds a grid for better visualization.
What I showed was only meant to help you understand the parameters. However, a quiver plot is rarely used to visualize just a few vectors. It is commonly employed to plot a plane with numerous vectors in physics and electronics to illustrate data such as airflow, electric flow, and any other type of data represented by vectors (with a starting point and magnitude). Adding many vectors one by one is not efficient, so we use meshgrid. Here is an example:
# Creating x and y arrays
x = np.arange(0, 2, 0.2)
y = np.arange(0, 2, 0.2)
# Creating u and v components using meshgrid function
X, Y = np.meshgrid(x, y)
u = np.cos(X)*Y
v = np.sin(Y)*Y
# creating plot
fig, ax = plt.subplots(figsize =(14, 8))
ax.quiver(X, Y, u, v)
ax.set_xticks([])
ax.set_yticks([])
ax.set_ylim([-0.3, 2.3])
ax.set_xlim([-0.3, 2.3])
ax.set_aspect('equal')
In the given code, np.meshgrid(x, y) is used to create a grid of points in the form of coordinate matrices X and Y. np.meshgrid(x, y) takes the 1-dimensional arrays x and y and returns two 2-dimensional arrays (X and Y). These arrays represent the grid of points where vectors will be plotted.
The u and v components of the vectors are then calculated based on the values of X and Y. In this example, the vectors have components related to trigonometric functions and the values of X and Y.
imshow(), hexbin() , matshow()
All three methods—𝐢𝐦𝐬𝐡𝐨𝐰(), 𝐡𝐞𝐱𝐛𝐢𝐧(), and 𝐦𝐚𝐭𝐬𝐡𝐨𝐰()—are designed for visualizing 2D data or arrays and allow customization through the colormap (cmap) parameter for color mapping. However, their primary applications differ:
In summary, although all three methods share a common purpose of visualizing 2D data, they excel in specific applications, providing tailored features for diverse data types and visualization requirements.
The following example illustrates how 𝐢𝐦𝐬𝐡𝐨𝐰() works:
from PIL import Image
img = Image.open("pic.jpeg")
data = np.array(img)
red_channel = np.zeros_like(data)
red_channel[:,:,0] = data[:,:,0]
green_channel = np.zeros_like(data)
green_channel[:,:,1] = data[:,:,1]
blue_channel = np.zeros_like(data)
blue_channel[:,:,2] = data[:,:,2]
fig, axes = plt.subplots(2,2)
axes[0,0].imshow(data)
axes[0,1].imshow(red_channel)
axes[1,0].imshow(green_channel)
axes[1,1].imshow(blue_channel)
The code opens an image, separates its RGB channels (red, green, and blue), and displays each channel as well as the original image in a 2x2 grid using Matplotlib's 𝐢𝐦𝐬𝐡𝐨𝐰().
While 𝐢𝐦𝐬𝐡𝐨𝐰() is commonly used to display images, it can be applied to various types of data, not limited to RGB images. You can use it to visualize grayscale images, heatmaps, 2D arrays, or any data where a color mapping can be meaningful. It's a versatile function in Matplotlib suitable for displaying a wide range of visualizations beyond just photographs.
For instance, the next code generates two random datasets (x and y) with 500 points each, drawn from normal distributions. It then creates a side-by-side comparison of a scatter plot (on the left) and an 𝐢𝐦𝐬𝐡𝐨𝐰() plot (on the right) using Matplotlib.
x = np.random.normal(10,10,500)
y = np.random.normal(5,5,500)
# Create a scatter plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.scatter(x, y, c='blue', alpha=0.7)
plt.title('Scatter Plot')
# Create an imshow plot
plt.subplot(1, 2, 2)
plt.imshow(np.histogram2d(x, y, bins=10)[0])
plt.title('Imshow Plot')
The 𝐢𝐦𝐬𝐡𝐨𝐰() plot visualizes the density of points in the 2D space, providing a different perspective than the scatter plot, which shows individual data points.
𝐦𝐚𝐭𝐬𝐡𝐨𝐰() is very similar to 𝐢𝐦𝐬𝐡𝐨𝐰() with few differences.
matshow() automatically adjusts the aspect ratio and adds tick labels to the axes by default, making it convenient for visualizing matrices or 2D datasets. With 𝐢𝐦𝐬𝐡𝐨𝐰(), you may need to manually adjust the aspect ratio or add tick labels depending on your requirements.
Choose 𝐢𝐦𝐬𝐡𝐨𝐰() if:
Use 𝐦𝐚𝐭𝐬𝐡𝐨𝐰() when:
In the previous examples, we explored how 𝐢𝐦𝐬𝐡𝐨𝐰() and 𝐦𝐚𝐭𝐬𝐡𝐨𝐰() seamlessly display 2D arrays. You say how I utilized If you have raw data and aim to create a heatmap or visualize its distribution, 𝐧𝐮𝐦𝐩𝐲.𝐡𝐢𝐬𝐭𝐨𝐠𝐫𝐚𝐦𝟐𝐝 to compute a 2D histogram for subsequent display using these functions. However, for a more automatic and handy approach in creating hexagonal binning plots, 𝐡𝐞𝐱𝐛𝐢𝐧() proves to be a convenient alternative. 𝐡𝐞𝐱𝐛𝐢𝐧() does the binning automatically so you do not need an extra step of creating a histogram:
x = np.random.randn(1000)
y = np.random.randn(1000)
# Create a hexagonal binning plot
plt.hexbin(x, y, gridsize=20, cmap='viridis')
plt.colorbar(label='Count') # Add a colorbar for reference
# Set labels and title
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.title('Hexagonal Binning Plot')
3D Visualization
For 3D plotting, we can use ordinary methods like scatter and plot with the 𝐩𝐫𝐨𝐣𝐞𝐜𝐭𝐢𝐨𝐧="𝟑𝐝" parameter to indicate the three-dimensional aspect of the plot. In addition to these general-purpose functions, there are also specific functions tailored for 3D plotting, such as plot_surface for creating 3D surface plots, plot_wireframe for 3D wireframe plots, scatter3D for 3D scatter plots, and bar3d for 3D bar plots. These specialized functions provide more control and options for creating visually appealing and informative 3D visualizations.
Lets start by utilizing the scatter() method in 3d space:
ax = plt.subplot(projection="3d")
x = np.random.random(20)
y = np.random.random(20)
z = np.random.random(20)
ax.scatter(x,y,z)
This code creates a 3D scatter plot with randomly generated data.
bar3d()
You can use 𝐛𝐚𝐫𝟑𝐝() when you have three categorical variables and want to show how a particular quantity varies across them. For example, if you have data on sales quantity (z-axis) across different months (x-axis) and different regions (y-axis):
from scipy.stats import multivariate_normal
mvn = multivariate_normal([8,4],[[4,5],[0,2]])
months , regions = np.meshgrid(np.linspace(1,12,12), np.linspace(1,5,5))
points = np.c_[months.ravel(), regions.ravel()]
sales_quantity = mvn.pdf(points)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.bar3d(months.ravel(),regions.ravel(),np.zeros_like(months).ravel(),np.ones_like(months).ravel(),np.ones_like(months).ravel(),sales_quantity.ravel())
ax.set_title('Sales quantity')
ax.set_xlabel("Time (Month1-12)")
ax.set_ylabel("Regions 1-6")
The provided code generates a 3D bar chart using a bivariate normal distribution to model sales quantity over time (months) and regions. A bivariate normal distribution (𝐦𝐮𝐥𝐭𝐢𝐯𝐚𝐫𝐢𝐚𝐭𝐞_𝐧𝐨𝐫𝐦𝐚𝐥) is created with a mean of [8, 4] and a covariance matrix [[4, 5], [0, 2]]. Then, 𝐧𝐩.𝐦𝐞𝐬𝐡𝐠𝐫𝐢𝐝 is used to create arrays for months (1-12) and regions (1-5). Next, the 𝐧𝐩.𝐜_ function is employed to combine the meshgrid arrays into a single array of points. After that, the probability density function (pdf) of the bivariate normal distribution is calculated at each point in the combined array. Finally, a 3D bar chart is created using 𝐚𝐱.𝐛𝐚𝐫𝟑𝐝(), where the x, y, and z coordinates are determined by the meshgrid arrays, and the bar heights are determined by the calculated sales quantity.
The resulting plot visually represents the sales quantity over time (months) and regions as a 3D bar chart. Each bar's height corresponds to the sales quantity at a specific combination of time and region according to the bivariate normal distribution.
plot_surface()
𝐩𝐥𝐨𝐭_𝐬𝐮𝐫𝐟𝐚𝐜𝐞() is also available to create a 3D axis using projection='3d'. While 𝐛𝐚𝐫𝟑𝐝() is suitable for displaying categorical data in a 3D space using bars, 𝐩𝐥𝐨𝐭_𝐬𝐮𝐫𝐟𝐚𝐜𝐞() Ideal for visualizing continuous functions or datasets on a 3D surface.
from scipy.stats import multivariate_normal
mvn = multivariate_normal([1,2],[[8,7],[6,7]])
x , y = np.meshgrid(np.linspace(-10,10,100),np.linspace(-10,10,100))
points = np.c_[x.ravel(), y.ravel()]
z = mvn.pdf(points)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(x,y,z.reshape(100, 100), cmap='viridis',\
edgecolor='green')
ax.set_title('Surface plot')
The provided code is generating a 3D surface plot of the probability density function (PDF) of a multivariate normal distribution. The use of a 3D plot is appropriate because the multivariate normal distribution is characterized by two dimensions (mean vector [1, 2]), and the surface plot visualizes the probability density in a 3D space. Pay attention that, the z values, representing the PDF, are flattened during the calculation to match the shape of the grid points. The reshape(100, 100) is used to restore the original shape of the grid for proper visualization in the 3D surface plot.
plot_wireframe()
plot_wireframe is very similar to plot_surface(). It is ideal for representing the overall structure of a surface. This method uses lines to connect points on the surface without filling in the areas between them, making it suitable for emphasizing the surface's structure. This is an example:
x , y = np.meshgrid(np.linspace(-10,10,100),np.linspace(-10,10,100))
z = x**2+y**2
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_wireframe(x,y,z.reshape(100, 100), cmap='viridis',\
edgecolor='green')
ax.set_title('Surface plot')
plt.show()
This code generates a 3D wireframe plot of a surface defined by the equation z = x**2 + y**2. Here's the result plot:
Conclusion
In conclusion, this article on Matplotlib covers various plotting techniques and functions for both 2D and 3D visualizations. I started by introducing the functional and object-oriented APIs, highlighting the benefits of the latter for intricate visualizations. The article delves into creating figures, subplots, and customizing them using Matplotlib.
I aimed to provide detailed explanations and examples for common 2D plotting functions like plot, scatter, bar, pie, and more. Special attention is given to histograms, boxplots, and annotations for effective data representation. The distinction between imshow(), hexbin(), and matshow() is clarified, offering insights into their specific use cases.
In the realm of 3D plotting, I introduced scatter plots, bar charts using bar3d(), and surface plots using plot_surface() and plot_wireframe(). I demonstrated how to leverage these functions for visualizing data that spans three dimensions.
I didn't mean to write a complete guide or provide a manual for you. I just wanted to share some ideas about data visualization using Matplotlib. I brought up some challenges and solutions to help you get accustomed to the atmosphere. Indeed, this is just assistance; you may encounter many difficulties and ideas that I haven't mentioned in this article. However, I honestly believe that numerous examples, clear code snippets, and practical tips make this article a valuable resource for anyone working with Matplotlib for data visualization. Thanks for reading this guide!