王美洁

Python 科学计算加速的几种方法

python 作为一种解释型语言,通常运行速度较慢,特别是在科学计算领域。为了提高运行效率,通常需要使用一些加速方法。本文将简单介绍如何使用这些方法加速 Python 科学计算。

首先,我们定义一个 Mandelbrot 函数,用于演示加速方法。

import numpy as np
import time

def MandNumba(ext, max_steps, Nx, Ny):
	data = np.ones((Nx, Ny)) * max_steps
	for i in range(Nx):
		for j in range(Ny):
			x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1.)
			y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
			z0 = x + y * 1j
			z = 0j
			for itr in range(max_steps):
				if abs(z) > 2.:
					data[j, i] = itr
					break
				z = z * z + z0
	return data
Nx = 1000
Ny = 1000
max_steps = 1000  # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandNumba(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)

clock time: 29.383343935012817

可以看到不进行任何加速需要 29.38s ,非常慢。

以下是几种加速该函数的Python科学计算方法:

  1. Numba JIT 并行加速
  2. 多进程并行
  3. NumPy 向量化

还有其他加速方法,如 CythonNumba CUDA 等,这些方法需要GPU硬件支持,这里不做介绍。

1. Numba JIT 并行加速

Numba 是一个开源的 JIT 编译器,可以将 Python 代码编译成机器码,从而提高运行速度。Numba 支持并行计算,可以通过 @numba.njit(parallel=True) 修饰器实现并行计算。

这里解释一下什么是修饰器,在 Python 中,修饰器(Decorator)是一种设计模式,用于在不修改原始函数代码的情况下,动态地给函数添加新的功能。修饰器本质上是一个函数,它接收一个函数作为参数,并返回一个新的函数。

在这里,@numba.njit(parallel=True) 修饰器将 MandNumba_parallel 函数编译成机器码,并实现并行计算。

import numpy as np
import numba
import time

@numba.njit(parallel=True)
def MandNumba_parallel(ext, max_steps, Nx, Ny):
	data = np.ones((Ny, Nx), dtype=np.int32) * max_steps
	for i in numba.prange(Nx):  # 并行外层循环
		for j in range(Ny):
			x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1)
			y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1)
			z0 = x + y * 1j
			z = 0j
			for itr in range(max_steps):
				if abs(z) > 2:
					data[j, i] = itr
					break
				z = z * z + z0
	return data
Nx = 1000
Ny = 1000
max_steps = 1000  # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandNumba_parallel(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)

clock time: 0.8841660022735596

这里只花了 0.88s ,加速了约 33 倍。

2. 多进程并行

多进程并行是一种利用多核 CPU 并行计算的方法,适合 CPU 密集型任务。

就是让每个人算一行,然后把结果合并起来。这是最朴素的并行计算方法,也是最容易实现的方法。

Python 提供了 multiprocessing 模块,可以很方便地实现多进程并行计算。

由于 multiprocessing 模块的限制,被并行计算的函数不能是类的成员函数,也不能是全局函数。因此,我们需要将 compute_row 函数单独放在一个文件中,然后通过 multiprocessing 模块调用。

compute_one_row.py 文件内容:

import numpy as np

def compute_row(ext, max_steps, Nx, Ny, row):
    result = np.empty(Ny, dtype=np.int64)
    for j in range(Ny):
        x = ext[0] + (ext[1] - ext[0]) * row / (Nx - 1.)
        y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
        z0 = x + y * 1j
        z = 0j
        for itr in range(max_steps):
            if abs(z) > 2.:
                result[j] = itr
                break
            z = z * z + z0
        else:
            result[j] = max_steps
    return result
import numpy as np
import multiprocessing as mp
from compute_one_row import compute_row
import time

def MandelMultiProcess(ext, max_steps, Nx, Ny):
    data = np.ones((Nx, Ny), dtype=np.int64) * max_steps
    with mp.Pool(processes=mp.cpu_count()) as pool:
        results = [pool.apply_async(compute_row, (ext, max_steps, Nx, Ny, i)) for i in range(Nx)]
        for i in range(Nx):
            data[i, :] = results[i].get()
    return data
Nx = 1000
Ny = 1000
max_steps = 1000  # 50

