标签 Matplotlib 下的文章

目录

  1. 库的概览与核心价值
  2. 环境搭建与"Hello, World"
  3. 核心概念解析
  4. 实战演练:分析电影评分趋势
  5. 最佳实践与常见陷阱
  6. 进阶指引

1. 库的概览与核心价值

想象一下,你手头有一份包含一百万条销售数据的 Excel 表格,密密麻麻的数字堆叠在一起,让你头晕眼花。你需要找出旺季和淡季的趋势,对比不同产品的销售表现,但这些冰冷的数据就像沉默的密码,让你难以快速洞察其中的规律。这就是数据可视化的痛点——没有图形,数据就是一堆难以理解的数字。

Matplotlib 正是为解决这个核心问题而生的强大工具。它就像一位精通绘画的数据翻译官,能将枯燥的数据转化为直观、生动的图表,让你一眼看出数据背后的故事。在 Python 数据科学生态中,NumPy 负责数值计算,Pandas 处理结构化数据,而 Matplotlib 则承担着将数据"可视化呈现"的关键使命,三者共同构成了数据分析的三剑客。

那么,为什么需要专门的 Matplotlib,而不是直接用 Excel 或其他工具呢?关键在于它的三个独特优势

  • 无缝集成MatplotlibNumPyPandas 完美兼容,你可以直接读取 DataFrame 或数组进行绘图,无需繁琐的数据导出导入
  • 高度可定制:从坐标轴刻度、图例位置到颜色、字体、线型,每一个细节都可以精细控制,满足论文发表、专业汇报的苛刻要求
  • 生态基石:作为 Python 可视化的开山鼻祖,它不仅是独立工具,更是 SeabornPlotly 等高级库的基础,学会了它,后续学习会更轻松

一句话总结:Matplotlib 让数据"说话",让复杂的规律变得一目了然,是每位数据分析师必备的看家本领。

2. 环境搭建与"Hello, World"

安装说明

安装 Matplotlib 非常简单,推荐使用 pipconda

# 使用 pip 安装(推荐)
pip install matplotlib numpy

# 使用 conda 安装
conda install matplotlib numpy

注意Matplotlib 通常与 NumPy 配合使用,建议同时安装。如果安装过程中遇到权限问题,可以尝试使用 --user 参数(pip)或创建虚拟环境。

最简示例

让我们用最经典的"正弦曲线"作为入门案例,只需 5 行代码就能画出一张漂亮的图表:

import matplotlib.pyplot as plt
import numpy as np

# 1. 准备数据:x从0到2π,取100个点
x = np.linspace(0, 2 * np.pi, 100)
y = np.sin(x)

# 2. 创建画布和绘图区域,并绘制曲线
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(x, y)

# 3. 添加标题和标签
ax.set_title("正弦函数图像")
ax.set_xlabel("x值(弧度)")
ax.set_ylabel("sin(x)")

# 4. 显示图表
plt.show()

逐行解释

  • 第1-2行:导入 pyplot 子模块(简写为 plt)和 NumPypyplotMatplotlib 的高级接口,提供了类似 MATLAB 的绘图函数,是日常绘图最常用的模块。
  • 第4行np.linspace(0, 2*np.pi, 100) 生成从 0 到 2π 的 100 个等间距点,这是 NumPy 的核心函数,非常适合生成连续变化的 x 轴数据。
  • 第5行np.sin(x) 计算 x 数组中每个元素的正弦值,返回对应的 y 数组。NumPy 的数学运算会自动应用到数组的每个元素,无需循环。
  • 第8行plt.subplots(figsize=(8, 4)) 同时创建 Figure(画布)和 Axes(坐标轴)对象。figsize 参数设置画布大小为 8 英寸宽、4 英寸高。推荐使用 subplots() 而非单独创建,因为它更高效且符合面向对象风格。
  • 第9行ax.plot(x, y)Axes 对象上绘制折线图。这是最核心的绘图函数,将 x 和 y 数组连接成一条平滑的曲线。
  • 第12-14行set_title()set_xlabel()set_ylabel() 分别设置图表标题、x 轴标签和 y 轴标签。所有以 set_ 开头的方法都是在配置 Axes 的属性。
  • 第17行plt.show() 弹出窗口显示图表。在 Jupyter Notebook 中,可以省略这行代码直接在单元格中显示。

