Getting Started with Matplotlib

Matplotlib is Python's most popular plotting library, providing a MATLAB-like interface for creating static, interactive, and animated visualizations.

Basic Setup and Imports



python

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Set up matplotlib for better output
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# Generate sample data
np.random.seed(42)
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.sin(x) + 0.1 * np.random.randn(100)

Your First Plot



python

# Simple line plot
plt.figure()
plt.plot(x, y1)
plt.title('Simple Sine Wave')
plt.xlabel('X values')
plt.ylabel('Y values')
plt.show()

Essential Plot Types

Line Plots



python

# Multiple lines with customization
plt.figure(figsize=(12, 6))

# Plot multiple lines
plt.plot(x, y1, label='sin(x)', linewidth=2, color='blue')
plt.plot(x, y2, label='cos(x)', linewidth=2, color='red', linestyle='--')
plt.plot(x, y3, label='sin(x) + noise', alpha=0.7, color='green')

# Customization
plt.title('Trigonometric Functions', fontsize=16, fontweight='bold')
plt.xlabel('X values', fontsize=14)
plt.ylabel('Y values', fontsize=14)
plt.legend(loc='upper right', fontsize=12)
plt.grid(True, alpha=0.3)

# Set axis limits
plt.xlim(0, 10)
plt.ylim(-1.5, 1.5)

plt.tight_layout()
plt.show()

Bar Charts



python

# Sample data for bar charts
categories = ['Product A', 'Product B', 'Product C', 'Product D', 'Product E']
sales_q1 = [23, 45, 56, 78, 32]
sales_q2 = [34, 52, 48, 82, 29]

# Simple bar chart
plt.figure(figsize=(10, 6))
plt.bar(categories, sales_q1, color='skyblue', alpha=0.8)
plt.title('Q1 Sales by Product')
plt.xlabel('Products')
plt.ylabel('Sales (thousands)')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Grouped bar chart
x_pos = np.arange(len(categories))
width = 0.35

plt.figure(figsize=(12, 6))
bars1 = plt.bar(x_pos - width/2, sales_q1, width, label='Q1', color='lightblue')
bars2 = plt.bar(x_pos + width/2, sales_q2, width, label='Q2', color='orange')

# Add value labels on bars
for bar in bars1:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{height}', ha='center', va='bottom')

for bar in bars2:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{height}', ha='center', va='bottom')

plt.title('Quarterly Sales Comparison', fontsize=16, fontweight='bold')
plt.xlabel('Products')
plt.ylabel('Sales (thousands)')
plt.xticks(x_pos, categories)
plt.legend()
plt.tight_layout()
plt.show()

Scatter Plots



python

# Generate sample data
n = 100
x_scatter = np.random.randn(n)
y_scatter = 2 * x_scatter + np.random.randn(n)
colors = np.random.rand(n)
sizes = 1000 * np.random.rand(n)

plt.figure(figsize=(10, 8))
scatter = plt.scatter(x_scatter, y_scatter, c=colors, s=sizes, alpha=0.6, 
                     cmap='viridis', edgecolors='black', linewidth=0.5)

plt.title('Scatter Plot with Color and Size Mapping', fontsize=16)
plt.xlabel('X values')
plt.ylabel('Y values')

# Add colorbar
cbar = plt.colorbar(scatter)
cbar.set_label('Color Scale', rotation=270, labelpad=20)

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Histograms and Distributions



python

# Generate sample data
data1 = np.random.normal(100, 15, 1000)
data2 = np.random.normal(80, 20, 1000)

plt.figure(figsize=(12, 5))

# Subplot 1: Simple histogram
plt.subplot(1, 2, 1)
plt.hist(data1, bins=30, alpha=0.7, color='blue', edgecolor='black')
plt.title('Distribution 1')
plt.xlabel('Values')
plt.ylabel('Frequency')

# Subplot 2: Overlapped histograms
plt.subplot(1, 2, 2)
plt.hist(data1, bins=30, alpha=0.5, label='Dataset 1', color='blue')
plt.hist(data2, bins=30, alpha=0.5, label='Dataset 2', color='red')
plt.title('Overlapped Distributions')
plt.xlabel('Values')
plt.ylabel('Frequency')
plt.legend()

plt.tight_layout()
plt.show()

Advanced Customization

Custom Styling and Themes



python

# Custom style configuration
plt.style.use('seaborn-v0_8')  # Use seaborn style

# Create custom color palette
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']

# Advanced customization example
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

# Plot 1: Customized line plot
ax1.plot(x, y1, color=colors[0], linewidth=3, marker='o', markersize=4)
ax1.set_title('Styled Line Plot', fontsize=14, fontweight='bold')
ax1.set_xlabel('X values')
ax1.set_ylabel('Y values')
ax1.grid(True, linestyle=':', alpha=0.7)
ax1.set_facecolor('#f8f9fa')