ext = [-2, 1, -1, 1]

t0 = time.time()
data = MandelMultiProcess(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)

clock time: 3.7328829765319824

我使用的是 Macbook Pro M2 Pro 版本,有6个性能核,4个能效核,总共10个核。所耗费的时间是 3.73s ,加速了约 8 倍。

3. NumPy 向量化

NumPy 是 Python 科学计算的基础库,提供了很多高效的数学函数和运算符。NumPy 的向量化操作可以大大提高运算速度。

下面的代码是将 MandNumba 函数向量化,使用 NumPy 的数组运算代替循环运算。

import numpy as np


def MandNumba_vectorized(ext, max_steps, Nx, Ny):
	x = np.linspace(ext[0], ext[1], Nx)
	y = np.linspace(ext[2], ext[3], Ny)
	X, Y = np.meshgrid(x, y)
	Z0 = X + Y * 1j
	Z = np.zeros_like(Z0)
	data = np.full(Z0.shape, max_steps, dtype=int)

	mask = np.ones_like(Z0, dtype=bool)
	for itr in range(max_steps):
		Z[mask] = Z[mask] ** 2 + Z0[mask]
		escaped = (np.abs(Z) > 2) & mask
		data[escaped] = itr
		mask[escaped] = False
		if not mask.any():
			break
	return data
Nx = 1000
Ny = 1000
max_steps = 1000  # 50

ext = [-2, 1, -1, 1]
t0 = time.time()
data = MandNumba_vectorized(np.array(ext), max_steps, Nx, Ny)
t1 = time.time()
print('clock time: ', t1 - t0)

clock time: 5.844282150268555

这里花费了 5.84s ,加速了约 5 倍。

性能对比说明:

  1. Numba JIT:通过即时编译优化循环,通常可获得10-100倍加速
  2. 多进程:适合CPU密集型任务,但进程间通信可能成为瓶颈
  3. 向量化:对小规模计算友好,但内存消耗随问题规模平方增长 可以看到,Numba JIT 是最快的,其本质是将 Python 代码编译成机器码,作弊变成C++fortran,消除了 Python 解释器的性能瓶颈。

最后感谢 DeepseekGithub Copilot 的帮助。

并行后的 mandf-dynamic.py 代码如下:

#%matplotlib notebook
import numpy as np
import pylab as plt
import time
import numba

@numba.njit(parallel=True)
def MandNumba(ext, max_steps, Nx, Ny):
    data = np.ones((Nx, Ny), dtype=np.int32) * max_steps
    for i in range(Nx):
        for j in range(Ny):
            x = ext[0] + (ext[1] - ext[0]) * i / (Nx - 1.)
            y = ext[2] + (ext[3] - ext[2]) * j / (Ny - 1.)
            z0 = x + y * 1j
            z = 0j
            for itr in range(max_steps):
                if abs(z) > 2.:
                    data[j, i] = itr
                    break
                z = z * z + z0
    return data

def ax_update(ax):  # actual plotting routine
    ax.set_autoscale_on(False) # Otherwise, infinite loop
    # Get the range for the new area
    xstart, ystart, xdelta, ydelta = ax.viewLim.bounds
    xend = xstart + xdelta
    yend = ystart + ydelta
    ext=np.array([xstart,xend,ystart,yend])
    data = MandNumba(ext, max_steps, Nx, Ny) # actually producing new fractal

    # Update the image object with our new data and extent
    im = ax.images[-1]  # take the latest object
    im.set_data(data)   # update it with new data
    im.set_extent(ext)           # change the extent
    ax.figure.canvas.draw_idle() # finally redraw

if __name__ == '__main__':
    Nx = 1000
    Ny = 1000
    max_steps = 1000 # 50

    ext = [-2,1,-1,1]

    t0 = time.time()
    data = MandNumba(np.array(ext), max_steps, Nx, Ny)
    t1 = time.time()
    print('clock time: ', t1-t0)

    fig,ax=plt.subplots(1,1)
    ax.imshow(data, extent=ext,aspect='equal',origin='lower',cmap='plasma')

    ax.callbacks.connect('xlim_changed', ax_update)
    ax.callbacks.connect('ylim_changed', ax_update)
    plt.show()

clock time: 1.4002361297607422