预期输出:运行后会弹出一个窗口,展示一条波浪状的正弦曲线,x 轴范围是 0 到 2π,y 轴范围是 -1 到 1,曲线从原点出发,先上升到 1(π/2 处),下降到 -1(3π/2 处),最后回到 0(2π 处)。

解决中文显示问题

Matplotlib 默认不支持中文,会导致中文显示为方块。需要在导入后添加以下配置:

import matplotlib.pyplot as plt
import matplotlib

# 设置中文字体(Windows 用 SimHei,Mac 用 Arial Unicode MS)
plt.rcParams['font.sans-serif'] = ['SimHei']
# 解决负号显示为方块的问题
plt.rcParams['axes.unicode_minus'] = False

3. 核心概念解析

理解 Matplotlib 的核心概念是掌握它的关键。新手容易混淆的主要是以下四个对象,它们之间的关系就像画画工具的层级:

3.1 Figure(画布)

Figure 是整个图表的容器,相当于一张白纸或画框。一个 Figure 可以包含多个 Axes(子图),它负责管理整个图像的尺寸、背景色、边框等全局属性。你可以把 Figure 想象成一个画板,所有的图表元素都画在这个画板上。

fig = plt.figure(figsize=(10, 6), facecolor='lightgray')

3.2 Axes(坐标轴/子图)

Axes 是实际绘图的区域,每个 Axes 都包含独立的坐标系(x 轴、y 轴)、标题、标签、图例等元素。一个 Figure 可以有多个 Axes(比如 2×2 的子图布局),但每个 Axes 只能属于一个 Figure。你可以把 Axes 想象成画板上的一个画框,具体的线条、点、文字都画在这个画框里。

fig, ax = plt.subplots()  # 创建包含一个 Axes 的 Figure
fig, axs = plt.subplots(2, 2)  # 创建包含 2×2 个 Axes 的 Figure

3.3 Axis(坐标轴对象)

每个 Axes 包含两个(或 3D 图中的三个)Axis 对象,分别代表 x 轴和 y 轴。Axis 负责控制刻度(ticks)、刻度标签(tick labels)、坐标轴范围(limits)等。比如 x 轴的刻度位置是 0、π/2、π、3π/2、2π,刻度标签就是对应的数字。

ax.set_xlim(0, 10)  # 设置 x 轴范围
ax.set_xticks([0, 5, 10])  # 设置 x 轴刻度位置
ax.set_xticklabels(['起点', '中点', '终点'])  # 设置刻度标签

3.4 Artist(艺术家对象)

Artist 是所有可见元素的统称,包括线条(Line2D)、文本(Text)、矩形(Rectangle)、图例(Legend)等。FigureAxesAxis 本身也是 Artist。当调用 plt.show()plt.savefig() 时,所有 Artist 会被渲染到画布上。

line, = ax.plot([1, 2, 3], [4, 5, 6])  # line 是一个 Line2D Artist
title = ax.set_title("标题")  # title 是一个 Text Artist

核心概念关系图

以下 Mermaid 图表展示了这些核心对象之间的层次关系:

graph TD
    A[Figure<br/>画布容器] --> B[Axes<br/>绘图区域1]
    A --> C[Axes<br/>绘图区域2]
    A --> D[Axes<br/>绘图区域N]
    B --> E[Axis X<br/>X轴对象]
    B --> F[Axis Y<br/>Y轴对象]
    B --> G[Line2D<br/>线条]
    B --> H[Text<br/>标题/标签]
    B --> I[Legend<br/>图例]
    E --> J[刻度]
    E --> K[刻度标签]
    F --> L[刻度]
    F --> M[刻度标签]

