热力图

🔖 python
🔖 visualization
Author

Guangyao Zhao

Published

Apr 22, 2023

导入第三方库并产生数据

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

data = pd.DataFrame(data=np.random.randn(3, 5),
                    index=list("abc"),
                    columns=list("ABCDE"))
corr_matrix = data.corr()

使用 matplotlib 的imshow()

# 方法1
fig = plt.figure()
axes = fig.add_subplot(111)
ax_tmp = axes.imshow(corr_matrix.abs())  #! 显示绝对值
axes.set_xticks(range(len(corr_matrix.columns)), corr_matrix.columns)
axes.set_xticklabels(corr_matrix.columns, rotation=90)
axes.set_yticks(range(len(corr_matrix.columns)), corr_matrix.columns)
axes.tick_params(axis="both", which="both", bottom=False, left=False)
ax_tmp = fig.colorbar(ax_tmp, ax=axes)  # 绘制 colorbar
ax_tmp.ax.tick_params(axis="both", which="both", direction="out")
ax_tmp.ax.tick_params(axis="both", which="major", length=6)

使用 seaborn 的heatmap()

# 方法2
fig = plt.figure()
axes = fig.add_subplot(111)
ax = sns.heatmap(
    data=corr_matrix.abs(),
    ax=axes,  # 注意要将heatmap赋予到axes对象
    annot=True,  # 显示数值
    fmt=".2f",
    cbar_kws=dict(location="top"),
    cmap="YlGn",
    vmin=0,
    vmax=1,
)
axes.tick_params(axis="both", which="both", left=False, bottom=False)

cbar = ax.collections[0].colorbar  #! 获取colorbar对象
cbar.ax.tick_params(which="both", direction="out")
cbar.ax.tick_params(which="major", length=6)
cbar.ax.set_xlim(0, 1)
plt.xticks(rotation=60)  # 旋转x标签
plt.show()