# Plot 2: Customized bar chart
ax2.bar(range(len(sales_q1)), sales_q1, color=colors[:len(sales_q1)])
ax2.set_title('Styled Bar Chart', fontsize=14, fontweight='bold')
ax2.set_xlabel('Products')
ax2.set_ylabel('Sales')
ax2.set_xticks(range(len(categories)))
ax2.set_xticklabels(categories, rotation=45)

# Plot 3: Styled scatter plot
ax3.scatter(x_scatter, y_scatter, c=colors[2], alpha=0.6, s=50)
ax3.set_title('Styled Scatter Plot', fontsize=14, fontweight='bold')
ax3.set_xlabel('X values')
ax3.set_ylabel('Y values')

# Plot 4: Styled histogram
ax4.hist(data1, bins=20, color=colors[3], alpha=0.7, edgecolor='white')
ax4.set_title('Styled Histogram', fontsize=14, fontweight='bold')
ax4.set_xlabel('Values')
ax4.set_ylabel('Frequency')

plt.tight_layout()
plt.show()

Annotations and Text



python

# Create plot with annotations
plt.figure(figsize=(12, 8))

# Plot the data
plt.plot(x, y1, 'b-', linewidth=2, label='sin(x)')
plt.plot(x, y2, 'r--', linewidth=2, label='cos(x)')

# Add annotations
plt.annotate('Maximum', xy=(np.pi/2, 1), xytext=(2, 1.3),
            arrowprops=dict(arrowstyle='->', color='black', lw=1.5),
            fontsize=12, ha='center')

plt.annotate('Zero crossing', xy=(np.pi, 0), xytext=(4, 0.5),
            arrowprops=dict(arrowstyle='->', color='red', lw=1.5),
            fontsize=12, ha='center')

# Add text box
textstr = 'Mathematical Functions:\n$y = \sin(x)$\n$y = \cos(x)$'
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8)
plt.text(0.7, 0.95, textstr, transform=plt.gca().transAxes, fontsize=10,
        verticalalignment='top', bbox=props)

plt.title('Annotated Plot', fontsize=16, fontweight='bold')
plt.xlabel('X values')
plt.ylabel('Y values')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

Subplots and Layout



python

# Complex subplot layout
fig = plt.figure(figsize=(16, 10))

# Create a grid of subplots with different sizes
gs = fig.add_gridspec(3, 3, height_ratios=[1, 1, 1], width_ratios=[2, 1, 1])

# Large plot spanning multiple cells
ax_main = fig.add_subplot(gs[0:2, 0:2])
ax_main.plot(x, y1, 'b-', linewidth=2, label='sin(x)')
ax_main.plot(x, y2, 'r--', linewidth=2, label='cos(x)')
ax_main.set_title('Main Plot', fontsize=16, fontweight='bold')
ax_main.legend()
ax_main.grid(True, alpha=0.3)

# Histogram on the right
ax_hist = fig.add_subplot(gs[0, 2])
ax_hist.hist(data1, bins=20, orientation='horizontal', color='lightblue', alpha=0.7)
ax_hist.set_title('Distribution')

# Bar chart
ax_bar = fig.add_subplot(gs[1, 2])
ax_bar.bar(range(3), [10, 15, 12], color=['red', 'green', 'blue'])
ax_bar.set_title('Bar Chart')

# Bottom spanning plot
ax_bottom = fig.add_subplot(gs[2, :])
ax_bottom.scatter(x_scatter, y_scatter, alpha=0.6, c=colors[0])
ax_bottom.set_title('Bottom Scatter Plot', fontsize=14)
ax_bottom.set_xlabel('X values')

plt.tight_layout()
plt.show()

Real-World Example: Sales Dashboard



python

# Create a comprehensive sales dashboard
np.random.seed(42)

# Generate sample sales data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
products = ['Product A', 'Product B', 'Product C', 'Product D']

# Create data
monthly_sales = [45, 52, 48, 61, 58, 67]
product_sales = [120, 95, 87, 110]
daily_sales = np.random.normal(50, 10, 30)
profit_margin = [15, 12, 18, 22, 20, 25]

# Create dashboard
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

# 1. Monthly sales trend
ax1.plot(months, monthly_sales, marker='o', linewidth=3, markersize=8, 
         color='#2E86AB', markerfacecolor='#F24236')
ax1.fill_between(months, monthly_sales, alpha=0.3, color='#2E86AB')
ax1.set_title('Monthly Sales Trend', fontsize=16, fontweight='bold', pad=20)
ax1.set_ylabel('Sales (thousands $)', fontsize=12)
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, max(monthly_sales) * 1.1)

# Add value labels
for i, v in enumerate(monthly_sales):
    ax1.annotate(f'${v}k', (i, v), textcoords="offset points", 
                xytext=(0,10), ha='center', fontweight='bold')