这个图清晰地展示了:

  • Figure 是最顶层容器,可以包含多个 Axes
  • 每个 Axes 包含 Axis 对象和具体的 Artist 元素
  • Axis 负责刻度和标签管理
  • 所有的 Artist 最终渲染到 Figure

记住一句话:我们绘图时,先创建 Figure,再在 Figure 上添加 Axes,最后在 Axes 上调用绘图方法(如 plot()scatter()bar()),然后通过 set_xxx() 方法配置样式,最后用 plt.show()plt.savefig() 展示或保存图表。

4. 实战演练:分析电影评分趋势

需求分析

假设我们有一份电影数据集,包含电影类型、评分、上映年份等信息。我们需要分析不同类型电影的平均评分趋势,找出评分最高和最低的电影类型,并用可视化方式展示结果。这个任务涉及数据统计、多系列折线图绘制、图例和标签设置等核心技能。

方案设计

我们将按以下步骤实现:

  1. 生成模拟数据(包含电影类型、评分、年份)
  2. 按类型和年份分组计算平均评分
  3. 使用 Matplotlib 绘制多系列折线图,每种类型一条曲线
  4. 添加图例、标题、标签,美化图表样式
  5. 保存为高清图片

这个案例将练习以下核心功能:DataFrame 分组统计、subplots 多图布局、plot 折线图、图例和标签设置、样式定制、图片保存。

完整代码实现

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# ===== 步骤1:生成模拟数据 =====
np.random.seed(42)  # 确保结果可复现

# 电影类型列表
genres = ['剧情', '动作', '喜剧', '科幻', '恐怖', '爱情']
n_movies = 1000  # 总电影数

# 生成随机数据
data = {
    'genre': np.random.choice(genres, n_movies),
    'year': np.random.randint(2010, 2024, n_movies),
    'rating': np.random.uniform(3.0, 9.0, n_movies)  # 评分3.0-9.0
}
df = pd.DataFrame(data)

# ===== 步骤2:数据统计 =====
# 按类型和年份分组,计算平均评分
grouped = df.groupby(['genre', 'year'])['rating'].mean().reset_index()

# 将数据转换为更适合绘图的格式:每种类型一个 Series
pivot_data = grouped.pivot(index='year', columns='genre', values='rating')

# ===== 步骤3:创建图表 =====
fig, ax = plt.subplots(figsize=(12, 6))

# 为每种类型绘制一条曲线,使用不同颜色和标记
colors = plt.cm.tab10(np.linspace(0, 1, len(genres)))
markers = ['o', 's', '^', 'D', 'v', 'p']

for i, genre in enumerate(genres):
    if genre in pivot_data.columns:
        ax.plot(pivot_data.index, pivot_data[genre],
                color=colors[i],
                marker=markers[i],
                markersize=6,
                linewidth=2,
                label=genre)

# ===== 步骤4:美化图表 =====
ax.set_title('2010-2023年各类型电影平均评分趋势',
             fontsize=16, pad=20)
ax.set_xlabel('年份', fontsize=12)
ax.set_ylabel('平均评分', fontsize=12)

# 设置 x 轴刻度为每年一个
ax.set_xticks(range(2010, 2024))
ax.set_xticklabels([str(year) for year in range(2010, 2024)],
                   rotation=45, ha='right')

# 设置 y 轴范围,突出差异
ax.set_ylim(3.0, 9.0)
ax.grid(True, linestyle='--', alpha=0.3)

# 添加图例
ax.legend(loc='upper left', fontsize=10, ncol=3)

# 添加参考线(平均分)
avg_rating = df['rating'].mean()
ax.axhline(y=avg_rating, color='red', linestyle=':',
           linewidth=1.5, label=f'总体平均分 ({avg_rating:.2f})')

# ===== 步骤5:保存和显示 =====
plt.tight_layout()  # 自动调整布局,避免标签被截断
plt.savefig('movie_rating_trend.png', dpi=300, bbox_inches='tight')
print("图表已保存为 movie_rating_trend.png")
plt.show()

运行说明

  1. 将上述代码保存为 movie_analysis.py 文件
  2. 确保已安装依赖:pip install matplotlib numpy pandas
  3. 运行命令:python movie_analysis.py
  4. 程序会弹出窗口显示图表,并在当前目录下生成 movie_rating_trend.png 高清图片

结果展示

生成的图表将展示:

  • 6条折线:每种电影类型一条曲线,用不同颜色和标记区分
  • x 轴:2010-2023 年,每年一个刻度,标签旋转 45 度避免重叠
  • y 轴:评分范围 3.0-9.0,突出评分差异
  • 红色虚线:总体平均分参考线,便于对比
  • 图例:显示所有类型和参考线,位于左上角,分 3 列排列
  • 网格线:浅灰色虚线,辅助读取数据

这个案例展示了 Matplotlib 的核心能力:数据处理与可视化的无缝结合、多系列图表绘制、样式精细控制、专业级图表输出。掌握了这些技能,你就能应对大多数数据可视化任务。

5. 最佳实践与常见陷阱

常见错误及规避方法

错误1:混淆 FigureAxes

问题描述:直接使用 plt.plot() 绘图,却不知道"画在哪个 Axes 上",导致多图布局混乱。

# ❌ 错误做法:使用 pyplot 状态机,难以控制
plt.plot(x, y1)  # 自动创建 fig1 和 ax1
plt.figure()     # 新建 fig2
plt.plot(x, y2)  # 画在 fig2 的 ax2 上,但 ax1 无法再修改
# ✅ 正确做法:手动创建 Axes,精准控制
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.plot(x, y1)
ax1.set_title('图表1')
ax2.plot(x, y2)
ax2.set_title('图表2')

原因plt 是便捷接口,会自动创建和管理对象,但复杂绘图时容易失控。面向对象风格更清晰、更可控。

错误2:保存图表的顺序错误

问题描述:先 plt.show()plt.savefig(),保存的是空白图片!

# ❌ 错误做法
plt.show()           # 弹出窗口并释放资源
plt.savefig('plot.png')  # 此时 Figure 已为空,保存空白
# ✅ 正确做法
plt.savefig('plot.png', dpi=300, bbox_inches='tight')  # 先保存
plt.show()            # 再显示

原因plt.show() 会弹出窗口并释放绘图资源,之后再调用 savefig()Figure 已为空。必须先保存再显示。

错误3:中文显示乱码

问题描述:图表中的中文显示为方块,无法识别。

# ❌ 错误做法:未配置字体
plt.title('电影评分趋势')  # 显示为方块
# ✅ 正确做法:配置中文字体
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']  # Windows 用黑体
# plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']  # Mac 用这个
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示为方块
plt.title('电影评分趋势')  # 正确显示中文

原因Matplotlib 默认字体不支持中文,axes.unicode_minus 也需设置为 False 否则负号会乱码。

错误4:误解 figsize 的单位

问题描述:以为 figsize=(8, 4) 表示 8 像素×4 像素,结果图片太小。

# ❌ 错误理解
fig = plt.figure(figsize=(8, 4))  # 不是 8×4 像素!
# ✅ 正确理解
fig = plt.figure(figsize=(8, 4), dpi=100)  # 8英寸×4英寸,dpi=100,实际是 800×400 像素
# 想要 800×400 像素,要么设置 dpi=100,要么设置 figsize=(8, 4) 且 dpi=100

原因figsize 的单位是"英寸",而非像素。最终像素数 = figsize × dpi(默认 dpi=100)。