# 2. Product sales comparison
bars = ax2.barh(products, product_sales, color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4'])
ax2.set_title('Sales by Product', fontsize=16, fontweight='bold', pad=20)
ax2.set_xlabel('Sales (thousands $)', fontsize=12)

# Add value labels on bars
for i, (bar, value) in enumerate(zip(bars, product_sales)):
    ax2.text(value + 2, i, f'${value}k', va='center', fontweight='bold')

# 3. Daily sales distribution
n, bins, patches = ax3.hist(daily_sales, bins=15, color='#FFEAA7', 
                           edgecolor='black', alpha=0.7)
ax3.axvline(np.mean(daily_sales), color='red', linestyle='--', linewidth=2, 
           label=f'Mean: ${np.mean(daily_sales):.1f}k')
ax3.set_title('Daily Sales Distribution (Last 30 Days)', fontsize=16, 
              fontweight='bold', pad=20)
ax3.set_xlabel('Daily Sales (thousands $)', fontsize=12)
ax3.set_ylabel('Frequency', fontsize=12)
ax3.legend()

# 4. Monthly profit margin
ax4_twin = ax4.twinx()
bars = ax4.bar(months, monthly_sales, alpha=0.6, color='#74A9CF', label='Sales')
line = ax4_twin.plot(months, profit_margin, color='#E74C3C', marker='o', 
                    linewidth=3, markersize=8, label='Profit Margin %')

ax4.set_title('Sales vs Profit Margin', fontsize=16, fontweight='bold', pad=20)
ax4.set_ylabel('Sales (thousands $)', fontsize=12, color='#74A9CF')
ax4_twin.set_ylabel('Profit Margin (%)', fontsize=12, color='#E74C3C')
ax4.tick_params(axis='y', labelcolor='#74A9CF')
ax4_twin.tick_params(axis='y', labelcolor='#E74C3C')

# Combine legends
lines1, labels1 = ax4.get_legend_handles_labels()
lines2, labels2 = ax4_twin.get_legend_handles_labels()
ax4.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

plt.tight_layout()
plt.suptitle('Sales Performance Dashboard - Q1 2024', 
             fontsize=20, fontweight='bold', y=1.02)
plt.show()

Performance Tips and Best Practices

Efficient Plotting



python

# Efficient plotting for large datasets
def efficient_large_plot():
    # Generate large dataset
    n = 100000
    x_large = np.random.randn(n)
    y_large = np.random.randn(n)
    
    # Use rasterized for large scatter plots
    plt.figure(figsize=(10, 6))
    plt.scatter(x_large, y_large, alpha=0.1, s=1, rasterized=True)
    plt.title('Large Dataset (100k points)', fontsize=14)
    plt.xlabel('X values')
    plt.ylabel('Y values')
    
    # Use tight_layout for better spacing
    plt.tight_layout()
    plt.show()

# Call the function
# efficient_large_plot()

# Memory-efficient plotting
def plot_with_context():
    """Use context managers for memory efficiency"""
    with plt.style.context('seaborn-v0_8'):
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.plot(x, y1)
        ax.set_title('Context Manager Example')
        plt.show()
        plt.close()  # Explicitly close to free memory

# Batch processing multiple plots
def create_multiple_plots():
    """Create multiple plots efficiently"""
    data_sets = [
        (x, y1, 'sin(x)', 'blue'),
        (x, y2, 'cos(x)', 'red'),
        (x, y3, 'noisy sin(x)', 'green')
    ]
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for ax, (x_data, y_data, title, color) in zip(axes, data_sets):
        ax.plot(x_data, y_data, color=color, linewidth=2)
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

create_multiple_plots()

Common Customization Patterns



python

# Define reusable styling functions
def apply_custom_style(ax, title, xlabel, ylabel):
    """Apply consistent styling to axes"""
    ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel(ylabel, fontsize=12)
    ax.grid(True, alpha=0.3, linestyle=':')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    return ax

def create_publication_ready_plot():
    """Create a publication-ready plot"""
    # Set publication style
    plt.rcParams.update({
        'font.size': 12,
        'axes.linewidth': 1.5,
        'axes.spines.top': False,
        'axes.spines.right': False,
        'xtick.major.size': 7,
        'xtick.minor.size': 4,
        'ytick.major.size': 7,
        'ytick.minor.size': 4
    })
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Plot data
    ax.plot(x, y1, linewidth=2.5, color='#1f77b4', label='sin(x)')
    ax.plot(x, y2, linewidth=2.5, color='#ff7f0e', linestyle='--', label='cos(x)')
    
    # Apply styling
    apply_custom_style(ax, 'Trigonometric Functions', 'x', 'f(x)')
    
    # Add legend with custom styling
    legend = ax.legend(loc='upper right', frameon=True, fancybox=True, 
                      shadow=True, fontsize=11)
    legend.get_frame().set_facecolor('white')
    legend.get_frame().set_alpha(0.9)
    
    plt.tight_layout()
    plt.show()

create_publication_ready_plot()