最佳实践建议

  1. 优先使用面向对象接口:虽然 plt.plot() 更简洁,但复杂场景(如多子图、自定义样式)必须用 fig, ax = plt.subplots() 面向对象风格。
  2. 统一配置字体和样式:在脚本开头一次性设置 rcParams,避免每个图表都重复配置。
  3. 养成使用 tight_layout() 的习惯:自动调整子图间距,避免标签被截断。
  4. 合理设置 dpi 参数:保存图片时 dpi=300 适合打印,dpi=150 适合屏幕显示,dpi=72 适合网页。
  5. 利用 colormaps 自动生成配色:不要手动指定颜色列表(如 ['red', 'blue', 'green']),用 plt.cm.tab10plt.cm.viridis 生成专业配色。
  6. 保存图片时使用 bbox_inches='tight':自动裁剪空白边距,让图片更紧凑。
  7. 多图布局时用 subplots_adjust 微调:当 tight_layout() 不能满足需求时,手动调整 left, right, top, bottom, wspace, hspace 参数。
  8. 避免使用过时的 API:如 plt.axes() 已被 plt.subplots() 替代,plt.hold() 已在新版本中移除。

6. 进阶指引

掌握了基础用法后,你可以继续探索 Matplotlib 的高级功能和生态系统:

高级功能

  • 多子图复杂布局:使用 plt.subplot_mosaic() 创建非网格状布局(如左大右小、上一下三等)
  • 3D 可视化:使用 mpl_toolkits.mplot3d 绘制三维曲面图、散点图
  • 动画制作:使用 matplotlib.animation 模块制作动态图表,展示数据变化过程
  • 交互式可视化:结合 ipywidgets 在 Jupyter Notebook 中实现滑块、下拉框等交互控件

生态扩展

  • Seaborn:基于 Matplotlib 的高级库,提供更简洁的 API 和更美观的默认样式,适合快速生成统计图表
  • Plotly:专注于交互式可视化,生成的图表支持缩放、拖拽、悬停查看数据,适合网页展示
  • Cartopy:地理数据可视化,支持地图投影、地理坐标转换等

学习资源

学习路径建议

  1. 第一阶段(1-2周):熟练掌握折线图、柱状图、散点图、饼图、直方图 5 种基础图表
  2. 第二阶段(2-3周):学会多子图布局、样式定制、图例标签设置
  3. 第三阶段(3-4周):尝试 3D 可视化、动画制作、交互式图表
  4. 第四阶段(持续):结合实际项目(如个人数据分析、Kaggle 比赛),在实战中积累经验

记住:Matplotlib 的核心是"多动手实践"。找一份真实数据(如公开数据集、个人消费记录),尝试用不同图表展示,逐步掌握参数调整和样式优化。从基础图表到专业可视化,Matplotlib 能伴随你从数据分析新手成长为可视化高手。

Matrix 首页推荐 

Matrix 是少数派的写作社区,我们主张分享真实的产品体验,有实用价值的经验与思考。我们会不定期挑选 Matrix 最优质的文章,展示来自用户的最真实的体验和观点。 

文章代表作者个人观点,少数派仅对标题和排版略作修改。


最近尝试用 Python 和 Matplotlib 从零手写复刻了一下 Pluribus 的片头。先看看效果:

1. 前言

最近看了 Apple TV 的一部剧叫 Pluribus。我很喜欢这部剧,原因有二:

  • 它核心概念里的 "Joining" 和 《EVA》里的 「人类补完计划」 非常像,很对我的胃口;
  • 剧情探讨了人类和 AI 的关系,也是我最近一直在深度思考的问题 <(")

除去剧情,我特别喜欢它的片头。极简但非常抓眼球,完全就是我的菜。Apple TV 的片头通常都很复杂且暗示剧情走向(比如《人生切割术》或者《羊毛战记》),但这一个很特别。这也是我第一次觉得「哎,这个我好像能用代码写出来」的片头 :>

2. 粒子系统 (Particle System)

因为我从来没碰过粒子系统,对计算机视觉也知之甚少,所以上手第一步就是先读几篇文章。下面这两个资源对拆解概念非常有帮助:

简单来说,我只需要一堆点,然后追踪它们的物理状态:位置、速度和加速度。

class Particle:
    def __init__(self, pos: (int, int), 
                 velocities: (int, int), 
                 accelerations: (int, int)):
        self.pos = pos
        self.vel = velocities
        self.acc = accelerations

套用高中物理学过的标准公式:

写个函数来更新这些值:

def pos_update(dot, dt):
    dot.pos = (
        dot.pos[0] + dot.vel[0] * dt,
        dot.pos[1] + dot.vel[1] * dt
        )
    dot.vel = (dot.vel[0] + dot.acc[0] * dt,
                dot.vel[1] + dot.acc[1] * dt)

对每个点跑这个循环,就能得到一个基础的粒子系统(渲染代码略过不表,不过这里有个很好的 matplotlib 动画教程)。

最后,给每个点加点随机力。假设质量(m)为 1,根据 F=ma,我们可以直接把随机值加到加速度上:

def force_apply(p: Particle):
    p.acc = (
        p.acc[0] + random.randint(-2, 2), 
        p.acc[1] + random.randint(-2, 2)
        )

def dots_update(dots, dt):
    for dot in dots:
        pos_update(dot, dt)
        force_apply(dot)
    return

初始化网格里的点之后,大概长这样:

3. 背景点 (Background-dots)

把片头看了五遍以后,我发现里面的点可以分为三类,各个击破:

  • 背景点 (Background-dots)
  • 圆圈点 (Circle-dots)
  • 文字点 (Text-dots)

对于背景点,简单的随机运动看着不自然。如果你仔细看(现在是第六遍了 :D),会发现它们之间是有交互的。基本上就是太近了会推开,太远了会拉近。我发现 Lennard-Jones 势能完美描述了这个行为:

简单说就是距离太近会排斥,距离远了(但在范围内)会吸引。就像下图这个曲线。(我是从这个博客学来的)。

实现起来也很简单,遍历每一对点应用这个力就行,复杂度是 O(n^2)。

def lj_force(p1, p2):
    dx = p1.pos[0] - p2.pos[0]
    dy = p1.pos[1] - p2.pos[1]
    dis = (dx**2 + dy**2) ** 0.5

    dx_dir = dx / dis
    dy_dir = dy / dis

    u = min(10, 4 * EPI * ((SIGMA/dis)**12 - (SIGMA/dis)**6))

    dx_acc = u * dx_dir / 1
    dy_acc = u * dy_dir / 1

    p1.acc = (p1.acc[0]+dx_acc, p1.acc[1]+dy_acc)
    p2.acc = (p2.acc[0]-dx_acc, p2.acc[1]-dy_acc)

加上 LJ 势能后的效果如下。能明显看到点之间相互作用产生的复杂运动。

4. 圆圈点 (Circle-dots)

加圆圈点之前,先快速复习一下如何在粒子系统中定义方向和距离。(记得的同学可以跳过 :O)

基本上给定一个角度 θ∈[0,2π) 我们可以得到方向的单位向量 ​dir_x=cos(θ) dir_y=sin(θ)​。给定两个点,我们可以得到从 p1​ 到 p2​ 的方向:

要得到方向(单位向量),我们用差值除以距离:

加圆圈点很容易。给个初始速度,按 2π(360度)均匀分布方向就行。

def add_wave(dots):
    for i in range(WAVE_DOTS_NUM):
        angle = 2 * math.pi * i / WAVE_DOTS_NUM
        
        pos = (WAVE_ORIGIN[0] + math.cos(angle)*5, 
            WAVE_ORIGIN[1] + math.sin(angle)*5)
        
        vx = WAVE_SPEED * math.cos(angle)
        vy = WAVE_SPEED * math.sin(angle)
            
        dots.append(Particle(pos, velocities=(vx, vy)))
  • 碰撞问题: 但这里有个坑。因为我们加了 LJ 力,背景点会和圆圈点互怼。圆圈扩大的时候,撞到背景点会被推歪,形状就散了。
  • 解决方案: 我的解法简单粗暴:给 Particle 类加个 mass(质量)属性。让圆圈点比背景点重得多,它们惯性就大,不容易被推跑。

更新物理计算遵循牛顿第二定律 (a=F/m)。基本就是更新速度的时候,把累计的力(加速度)除以质量:

def pos_update(dot, dt):
    dot.pos = (
        dot.pos[0] + dot.vel[0] * dt,
        dot.pos[1] + dot.vel[1] * dt
        )
    dot.vel = (
        dot.vel[0] + dot.acc[0] * dt / dot.mass,
        dot.vel[1] + dot.acc[1] * dt / dot.mass
        )

对比一下(左:无质量,右:有质量)。

加了质量以后看着舒服多了吧?能明显看到圆圈点把背景点推开,自己还能保持队形。

5. 文字点 (Text-dots)

用点渲染文字不难。找个字体(我用的 Arial)画出来,然后提取像素位置就行。

def get_text_draw(text = TEXT, font_path = FONT_PATH):
    mask_img = Image.new("L", (WIDTH, LENGTH), 0)
    draw = ImageDraw.Draw(mask_img)
    font = ImageFont.truetype(font_path, 35)

    bbox = draw.textbbox((0, 0), text, font=font)
    text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
    draw.text(((WIDTH - text_w) // 2, (LENGTH - text_h) // 2 - 5), text, fill=255, font=font)
    y_coords, x_coords = np.where(np.array(mask_img)[::-1] > 128)
    return x_coords, y_coords

难点在于做那个「指纹」图案。仔细看原片,它像个波浪,稍微有点不规则。为了简单,我用 sine wave 模拟:

基本上就是根据距离中心的远近推拉这些点。调整频率能搞出不同的环形图案。下图是 freq={1,4,7} 的效果。

def set_fingerprint(x, y, freq = RADIAL_FREQ, strengh = RADIAL_STRENGTH):
    dx = x_coords - RADIAL_ORIGIN[0]
    dy = y_coords - RADIAL_ORIGIN[1]

    dist = np.sqrt(dx**2 + dy**2)
    angle = np.arctan2(dy, dx)

    push = np.sin(dist * freq) * strengh

    x_new = x_coords + (np.cos(angle) * push)
    y_new = y_coords + (np.sin(angle) * push)
    return x_new, y_new

如下是从点 P(25,42) 发起正弦波应用到文字的效果。

其实调这个波的参数花了我好久。试了各种组合,最后选了个看着最舒服的。^_^

把所有东西合在一起,就有了第一版片头!8)

6. 性能优化 (Performance Optimization)

先停一下。目前渲染60帧要跑6分钟。感觉我在浪费生命等它跑完 :( 是时候做点优化了。

6.1 空间哈希 (Spatial Hashing)

前面说了,瓶颈在物理交互计算,复杂度 O(n2)。加上文字点和不断生成的圆圈点,数量轻松上千,意味着每帧要做 10^6 次距离检测。

我的解法是用空间哈希(分桶),把空间划成网格,只计算相邻网格里粒子的 LJ 力。灵感来自第 3 节的公式:距离 ≥3σ 时势能几乎归零。

我用哈希表记录每个点属于哪个格子:

def _bin_coords(pos):
    return int(pos[0]) // BIN_SIZE, int(pos[1]) // BIN_SIZE

def _build_bins(dots):
    bins = {}
    for idx, p in enumerate(dots):
        bx, by = _bin_coords(p.pos)
        if 0 <= bx < BIN_XNUM and 0 <= by < BIN_YNUM:
            bins.setdefault((bx, by), []).append(idx)
    return bins

这一改,速度提升了 5 倍,渲染时间从 6 分 10 秒降到了 1 分 06 秒。

(虽然我知道用树结构——类似二叉索引树——动态维护位置能把复杂度降到 O(nlogn),毕竟最近在刷 LeetCode。但网格法目前够用了。)

6.2 生命周期管理

另一个优化是控制点的生命周期。圆圈点飞出屏幕(「越界」)后就不用算了。我加了个定期清理。这对减少内存占用很有效,之前内存都飙到 10GB 了。

def prune_dots(dots, circles, margin=50):
    alive_dots = []
    alive_circles = []

    for dot, circle in zip(dots, circles):
        x, y = dot.pos
        if -margin < x < WIDTH + margin and -margin < y < LENGTH + margin:
            alive_dots.append(dot)
            alive_circles.append(circle)
        else:
            circle.remove()

    dots[:] = alive_dots
    circles[:] = alive_circles

我很确定用内存池(链表+哈希表)能做到 O(1) 的插入删除,但对于这个项目有点杀鸡用牛刀了 :/

7. 视觉打磨 (Visual Optimization)

接下来打磨一下视觉效果。

7.1 文字形状

第一个问题是文字时间长了会「糊」掉或者散架。因为点挤得太紧,LJ 势能把它们推开了,导致我们(搞了半天的)指纹纹理丢了。

解决办法很简单:加个锚点力 (Anchor Force)。就像个弹簧,点飘太远了就把它拽回原位。我还加了点阻尼(摩擦力)防止它震荡个没完。

def anchor_force(p):
    dx = p.anchor[0] - p.pos[0]
    dy = p.anchor[1] - p.pos[1]
    dis = (dx**2 + dy**2) ** 0.5
    dx_dir = dx / dis
    dy_dir = dy / dis

    f = dis * ANCHOR_STRENGH

    damping_fx = -p.vel[0] * DAMPING
    damping_fy = -p.vel[1] * DAMPING

    p.acc = (
        p.acc[0] + (f * dx_dir + damping_fx) * random.randrange(5, 10) / 10, 
        p.acc[1] + (f * dy_dir + damping_fy) * random.randrange(5, 10) / 10
        )

7.2 呼吸与循环

另一个改进是给背景点加个「呼吸」效果,大小有节奏地缩放。给每个粒子加个相位属性,用正弦波更新就行。

最后,为了防止背景点飞出屏幕,我做了个屏幕循环 (Screen wrapping)。点从右边出去,就从左边回来。

def pos_update(dot, dt):
    dot.pos = (
        dot.pos[0] + dot.vel[0] * dt,
        dot.pos[1] + dot.vel[1] * dt
    )
    dot.vel = (
        dot.vel[0] + dot.acc[0] * dt / dot.mass,
        dot.vel[1] + dot.acc[1] * dt / dot.mass
    )
    dot.acc = (0, 0)

    dot.phase = (dot.phase + PHASE_INCREMENT) % (2 * math.pi)
    sine_wave = (math.sin(dot.phase) + 1) / 2

    if dot.type == 0:
        ## Keep background dots
        dot.vel = (dot.vel[0] * DECAY_RATIO, dot.vel[1] * DECAY_RATIO)
        dot.pos = (dot.pos[0] % WIDTH, dot.pos[1] % LENGTH)
        ## Change their size periodically
        dot.radius = 0.5 * (0.4 + 0.6 * sine_wave)
    else:
        dot.radius = 0.5 * (0.9 + 0.1 * sine_wave)

效果图解:

当然你也可以把文字换成任何你想要的:

8. 总结

这其实是我第一次尝试写粒子系统。本来计划在剧终(圣诞节)前搞定,但我高估了自己旅行时的精力和专注度。说实话,理解原理并实现它确实花了我不少时间。

相比之下,我看很多人用 Gemini 生成那种酷炫的 web 端 3D 粒子系统。跟那些比,我这个可能显得简陋甚至有点「丑」。但对我来说,从零构建的这个过程要更 enjoyable,虽然这肯定不是最高效的方法。最后,我觉得这种感觉大概也就是《Pluribus》想表达的东西吧。 :V