Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

Python——strip函数用法注意事项


整体说明

  • strip() 方法容易误解,使用时可能会发生未知问题,所以要小心
  • strip() 方法会从字符串的开头和结尾移除指定的字符,但它是逐个字符进行匹配的,而不是匹配整个子字符串
    • 即把输入参数当做一个字符集合(注意是一个字符集合而不是一个字符串)
    • 后续的文字只要在这个集合内都会被清除掉,知道遇到第一个不在这个字符集合的字符为止

具体示例分析

  • 下面是一个简单示例:

    1
    2
    string = "SYSTEM:You are an AI assistant. You will be given a task.  "
    print(string.strip('SYSTEM:'))
    • 上面的句子输出是 "ou are an AI assistant. You will be given a task. "(注意:"SYSTEM:Y" 都被删除了,不仅仅是 "SYSTEM:")
  • 具体来说,当执行 strip('SYSTEM:') 时:

    • strip() 会把 'SYSTEM:' 看作是一个字符集合 :{'S', 'Y', 'S', 'T', 'E', 'M', ':'}
    • 从字符串开头开始,逐个检查字符是否在这个集合中:
      • 'S' 在集合中,移除
      • 'Y' 在集合中,移除
      • 'S' 在集合中,移除
      • 'T' 在集合中,移除
      • 'E' 在集合中,移除
      • 'M' 在集合中,移除
      • ':' 在集合中,移除
      • 'Y' 在集合中,移除(这里是 “You” 的 ‘Y’)
      • 'o' 不在集合中,停止移除,后续的字符都不再移除
  • 所以最终结果是 "ou are an AI assistant. You will be given a task. "

  • 如果你想移除特定的前缀字符串,应该使用:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    # 方法1:使用 removeprefix() (Python 3.9+)
    result = string.removeprefix('SYSTEM:')

    # 方法2:使用字符串切片
    if string.startswith('SYSTEM:'):
    result = string[7:] # 'SYSTEM:' 长度为7

    # 方法3:使用 replace() (但要小心,会替换所有匹配项)
    result = string.replace('SYSTEM:', '', 1) # 只替换第一个
  • 这样就能得到正确的结果:"You are an AI assistant. You will be given a task. "

Python——signal.signal函数处理OS信号


整体说明

  • signal.signal() 是 Python 中用于设置信号处理方式的核心函数
  • signal.signal()主要用于捕获和处理操作系统发送的各种信号(如中断信号、子进程状态变化信号等)

signal.signal 函数说明

  • 基本语法如下:

    1
    2
    3
    import signal

    signal.signal(signalnum, handler)
  • 参数说明:

    • signalnum:需要处理的信号类型(如 signal.SIGINT 表示键盘中断,signal.SIGCHLD 表示子进程状态变化等)
    • handler:信号处理函数(或特殊常量),用于定义收到信号后的行为

信号类型(signalnum)

  • 常见的信号常量包括:
    • signal.SIGINT:程序终止(Interrupt)信号,通常由 Ctrl+C 触发
    • signal.SIGTERM:终止请求,默认可以被捕获和处理
    • signal.SIGCHLD:子进程状态改变(终止、暂停等)时触发
    • signal.SIGKILL:强制终止信号(无法被捕获或忽略)
    • signal.SIGSTOP:暂停信号(无法被捕获或忽略)
  • 完整信号列表可通过 signal.__dict__ 查看
    • 也可以参考 Python 官方文档: signal — Set handlers for asynchronous events

处理函数(handler)

  • handler 可以是以下三种类型之一:

  • 自定义函数类型:

    • 函数需接收两个参数:signum(信号编号)和 frame(当前栈帧对象,可选)
    • 使用示例:
      1
      2
      3
      4
      5
      6
      def handle_sigint(signum, frame):
      print(f"\n收到信号 {signum},程序即将退出")
      exit(0)

      # 注册 SIGINT 信号的处理函数
      signal.signal(signal.SIGINT, handle_sigint)
  • 使用默认处理方式 ,即取值 signal.SIG_DFL

    • 此时表示使用默认处理方式(系统预设行为,如 SIGINT 默认终止程序)
    • 使用示例:
      1
      2
      # 恢复 SIGINT 的默认处理(取消自定义函数)
      signal.signal(signal.SIGINT, signal.SIG_DFL)
  • 使用忽略该信号处理方式 ,即取值 signal.SIG_IGN

    • 表示忽略该信号(不做任何处理)
    • 使用示例:
      1
      2
      # 忽略 SIGINT 信号(Ctrl+C 无效)
      signal.signal(signal.SIGINT, signal.SIG_IGN)

使用注意事项

  • 信号处理函数应尽量简单,避免包含复杂逻辑(如 I/O 操作、长时间阻塞等),否则可能导致程序不稳定
  • 部分信号(如 SIGKILL、SIGSTOP)无法被捕获或忽略,用于强制终止/暂停程序
  • 信号处理函数通常在主线程中执行,多线程程序中需谨慎处理信号,避免线程安全问题
  • 部分信号(如 SIGCHLD)在 Windows 系统中可能不支持,使用时需注意平台兼容性

Python——tqdm库使用


整体说明

  • Tqdm 是一个快速、可扩展的 Python 进度条库,可以轻松地为你的循环添加一个智能进度条,让你直观地了解任务的执行进度
  • Tqdm 的名字来源于阿拉伯语 “taqaddum”,意为“进展”
  • Tqdm 可以通过一行 pip 指令安装:
    1
    pip install tqdm

Tqdm 的最常用用法(自动控制)

  • Tqdm 最核心的用法就是将可迭代对象包装在 tqdm() 函数中

    1
    2
    3
    4
    5
    6
    7
    from tqdm import tqdm
    import time

    # 循环 100 次,每次暂停 0.01 秒
    for i in tqdm(range(100)):
    time.sleep(0.01)
    # 100%|██████████| 100/100 [00:01<00:00, 83.01it/s]
    • 运行这段代码,你会看到一个实时的进度条,显示循环的完成百分比、已完成的迭代次数、总迭代次数、每秒迭代次数以及预计剩余时间

Tqdm 的常用参数

  • desc: 给进度条添加一个描述性前缀

    1
    2
    3
    for i in tqdm(range(100), desc="Processing data"):
    time.sleep(0.01)
    # Processing data: 100%|██████████| 100/100 [00:01<00:00, 82.70it/s]
  • total: 当可迭代对象没有 __len__ 方法时,你可以手动指定总迭代次数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import random

    # 假设有一个生成器
    def my_generator():
    for _ in range(100):
    yield random.randint(1, 10)

    # 因为生成器没有长度,所以需要指定 total
    for item in tqdm(my_generator(), total=100):
    time.sleep(0.01)
    # 100%|██████████| 100/100 [00:01<00:00, 81.47it/s]
  • unit: 设置迭代单位,例如 'B'(字节)、'it'(迭代)

    1
    2
    3
    4
    # 模拟文件下载进度条
    for i in tqdm(range(1024), unit='B', unit_scale=True, desc="Downloading file"):
    time.sleep(0.001)
    # Downloading file: 100%|██████████| 1.02k/1.02k [00:12<00:00, 83.3B/s]
    • unit_scale=True 会自动将单位转换为 K、M、G 等,让显示更友好
  • ncols: 设置进度条的宽度,可以是一个整数或 None(自动适应终端宽度)

    1
    2
    3
    for i in tqdm(range(100), ncols=80): # 固定宽度为 80 个字符
    time.sleep(0.01)
    # 100%|█████████████████████████████████████████| 100/100 [00:01<00:00, 83.97it/s]
  • inital: 设置初始进度,用于从中间恢复任务的场景,一般用于手动控制的场景,下面会介绍


Tqdm 手动控制用法

  • 在某些情况下,可能无法直接将可迭代对象传递给 tqdm,例如当你需要在一个循环中分步更新进度时
  • 这时可以使用 tqdm 的上下文管理器或手动控制

使用上下文管理器 (推荐)

  • 使用 with 语句可以确保进度条在循环结束后正确关闭

    1
    2
    3
    4
    5
    6
    7
    with tqdm(total=100, desc="Manual loop") as pbar:
    for i in range(100):
    # 你的任务代码
    time.sleep(0.01)
    # 手动更新进度条
    pbar.update(1)
    # Manual loop: 100%|██████████| 100/100 [00:01<00:00, 82.50it/s]
    • pbar.update(n) 会将进度条前进 n 步,注意超过以后继续更新会导致超出部分显示异常(不会报错)
      1
      2
      3
      4
      5
      6
      7
      with tqdm(total=100, desc="Manual loop") as pbar:
      for i in range(100):
      # 你的任务代码
      time.sleep(0.01)
      # 手动更新进度条
      pbar.update(2)
      # Manual loop: 200it [00:01, 165.95it/s]
  • 若使用 initial 参数,则代码示例如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    from tqdm import tqdm
    import time

    with tqdm(initial=50, total=100, desc="Manual loop") as pbar:
    for i in range(50):
    # 你的任务代码
    time.sleep(0.01)
    # 手动更新进度条
    pbar.update(1)
    # 开始就是下面这样
    # Manual loop: 50%|█████ | 50/100 [00:00<?, ?it/s]
    # 最终变成这样:
    # Manual loop: 100%|██████████| 100/100 [00:00<00:00, 81.43it/s]

手动创建和关闭

  • 如果不能使用上下文管理器,可以手动创建和关闭进度条
    1
    2
    3
    4
    5
    6
    7
    8
    pbar = tqdm(total=100)
    for i in range(100):
    # 你的任务代码
    time.sleep(0.01)
    # 更新进度条
    pbar.update(1)
    # 任务完成后手动关闭进度条
    pbar.close()

Tqdm 高级用法

tqdm.notebook for Jupyter/IPython(暂未测试)

  • 如果在 Jupyter Notebook 或 IPython 环境中,可以使用 tqdm.notebook 模块,它会生成一个更美观的 HTML 进度条
    1
    2
    3
    4
    5
    from tqdm.notebook import tqdm
    import time

    for i in tqdm(range(100)):
    time.sleep(0.01)

tqdm.pandas for Pandas(暂未测试)

  • Tqdm 可以轻松地与 Pandas 的 apply、groupby 等方法结合,为数据处理过程添加进度条
    1
    2
    3
    4
    5
    6
    7
    8
    9
    import pandas as pd
    from tqdm.pandas import tqdm # 注意这里导入的是 tqdm.pandas 包

    tqdm.pandas(desc="Processing DataFrame")

    df = pd.DataFrame({'a': range(100000)})

    # 使用 progress_apply 替代 apply
    df['b'] = df['a'].progress_apply(lambda x: x * 2)

嵌套进度条

  • 当需要为嵌套循环添加进度条时,可以将内部循环的 tqdm 实例作为外部 tqdm 的子项

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    for i in tqdm(range(5), desc="Outer loop"):
    for j in tqdm(range(10), desc=f"Inner loop {i}", leave=False):
    time.sleep(0.01)
    # Outer loop: 0%| | 0/5 [00:00<?, ?it/s]
    # Inner loop 0: 0%| | 0/10 [00:00<?, ?it/s]
    # Inner loop 0: 90%|█████████ | 9/10 [00:00<00:00, 82.94it/s]
    # Outer loop: 20%|██ | 1/5 [00:00<00:00, 8.23it/s]
    # Inner loop 1: 0%| | 0/10 [00:00<?, ?it/s]
    # Inner loop 1: 90%|█████████ | 9/10 [00:00<00:00, 84.44it/s]
    # Outer loop: 40%|████ | 2/5 [00:00<00:00, 8.31it/s]
    # Inner loop 2: 0%| | 0/10 [00:00<?, ?it/s]
    # Inner loop 2: 90%|█████████ | 9/10 [00:00<00:00, 82.45it/s]
    # Outer loop: 60%|██████ | 3/5 [00:00<00:00, 8.26it/s]
    # Inner loop 3: 0%| | 0/10 [00:00<?, ?it/s]
    # Inner loop 3: 90%|█████████ | 9/10 [00:00<00:00, 80.50it/s]
    # Outer loop: 80%|████████ | 4/5 [00:00<00:00, 8.16it/s]
    # Inner loop 4: 0%| | 0/10 [00:00<?, ?it/s]
    # Inner loop 4: 80%|████████ | 8/10 [00:00<00:00, 79.85it/s]
    # Outer loop: 100%|██████████| 5/5 [00:00<00:00, 8.15it/s]
    • 注意 leave=False 参数,它会确保内部进度条在完成后立即消失,避免屏幕杂乱

Python——全局解释器锁


什么是 GIL?

  • 全局解释器锁(Global Interpreter Lock,GIL)是 Python 解释器(如 CPython)中的一个机制,它确保在同一时刻只有一个线程执行 Python 字节码
  • GIL 的存在意味着:即使在多核处理器上 ,多个线程也无法真正并行执行 Python 代码

GIL 存在的原因

  • GIL 的存在主要是为了保护 Python 解释器的内部状态,避免多线程同时修改导致的数据竞争问题
  • 由于 Python 的内存管理不是线程安全的,GIL 提供了一种简单的解决方案来保证线程安全

GIL对多线程的影响

  • GIL对多线程程序的影响取决于线程执行的任务类型:
    • CPU密集型任务 :这类任务主要是进行大量的计算,需要频繁使用CPU
      • GIL使得,即使使用多线程,多个 CPU 核心也无法同时执行 Python 代码
      • 多线程在 CPU 密集型任务上的表现可能比单线程更差,因为线程切换会带来额外的开销
    • I/O密集型任务 :这类任务主要是进行输入输出操作,如网络请求、文件读写等
      • 在执行 I/O 操作时,线程会释放 GIL,允许其他线程执行
      • 多线程在I/O密集型任务上可以显著提高性能

如何规避 GIL 的限制

  • 使用多进程(常用) :multiprocessing模块允许创建多个进程,每个进程都有自己的 Python 解释器和 GIL,因此可以真正并行执行 CPU 密集型任务
  • 使用 C 扩展(不常用) :将关键部分的代码用 C 语言实现,并在 C 扩展中释放 GIL。这样可以让C代码在多核处理器上并行执行

Python——判断未知源的编码类型

有时候遇到一个文件,而我们并不知道它是什么编码方式编码的,本文给出了一些判断未知文件编码方式的方法


使用chardet包


在程序中判断

  • 安装Chardet包

    1
    pip install chardet
  • 使用Chardet包做判断

    1
    2
    3
    4
    5
    6
    import urllib
    rawdata = urllib.urlopen('http://yahoo.co.jp/').read()
    import chardet
    print chardet.detect(rawdata)
    # Output:
    {'confidence': 0.99, 'language': '', 'encoding': 'utf-8'}

更多高级使用方法可参考chardet文档


直接使用命令判断

  • 安装chardetect工具

    1
    pip install chardet
  • 使用chardetect命令

    1
    2
    3
    4
    # 检测test.txt文件的编码方式
    chardetect test.txt
    # Output:
    test-chardetect.txt: ascii with confidence 1.0

Ubuntu——ElasticSearch安装与配置(logstash)

ElasticSearch的安装与基本配置


安装ElasticSearch

  • 下载ElasticSearch deb安装包
  • 安装ElasticSearch
    • Ubuntu中默认安装路径为/usr/share/elasticsearch/
  • 配置ES为一个服务
    1
    2
    sudo /bin/systemctl daemon-reload
    sudo /bin/systemctl enable elasticsearch.service

启动ElasticSearch

1
sudo systemctl start elasticsearch.service

启动相关问题

  • 设置远程访问:
    • 修改config/elasticsearch.yml中network.host : 0.0.0.0
  • 可能遇到的问题1:

[1]: max virtual memory areas vm.max_map_count [65530] is too low, increase to at least [262144]

  • 解决方案:https://www.cnblogs.com/chenjiangbin/p/12060899.html
  • 可能遇到的问题2:

[1]: max file descriptors [4096] for elasticsearch process is too low, increase to at least [65536]

  • 解决方案:https://blog.csdn.net/python36/article/details/84257343
  • 可能遇到的问题3:

[1]: the default discovery settings are unsuitable for production use; at least one of [discovery.seed_hosts, discovery.seed_providers, cluster.initial_master_nodes] must be configured

  • 解决方案:https://www.cnblogs.com/hellxz/p/11057234.html

关闭ElasticSearch

1
sudo systemctl stop elasticsearch.service

ElasticSearch的使用

ES的使用可从官网上查看,点击下一步接着可看完整个流程
*操作ES接口时,建议使用postman *

查看索引

1
GET /_cat/indices?v

health status index uuid pri rep docs.count docs.deleted store.size pri.store.size

创建索引

1
2
# 创建一个名字为customer的Index
PUT /customer?pretty

添加数据

给Index添加Document数据

1
2
3
4
5
6
# 向customer Index中添加id为1的Document
# 内容为{"name": "John Doe"}
PUT /customer/_doc/1?pretty
{
"name": "John Doe"
}

如果是使用Postman,这里选择json类型即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
{
"_index" : "customer",
"_type" : "_doc",
"_id" : "1",
"_version" : 1,
"result" : "created",
"_shards" : {
"total" : 2,
"successful" : 1,
"failed" : 0
},
"_seq_no" : 0,
"_primary_term" : 1
}

查询数据

1
2
# 从customer中查询id为1的数据
GET /customer/_doc/1?pretty
1
2
3
4
5
6
7
8
{
"_index" : "customer",
"_type" : "_doc",
"_id" : "1",
"_version" : 1,
"found" : true,
"_source" : { "name": "John Doe" }
}

删除索引

1
2
# 删除名为customer的索引
DELETE /customer?pretty

修改数据

1
2
3
4
5
6
7
# 向customer Index中添加id为1的Document
# 如果目标Document已经存在,则修改目标Document为指定的数据
# 内容为{"name": "John Doe"}
PUT /customer/_doc/1?pretty
{
"name": "Joe Doe"
}

使用logstash同步数据

6.2.4同步mysql数据到ES

安装logstash

  • 下载logstash 6.2.4
    • 为了方便配置管理,建议下载zip或者tar.gz版本
  • 解压到指定文件夹,建议在相关项目下创建ElasticSearch文件夹,并存储以下数据
    • logstash 解压文件夹logstash-6-2-4
    • 新建文件: logstash更新 mysql 数据库索引到 ES 时的配置文件,一般命名为mysql.config [后面会给出详细内容]
    • 下载jdbc库: mysql-connector-java.jar [用于连接数据库]
      • 这里下载Platform Independent的压缩包版本解压即可找到需要的jar包

配置文件及jdbc连接库

  • 新建一个名为mysql.config的文件,内容为

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    # logstash-6.2.2\bin\logstash.bat -f mysql.config
    # logstash-6.2.4/bin/logstash -f mysql.config
    input {
    jdbc {
    # mysql jdbc connection string to our backup databse
    jdbc_connection_string => "jdbc:mysql://127.0.0.1:3306/docker_manager?serverTimezone=UTC"
    # the user we wish to excute our statement as
    jdbc_user => "docker"
    jdbc_password => "123456"
    # the path to our downloaded jdbc driver
    jdbc_driver_library => "/home/jiahong/Workspace/IdeaProjects/DockerManagerSystem/elastic-search/mysql-connector-java-8.0.11.jar"
    # the name of the driver class for mysql
    jdbc_driver_class => "Java::com.mysql.jdbc.Driver"
    # jdbc_paging_enabled => "true"
    # jdbc_page_size => "50000"
    # statement_filepath => "jdbc.sql"
    schedule => "* * * * *"
    # type => "jdbc"
    statement => "SELECT * FROM docker_manager.DockerManager_docker WHERE id > :sql_last_value"
    use_column_value => true
    tracking_column => "id"
    }
    }

    output {
    elasticsearch {
    hosts => "127.0.0.1:9200"
    index => "docker"
    document_id => "%{id}"
    }
    }
    • 关于mysql.config,一般需要配置的地方为:
      • jdbc_connection_string: 注意docker_manager为数据库名,后面的时区参数有时也需要修改,关于时区的问题可我的博客
        Linux——Logstash时区问题
      • jdbc_driver_library: 这里需指定到对应的jdbc连接库
      • statement: 数据库查询语句
      • jdbc_user: 数据库用户名l
      • jdbc_password: MySQL用户名对应的密码
      • hosts: 输出到ES地址
      • index: ES服务器的Index(相当于MySQL中的数据库)

利用logstash同步mysql数据库数据到ES

  • 启动同步操作
    1
    2
    3
    4
    # 进入之前创建用于解压logstash源文件的目录下
    cd ElasticSearch
    # 启动同步
    ./logstash-6.2.4/bin/logstash -f mysql.config

一旦同步开始,如果不关闭进程,那么将一直自动同步,建议同步完成(当输出不再变化)后关闭进程

  • 查看ES中相应的Index是否已经被更新

Ubuntu——屏幕亮度调节

Ubuntu屏幕亮度调节


Flux

  • Flux是一款非常好用的屏幕调整软件,不仅可以调整屏幕亮度,还能调整色调等
  • 但是,许多Ubuntu版本下,Flux都不能使用,还会影响apt-get update命令,一般情况下,都会将其从表中删除
    • 如果Flux可以使用,建议优先使用Flux

xrandr调整屏幕亮度

  • 展示已有屏幕

    1
    xrandr -q
    • 或者可以使用更简洁的版本
      1
      xrandr -q | grep " connected"
  • 调整指定屏幕的亮度

    1
    xrandr --output HDMI-0 --brightness 0.5
    • 其中HDMI-0是屏幕名称,0.5是亮度参数,大于1时屏幕会发白,注意不要调整的太亮
  • 说明:xrandr修改的屏幕亮度在重启电脑后不会保存,如果想要每次开机自动调整,可以在开机启动添加脚本

MySQL——Ubuntu18.04安装及配置

解决Ubuntu18.04安装Ubuntu后普通用户没有权限登录问题


安装

  • 安装命令
    1
    2
    3
    4
    5
    6
    7
    8
    # 安装mysql服务
    sudo apt-get install mysql-server
    # 安装客户端
    sudo apt install mysql-client
    # 安装依赖
    sudo apt install libmysqlclient-dev
    # 检查状态
    sudo netstat -tap | grep mysql

远程链接设置

默认远程链接问题

Ubuntu18.04 安装mysql后,普通用户没有连接mysql数据库权限(local或者remote均没有权限)

  • mysql -u root或者直接用使用MySQL Workbench连接均失败,显示如下错误

    ERROR 1045: Access denied for user: ‘root@localhost’ (Using password: YES)

  • 使用sudo mysql -u root能正确连接

打开远程链接数据

(删除之前的root并重新创建账户)

  • 登录

    1
    sudo mysql -u root
  • 然后查看当前用户

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    select User,Host from mysql.user;
    # Output:
    +------------------+-----------+
    | User | Host |
    +------------------+-----------+
    | admin | localhost |
    | debian-sys-maint | localhost |
    | magento_user | localhost |
    | mysql.sys | localhost |
    | root | localhost |
  • 删除root账号

    1
    drop user 'root'@'localhost';
  • 创建新的root账号

    1
    create user 'root'@'%' identified by '123456';
  • 为新账号授权

    1
    2
    grant all privileges on *.* to 'root'@'%' with grant option;
    flush privileges;
  • 退出

    1
    exit;
  • 以普通用户身份登录

    1
    mysql -u root -p

Python——反编译(Disassemble)与字节码(Bytecode)

为了知道Python代码底层都做了哪些操作,我们常常需要反编译Python代码以获得Python的字节码
我们可以获得: classes, methods, functions, or code 的字节码


获取字节码的方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 比较`[]`和`list()`两者的不同
from dis import dis

# test case 1
dis("[]")
# Output:
1 0 BUILD_LIST 0
2 RETURN_VALUE

# test case 2
dis("list()")
# Output:
1 0 LOAD_NAME 0 (list)
2 CALL_FUNCTION 0
4 RETURN_VALUE

**由上述输出可知,`list()` 比 `[]` 会多执行一行字节码`LODA_NAME`**

Python——分布式编程之mpi4py使用


整体说明

  • 分布式系统 是由一组通过网络相互连接的独立计算节点(或计算机)组成的集合,这些节点协同工作以实现一个共同的目标,实现高效的计算和处理
  • MPI(Message Passing Interface,消息传递接口) 是一个跨语言的并行计算通信协议,也是一个应用程序接口(API)标准,允许程序在分布式内存系统中高效地交换数据
    • MPI 定义了一套标准的库函数和语义规则,允许在非共享内存环境下的多个进程(通常运行在不同的处理器或计算节点上)通过发送和接收消息进行通信和协作
  • mpi4py 库是一个构建在 MPI 之上的 Python 库 ,主要使用 Cython 编写
    • Cython 的目标是让 Python 代码具备 C 语言的高性能,它是 Python 的一个超集,既支持 Python 语法,又能调用 C 函数、定义 C 类型变量,从而优化 Python 代码的性能
  • mpi4py 库以面向对象的方式提供了在 Python 环境下调用 MPI 标准的编程接口,这些接口构建在 MPI-2 C++ 编程接口基础之上,与 C++ 的 MPI 编程接口类似
  • mpi4py 库实现了很多 MPI 标准中的接口,包括点对点通信、集合通信、阻塞/非阻塞通信、组间通信等
    • 可以在不同进程间传递任何可被 pickle 序列化的内置和用户自定义 Python 对象

安装 mpi4py

Mac 环境安装

  • 安装 Homebrew(若已安装可跳过):

    1
    /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
  • 安装 Open MPI(建议使用 Open MPI,因为它在 Mac 上通常更容易配置):

    1
    brew install open-mpi
    • 这一步安装依赖比较多,需要一些时间
    • 安装完成后,你可以通过运行 mpiexec --version 来验证 Open MPI 是否正确安装
  • 安装 mpi4py:

    1
    pip install mpi4py
    • (常见报错解法)如果你的系统中有多个 MPI 版本,或者 pip 无法找到正确的 MPI 库,你可能需要设置 MPICC 环境变量来指定 MPI 编译器。例如:

      1
      2
      export MPICC=/opt/homebrew/bin/mpicc  # 根据你的 Open MPI 安装路径调整
      pip install mpi4py
    • 在安装 mpi4py 后运行以下命令来验证安装:

      1
      python -c "import mpi4py; print(mpi4py.__version__)"

Ubuntu 环境安装(待补充)


mpi4py 的使用

mpi4py 的操作总结(通信原语)

  • mpi4py提供了并行计算所需的各种通信操作,主要分为两类:
  • 点对点通信(Point-to-Point Communication) :
    • 在两个特定进程间进行数据交换
    • 主要操作:Send/send, Recv/recv, Isend/isend, Irecv/irecv, Sendrecv/sendrecv等
  • 集体通信(Collective Communication) :
    • 涉及通信组(communicator)中的所有进程
    • 提供更高效的全局数据操作
  • 不同集体通信操作对比
    操作类型 通信模式 数据流向 结果分布 典型应用场景
    Barrier/barrier 同步 无数据传输 无 进程同步
    Bcast/bcast 一对多 根 -> 所有 所有进程相同 分发配置参数
    Scatter/scatter 一对多 根 -> 各部分 每个进程不同 数据并行分解
    Gather/gather 多对一 所有 -> 根 仅根进程有结果 结果收集
    Allgather/allgather 多对多 所有 -> 所有 所有进程相同 全局信息共享
    Alltoall/alltoall 多对多 所有↔所有 每个进程不同 矩阵转置
    Reduce/reduce 多对一 所有 -> 规约 -> 根 仅根进程有结果 全局统计计算
    Allreduce/allreduce 多对多 所有 -> 规约 -> 所有 所有进程相同 需要全局结果的并行计算
    Scan/scan 多对多 前缀累积 每个进程不同 累积计算
    Reduce_scatter 多对多 先reduce再scatter 每个进程不同
  • 在很多地方中,表述也使用类似 All-Gather, All-to-All, All-Reduce 等来表达
  • All-Gather 和 All-to-All 的区别:
    • All-Gather 和 All-to-All 都是多对多发送数据
    • 发送数据上来看:
      • All-Gather 中,从任意进程的视角看,向不同进程发送的数据是相同的;
      • All-to-All 中则向不同进程发送不同数据,默认数据的维度和 world_size 相同,每个 receive_rank 进程会得到其他进程 send_data[receive_rank] 的数据
    • 结果上来看:All-Gather 操作后,各个进程最终的数据是相同的;All-to-All 操作后,不同进程最终的数据不同

mpi4py 使用演示

  • 以下 mpi4py 示例均以小 Pythonic 风格(小写) 为例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    from mpi4py import MPI
    import numpy as np
    import time

    def main():
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    print(f"进程 {rank} 启动,共有 {size} 个进程")

    # 1. barrier - 进程同步
    comm.Barrier() # 与 comm.barrier() 等价

    # 2. bcast - 广播数据
    if rank == 0:
    broadcast_data = {"message": "这是来自 rank=0 的 broadcast 消息", "value": 123}
    else:
    broadcast_data = None
    broadcast_data = comm.bcast(broadcast_data, root=0)
    print(f"进程 {rank} 收到 bcast 数据: {broadcast_data}")

    # 3. scatter - 分发数据
    if rank == 0:
    scatter_data = [f"from-0-index-{i}" for i in range(size)]
    else:
    scatter_data = None
    data = comm.scatter(scatter_data, root=0)
    print(f"进程 {rank} 收到 scatter 数据: {data}")

    comm.Barrier()
    # 4. gather - 收集数据
    if rank == 0:
    print(f"==开始测试更复杂的原语:==")
    comm.Barrier()
    local_data = np.arange(size * rank, size * (rank + 1), dtype=np.int32)
    print(f"进程 {rank} 原始数据: {local_data}")
    gathered_data = comm.gather(local_data, root=0)
    print(f"进程 {rank} 收到 gather 数据: {gathered_data}")

    # 5. allgather - 全收集数据
    all_gathered_data = comm.allgather(local_data)
    print(f"进程 {rank} 收到 allgather 数据: {all_gathered_data}")

    # 6. alltoall - 全对全数据交换
    recv_data = comm.alltoall(local_data)
    print(f"进程 {rank} 收到 alltoall 数据: {recv_data}")

    # 7. reduce - 规约操作
    recv_data = comm.reduce(local_data, op=MPI.SUM, root=0)
    print(f"进程 {rank} 收到 reduce 数据: {recv_data}")

    # 8. allreduce - 全局规约操作
    recv_data = comm.allreduce(local_data, op=MPI.SUM)
    print(f"进程 {rank} 收到 allreduce 数据: {recv_data}")

    time.sleep(0.1)
    # 高阶,任意维度的 allreduce 操作
    comm.Barrier()
    if rank == 0:
    print(f"==测试任意维度的 allreduce 操作:==")
    comm.Barrier()
    all_reduce_local_data = np.array([rank*i for i in range(size+3)], dtype=np.int32)
    print(f"进程 {rank} 原始 all_reduce_local_data 数据: {all_reduce_local_data}")
    recv_data = comm.allreduce(all_reduce_local_data, op=MPI.SUM) # 输入可以是任意维度
    print(f"进程 {rank} 收到 allreduce 数据: {recv_data}")
    comm.Barrier()

    time.sleep(0.1)
    # 9. scan - 前缀积累,按照 rank 累积
    recv_data = comm.scan(local_data, op=MPI.SUM)
    print(f"进程 {rank} 收到 scan 数据: {recv_data}")
    comm.Barrier()

    time.sleep(0.1)
    # 增加:reduce_scatter - reduce_scatter规约操作
    if rank == 0:
    print(f"==开始 reduce_scatter 演示==")
    local_data = np.array([rank*i for i in range(size)])
    reduce_scatter_recv_data = np.array(0, local_data.dtype) # 必须将发送数据和接受数据的类型对齐,否则会出现类型不匹配的错误
    print(f"进程 {rank} local_data 原始数据: {local_data}")
    comm.Barrier()
    print(f"进程 {rank} reduce_scatter_recv_data 原始数据: {reduce_scatter_recv_data}")
    comm.Reduce_scatter(local_data, reduce_scatter_recv_data, op=MPI.SUM)
    print(f"进程 {rank} 收到 Reduce_scatter 数据: {reduce_scatter_recv_data}")
    comm.Barrier()

    time.sleep(0.1)
    # 点对点通信函数演示
    if rank == 0:
    print(f"==开始点对点通信演示==")
    # 10. send 和 recv - 阻塞式发送接收
    if rank == 0:
    dest = 1
    message = f"from {rank}, 你好,进程 1!"
    comm.send(message, dest=dest)
    print(f"进程 {rank} 发送消息到进程 {dest}")

    source = 1
    recv_msg = comm.recv(source=source)
    print(f"进程 {rank} 从进程 {source} 收到消息: {recv_msg}")

    elif rank == 1:
    source = 0
    recv_msg = comm.recv(source=source)
    print(f"进程 {rank} 从进程 {source} 收到消息: {recv_msg}")

    dest = 0
    message = f"from {rank}, 你好,进程 0!"
    comm.send(message, dest=dest)
    print(f"进程 {rank} 发送消息到进程 {dest}")

    comm.Barrier()
    # 11. isend 和 irecv - 非阻塞式发送接收
    if rank == 2:
    dest = 3
    message = "非阻塞消息"
    req = comm.isend(message, dest=dest)

    source = 3
    req_recv = comm.irecv(source=source)

    # 可以在这里执行其他计算
    req.wait() # 等待发送完成
    print(f"进程 {rank} 非阻塞发送完成")

    recv_msg = req_recv.wait() # 等待接收完成
    print(f"进程 {rank} 收到非阻塞消息: {recv_msg}")

    elif rank == 3:
    source = 2
    req_recv = comm.irecv(source=source)

    dest = 2
    message = "非阻塞回复"
    req = comm.isend(message, dest=dest)

    # 可以在这里执行其他计算
    recv_msg = req_recv.wait() # 等待接收完成
    print(f"进程 {rank} 收到非阻塞消息: {recv_msg}")

    req.wait() # 等待发送完成
    print(f"进程 {rank} 非阻塞发送完成")

    comm.Barrier()
    # 12. sendrecv - 同时发送和接收
    send_val = -rank * 100
    dest = (rank + 1) % size # 发送给下一个
    source = (rank - 1) if rank - 1 >= 0 else size - 1 # 接收自来自上一个的

    # 上面的实现是一个圆环的通信方式,节点之间消息是依次流转的
    print(f"rank={rank}, dest={dest}, source={source}")

    # 注意:一定要对齐,A.dest = B,则必须有 B.source = A,否则下面的语句会卡死(一直等待)
    recv_val = comm.sendrecv(send_val, dest=dest, source=source)
    print(f"进程 {rank} 发送 {send_val} 到进程 {dest},从进程 {source} 接收 {recv_val}")

    # 所有演示操作做完,最终同步
    comm.Barrier()
    if rank == 0:
    print("\n所有进程完成演示")

    if __name__ == "__main__":
    main()

    # 进程 0 启动,共有 4 个进程
    # 进程 2 启动,共有 4 个进程
    # 进程 3 启动,共有 4 个进程
    # 进程 1 启动,共有 4 个进程
    # 进程 0 收到 bcast 数据: {'message': '这是来自 rank=0 的 broadcast 消息', 'value': 123}
    # 进程 1 收到 bcast 数据: {'message': '这是来自 rank=0 的 broadcast 消息', 'value': 123}
    # 进程 2 收到 bcast 数据: {'message': '这是来自 rank=0 的 broadcast 消息', 'value': 123}
    # 进程 3 收到 bcast 数据: {'message': '这是来自 rank=0 的 broadcast 消息', 'value': 123}
    # 进程 1 收到 scatter 数据: from-0-index-1
    # 进程 3 收到 scatter 数据: from-0-index-3
    # 进程 0 收到 scatter 数据: from-0-index-0
    # ==开始测试更复杂的原语:==
    # 进程 2 收到 scatter 数据: from-0-index-2
    # 进程 0 原始数据: [0 1 2 3]
    # 进程 1 原始数据: [4 5 6 7]
    # 进程 3 原始数据: [12 13 14 15]
    # 进程 1 收到 gather 数据: None
    # 进程 3 收到 gather 数据: None
    # 进程 2 原始数据: [ 8 9 10 11]
    # 进程 2 收到 gather 数据: None
    # 进程 0 收到 gather 数据: [array([0, 1, 2, 3], dtype=int32), array([4, 5, 6, 7], dtype=int32), array([ 8, 9, 10, 11], dtype=int32), array([12, 13, 14, 15], dtype=int32)]
    # 进程 0 收到 allgather 数据: [array([0, 1, 2, 3], dtype=int32), array([4, 5, 6, 7], dtype=int32), array([ 8, 9, 10, 11], dtype=int32), array([12, 13, 14, 15], dtype=int32)]
    # 进程 1 收到 allgather 数据: [array([0, 1, 2, 3], dtype=int32), array([4, 5, 6, 7], dtype=int32), array([ 8, 9, 10, 11], dtype=int32), array([12, 13, 14, 15], dtype=int32)]
    # 进程 3 收到 allgather 数据: [array([0, 1, 2, 3], dtype=int32), array([4, 5, 6, 7], dtype=int32), array([ 8, 9, 10, 11], dtype=int32), array([12, 13, 14, 15], dtype=int32)]
    # 进程 2 收到 allgather 数据: [array([0, 1, 2, 3], dtype=int32), array([4, 5, 6, 7], dtype=int32), array([ 8, 9, 10, 11], dtype=int32), array([12, 13, 14, 15], dtype=int32)]
    # 进程 1 收到 alltoall 数据: [1, 5, 9, 13]
    # 进程 0 收到 alltoall 数据: [0, 4, 8, 12]
    # 进程 2 收到 alltoall 数据: [2, 6, 10, 14]
    # 进程 3 收到 alltoall 数据: [3, 7, 11, 15]
    # 进程 3 收到 reduce 数据: None
    # 进程 1 收到 reduce 数据: None
    # 进程 2 收到 reduce 数据: None
    # 进程 0 收到 reduce 数据: [24 28 32 36]
    # 进程 1 收到 allreduce 数据: [24 28 32 36]
    # 进程 0 收到 allreduce 数据: [24 28 32 36]
    # 进程 3 收到 allreduce 数据: [24 28 32 36]
    # 进程 2 收到 allreduce 数据: [24 28 32 36]
    # ==测试任意维度的 allreduce 操作:==
    # 进程 0 原始 all_reduce_local_data 数据: [0 0 0 0 0 0 0]
    # 进程 2 原始 all_reduce_local_data 数据: [ 0 2 4 6 8 10 12]
    # 进程 3 原始 all_reduce_local_data 数据: [ 0 3 6 9 12 15 18]
    # 进程 1 原始 all_reduce_local_data 数据: [0 1 2 3 4 5 6]
    # 进程 0 收到 allreduce 数据: [ 0 6 12 18 24 30 36]
    # 进程 2 收到 allreduce 数据: [ 0 6 12 18 24 30 36]
    # 进程 1 收到 allreduce 数据: [ 0 6 12 18 24 30 36]
    # 进程 3 收到 allreduce 数据: [ 0 6 12 18 24 30 36]
    # 进程 3 收到 scan 数据: [24 28 32 36]
    # 进程 0 收到 scan 数据: [0 1 2 3]
    # 进程 1 收到 scan 数据: [ 4 6 8 10]
    # 进程 2 收到 scan 数据: [12 15 18 21]
    # 进程 3 local_data 原始数据: [0 3 6 9]
    # 进程 2 local_data 原始数据: [0 2 4 6]
    # 进程 1 local_data 原始数据: [0 1 2 3]
    # ==开始 reduce_scatter 演示==
    # 进程 0 local_data 原始数据: [0 0 0 0]
    # 进程 0 reduce_scatter_recv_data 原始数据: 0
    # 进程 1 reduce_scatter_recv_data 原始数据: 0
    # 进程 2 reduce_scatter_recv_data 原始数据: 0
    # 进程 3 reduce_scatter_recv_data 原始数据: 0
    # 进程 0 收到 Reduce_scatter 数据: 0
    # 进程 3 收到 Reduce_scatter 数据: 18
    # 进程 2 收到 Reduce_scatter 数据: 12
    # 进程 1 收到 Reduce_scatter 数据: 6
    # ==开始点对点通信演示==
    # 进程 0 发送消息到进程 1
    # 进程 1 从进程 0 收到消息: from 0, 你好,进程 1!
    # 进程 1 发送消息到进程 0
    # 进程 0 从进程 1 收到消息: from 1, 你好,进程 0!
    # 进程 2 非阻塞发送完成
    # 进程 3 收到非阻塞消息: 非阻塞消息
    # 进程 3 非阻塞发送完成
    # 进程 2 收到非阻塞消息: 非阻塞回复
    # rank=2, dest=3, source=1
    # rank=3, dest=0, source=2
    # rank=0, dest=1, source=3
    # rank=1, dest=2, source=0
    # 进程 1 发送 -100 到进程 2,从进程 0 接收 0
    # 进程 0 发送 0 到进程 1,从进程 3 接收 -300
    # 进程 2 发送 -200 到进程 3,从进程 1 接收 -100
    # 进程 3 发送 -300 到进程 0,从进程 2 接收 -200
    #
    # 所有进程完成演示
  • 启动上述代码的命令为:

    1
    mpiexec -n 4 python example.py

附录:传输不同维度的数据

  • 在上述 allgather, alltoall, allreduce 等操作中,同一个进程传输给其他进程的数据不一定要维度完全相等,甚至类型也不一定要相同,只需要能够做对应的 MPI op 就可以
  • 示例(某个元素改成列表):
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    from mpi4py import MPI
    import numpy as np
    import time

    from numba.cuda import local

    def main():
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    print(f"进程 {rank} 启动,共有 {size} 个进程")

    comm.Barrier()

    local_data = [
    [0, 1, 2, 3],
    [4, 5, 6, 7],
    [ 8, 9, [10,18], 11],
    [12, 13, 14, 15],
    ]
    local_data = local_data[rank]
    print(f"进程 {rank} 原始数据: {local_data}")
    comm.Barrier()
    time.sleep(0.1)

    # allgather
    all_gathered_data = comm.allgather(local_data)
    print(f"进程 {rank} 收到 allgather 数据: {all_gathered_data}")
    comm.Barrier()
    time.sleep(0.1)

    # alltoall
    recv_data = comm.alltoall(local_data)
    print(f"进程 {rank} 收到 alltoall 数据: {recv_data}")
    comm.Barrier()
    time.sleep(0.1)

    # areduce
    recv_data = comm.reduce(local_data, op=MPI.SUM, root=0)
    print(f"进程 {rank} 收到 reduce 数据: {recv_data}")
    comm.Barrier()
    time.sleep(0.1)

    # allreduce - 全局规约操作
    recv_data = comm.allreduce(local_data, op=MPI.SUM)
    print(f"进程 {rank} 收到 allreduce 数据: {recv_data}")
    comm.Barrier()
    time.sleep(0.1)

    if __name__ == "__main__":
    main()

    # 进程 2 启动,共有 4 个进程
    # 进程 3 启动,共有 4 个进程
    # 进程 1 启动,共有 4 个进程
    # 进程 0 启动,共有 4 个进程
    # 进程 0 原始数据: [0, 1, 2, 3]
    # 进程 1 原始数据: [4, 5, 6, 7]
    # 进程 3 原始数据: [12, 13, 14, 15]
    # 进程 2 原始数据: [8, 9, [10, 18], 11]
    # 进程 3 收到 allgather 数据: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, [10, 18], 11], [12, 13, 14, 15]]
    # 进程 1 收到 allgather 数据: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, [10, 18], 11], [12, 13, 14, 15]]
    # 进程 0 收到 allgather 数据: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, [10, 18], 11], [12, 13, 14, 15]]
    # 进程 2 收到 allgather 数据: [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, [10, 18], 11], [12, 13, 14, 15]]
    # 进程 1 收到 alltoall 数据: [1, 5, 9, 13]
    # 进程 0 收到 alltoall 数据: [0, 4, 8, 12]
    # 进程 3 收到 alltoall 数据: [3, 7, 11, 15]
    # 进程 2 收到 alltoall 数据: [2, 6, [10, 18], 14]
    # 进程 3 收到 reduce 数据: None
    # 进程 2 收到 reduce 数据: None
    # 进程 0 收到 reduce 数据: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [10, 18], 11, 12, 13, 14, 15]
    # 进程 1 收到 reduce 数据: None
    # 进程 2 收到 allreduce 数据: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [10, 18], 11, 12, 13, 14, 15]
    # 进程 0 收到 allreduce 数据: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [10, 18], 11, 12, 13, 14, 15]
    # 进程 1 收到 allreduce 数据: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [10, 18], 11, 12, 13, 14, 15]
    # 进程 3 收到 allreduce 数据: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, [10, 18], 11, 12, 13, 14, 15]

附录:导入 mpi4py 包的方式

  • 必须使用:

    1
    from mpi4py import MPI
  • 不能使用:

    1
    2
    import mpi4py
    MPI = mpi4py.MPI
  • 两种写法看似相同, 实际上不一样,第二种方法报错为 AttributeError: module 'mpi4py' has no attribute 'MPI'

  • 这是因为 第二种方法 MPI 环境未初始化 ,mpi4py 要求在使用 MPI 功能前必须先初始化,而import mpi4py不会自动完成这一步

不同导入方式的差别

  • 假设有两种导入方式:

    1
    2
    3
    4
    5
    6
    # 方式一
    from packagea import A

    # 方式二
    import packagea
    A = packagea.A
  • 如果 packagea 是一个包(包含 __init__.py),但 __init__.py 中未导入或暴露 A,方式二将无法通过包名直接访问 A

  • 当 packagea 的 __init__.py 未显式导入或暴露 A 时,方式一(from packagea import A)仍能成功的原因与 Python 的导入机制和包结构有关

  • Python 在执行 from packagea import A 时,会按以下步骤查找 A:

    • 1) 检查 packagea.__init__.py :若 __init__.py 中定义或导入了 A,直接使用
    • 2) 搜索子模块 :若 __init__.py 未包含 A,Python 会在 packagea 目录下查找是否存在 A.py 或 A/__init__.py
    • 3) 递归子模块 :若仍未找到,Python 会尝试递归导入子模块中的 A(例如 packagea.module_a.A),但需显式指定路径(如 from packagea.module_a import A)

mpi4py 两种导入结果不同的原因

  • 方式一成功的原因 :当执行 from mpi4py import MPI 时:
    • 1) Python 加载 mpi4py 包,执行 __init__.py
    • 2) __init__.py 注册元路径导入器(_mpiabi._install_finder())
    • 3) Python 发现 __all__ 中包含 MPI,但 __init__.py 中未显式定义
    • 4) 触发元路径导入器,根据系统 MPI 环境选择并加载对应的 MPI.so 文件(注意名称不一定是这个,可能是根据不同系统命名的文件)
      • 注:已确认在 ~/anaconda3/envs/xxx/lib/python3.10/site-packages/mpi4py/ 路径下不存在 MPI.so 文件
      • 虽然 mpi4py 目录下没有直接名为 MPI.so 的文件,但实际存在 MPI.mpich.cpython-310-darwin.so(针对 MPICH 实现的扩展模块) 和 MPI.openmpi.cpython-310-darwin.so(针对 OpenMPI 实现的扩展模块)
      • _mpiabi._install_finder() 的作用之一是根据系统中实际安装的 MPI 库(通过环境变量或系统命令探测),选择对应的 .so 文件
    • 5) MPI 扩展模块被加载,并绑定到 mpi4py 命名空间,导入成功
  • 方式二失败的原因 :方式一(import mpi4py 后 mpi4py.MPI)失败是因为:
    • import mpi4py 仅执行 __init__.py,不会触发元路径导入器对 MPI 的查找
    • __init__.py 中没有显式导入或定义 MPI(如 from .MPI import MPI),因此 mpi4py.MPI 不存在于命名空间中

附录:传统 MPI 风格函数的使用

  • 在 mpi4py 中,函数名的大小写是有区别的,主要涉及两种不同的编程接口风格:Pythonic 风格(小写) 和 传统 MPI C/Fortran 风格(大写)
  • 总结:
    • 小写函数 :更 Pythonic,返回数据,适合动态数据,代码简洁,不需要管理缓冲区
    • 大写函数 :类似 C/Fortran MPI,需要缓冲区,适合高性能计算(避免数据拷贝)

小写函数(Pythonic 风格)

  • 返回数据 ,而不是修改传入的缓冲区(更符合 Python 习惯)

  • 通常更简洁,适合 Python 风格的编程

  • 适用于 NumPy 数组和 Python 对象(如列表、字典等)

  • 以 gather 为例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import numpy as np
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # 使用小写 gather(返回结果,不修改传入的 recvbuf)
    gathered_data = comm.gather(data, root=0)
    if rank == 0:
    print("Gathered data:", gathered_data) # 输出:[0, 10, 20, ...]
  • 特点:

    • gather 返回一个列表 ,包含所有进程发送的数据(仅在 root 进程有效)
    • 不需要预先分配 recvbuf ,适合动态数据

大写函数(传统 MPI 风格)

  • 需要预先分配接收缓冲区(recvbuf) ,类似 C/Fortran 的 MPI 接口

  • 直接修改传入的缓冲区 ,而不是返回数据

  • 适用于高性能计算(特别是 NumPy 数组 ,避免数据拷贝)

  • 以 gather 为例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    import numpy as np
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    # 每个进程准备自己的数据(NumPy 数组)
    sendbuf = np.array([rank * 10], dtype=int)

    if rank == 0:
    # 预先分配 recvbuf(大小必须匹配)
    recvbuf = np.empty(size, dtype=int)
    else:
    recvbuf = None # 非 root 进程不需要 recvbuf

    # 使用大写 Gather(修改 recvbuf)
    comm.Gather(sendbuf, recvbuf, root=0)

    if rank == 0:
    print("Gathered data:", recvbuf) # 输出:[0, 10, 20, ...]
  • Gather 需要预先分配 recvbuf (在 root 进程)

  • 适用于高性能计算(避免 Python 对象的额外开销)

类似函数总结

小写(Pythonic) 大写(传统 MPI) 用途
bcast Bcast 广播数据
scatter Scatter 分散数据
gather Gather 收集数据
allgather Allgather 全收集
reduce Reduce 规约计算
allreduce Allreduce 全规约
scan Scan 前缀计算

附录:分布式中 GPU 主要集体通信操作介绍

  • 分布式计算中常用的集体通信(Collective Communication)操作有All-Gather、All-Reduce 和 Reduce-Scatter,主要用于多进程或多设备(如GPU)之间的数据交互

All-Gather(全收集)

  • All-Gather:每个进程提供一块数据,最终所有进程收集到所有其他进程的数据,结果是一个包含所有数据的聚合
  • 输入:每个进程有一块独立数据(如 data_i)
  • 输出:所有进程得到相同的全量数据([data_0, data_1, ..., data_n])
  • 举例:进程0有 A,进程1有 B,进程2有 C -> 最终所有进程得到 [A, B, C]
  • 功能:参数广播、分布式训练中同步模型参数

All-Reduce(全规约)

  • All-Reduce:先对所有进程的数据进行规约操作(如求和、最大值等),然后将结果分发给所有进程
  • 输入:每个进程有一块数据(如 data_i)
  • 计算:对所有 data_i 执行规约(如 sum(data_0, data_1, ..., data_n))
  • 输出:所有进程得到相同的规约结果(如 sum)
  • 举例:进程0有 1,进程1有 2,进程2有 3 -> 求和后所有进程得到 6
  • 功能:梯度聚合(如分布式训练中多卡梯度的全局求和)

Reduce-Scatter(规约散播)

  • Reduce-Scatter:先对所有进程的数据进行规约操作,然后将结果按块分散到不同进程中,每个进程只获得结果的一部分
  • 输入:每个进程有一块数据(如 data_i)
  • 计算:规约所有数据(如 sum),然后将结果按进程数切分
  • 输出:进程 i 获得结果的第 i 块
  • 举例:进程0有 [1, 2],进程1有 [3, 4],进程2有 [5, 6] -> 全局求和为 [9, 12],然后进程0得到 9,进程1得到 12
  • 功能:分布式矩阵计算中分块结果的聚合

整体总结

操作 输入 计算步骤 输出 是否全量同步
All-Gather 每个进程一块数据 收集所有数据并广播 所有进程获得全量数据 是
All-Reduce 每个进程一块数据 规约所有数据并广播结果 所有进程获得相同的规约结果 是
Reduce-Scatter 每个进程一块数据 规约所有数据并按块分发 每个进程只获得结果的一部分 否

特别说明

  • All-Reduce 可以拆分为 Reduce-Scatter + All-Gather(先局部规约后全局同步)
  • 性能差异 :All-Reduce 通常比分开的两步操作更高效(优化后的算法如 Ring-AllReduce)

附录:Ring-AllReduce

  • Ring-AllReduce 是 All-Reduce 的一种高效实现算法,也写为 Ring AllReduce、Ring All-Reduce 等

以数据并行(DP)场景为例

  • 假设有 N 个 GPU,每个 GPU 上都有全部参数和一部分数据
  • 目标是保证所有 GPU 都能完成一次完整的参数更新
  • 特别注意:
    • 每个 GPU 都只有一部分数据,所以需要拿到其他 GPU 的数据才能计算梯度(仅依赖自身甚至算不出任何一个参数的梯度)
    • 有很多参数,每个 GPU 都要完成所有参数的更新

Ring AllReduce 的核心思想

  • Ring AllReduce 的核心思想是:
    • 先将所有 GPU 上的局部梯度数据按照参数分成不同的块(每个块包含一部分参数的局部梯度)
    • 把所有 GPU 组成一个逻辑上的环形结构 ,大家只和自己的“邻居”交流 ,然后通过多次传递 ,最终让所有人都得到完整的梯度结果

Ring AllReduce 工作流程(分两步走)

  • 为了方便理解,本文假设有 4 个 GPU(A, B, C, D)
分块并传递(Reduce-Scatter 阶段)
  • 这一阶段的目标是每个 GPU 负责一部分参数的梯度聚合(编码为 \(S_0, S_1, S_2, S_3\))
  • 首先,每个 GPU 把自己手里的局部梯度按照参数平均分成 N 份(每份叫一个“块”)
  • 然后,大家开始轮流传递:每个 GPU 把自己一部分的“块”传给右边的邻居,同时从左边的邻居那里接收一个“块”
  • 当收到邻居的“块”后,就把这个“块”和自己对应的“块”进行合并(比如求和)
  • 这个过程重复进行 N-1 次,直到每个人的手里都拿到了自己负责的参数对应那部分的梯度并完成聚合
  • 举个例子:
    • A 有局部梯度 \(A_0, A_1, A_2, A_3\),下标表示参数块的索引
    • B 有局部梯度 \(B_0, B_1, B_2, B_3\),下标表示参数块的索引
    • C 有局部梯度 \(C_0, C_1, C_2, C_3\),下标表示参数块的索引
    • D 有局部梯度 \(D_0, D_1, D_2, D_3\),下标表示参数块的索引
    • 假设某个 GPU 的目标是最终实现参数块 0 的梯度累加计算,即 \(S_0 = A_0+B_0+C_0+D_0\)
      • 其他 GPU 的最终目标分别是实现 1,2,3 块参数的梯度累加
  • 第一轮:
    • A 把 \(A_1\) 给 B,从 D 收到 \(D_0\)
    • B 把 \(B_2\) 给 C,从 A 收到 \(A_1\)
    • C 把 \(C_3\) 给 D,从 B 收到 \(B_2\)
    • D 把 \(D_0\) 给 A,从 C 收到 \(C_3\)
    • 然后大家各自合并收到的数据:
      • A 现在有 \(A_1, A_2, A_3\), 还有 \((A_0+D_0)\)(计算后的结果)
      • B 现在有 \(B_0, B_2, B_3\), 还有 \((B_1+A_1)\)(计算后的结果)
      • C 现在有 \(C_0, C_1, C_3\), 还有 \((C_2+B_2)\)(计算后的结果)
      • D 现在有 \(D_0, D_1, D_2\), 还有 \((D_3+C_3)\)(计算后的结果)
  • 第二轮:
    • A 把 \((A_0+D_0)\) 给 B,从 D 收到 \((D_3+C_3)\)
    • B 把 \((B_1+A_1)\) 给 C,从 A 收到 \((A_0+D_0)\)
    • C 把 \((C_2+B_2)\) 给 D,从 B 收到 \((B_1+A_1)\)
    • D 把 \((D_3+C_3)\) 给 A,从 C 收到 \((C_2+B_2)\)
    • 然后大家合并各自收到的数据:
      • A 现在有 \((A_0+D_0), A_1, A_2\), 还有 \((A_3+C_3+D_3)\)(计算后的结果)
      • …
      • D 现在有 \(D_0, D_1, (D_3+C_3)\), 还有 \((B_2+C_2+D_2)\)(计算后的结果)
  • 这个过程会进行 \(N-1\) 次(N 是参与者的数量),最终每个人手里都会有一部分“最终总和”的数据
    • 最终得到,A 手里有 \(S_2\) ,B 手里有 \(S_3\) ,C 手里有 \(S_0\),D 手里有 \(S_1\)

收集并广播(All-Gather 阶段)

  • 所有 GPU 再次开始传递,这次传递不再进行计算,而是把自己手里已经计算好的梯度,传给右边的邻居
  • 当收到邻居的数据后,就把它保存下来
  • 这个过程也重复进行,直到每个人都收到了所有梯度
  • 不是一般性,假定(注意:与上面不同,但是不影响推导和理解):
    • 现在 A 手里有 \(S_0\) ,B 有 \(S_1\) ,C 有 \(S_2\) ,D 有 \(S_3\)
  • 第一轮:
    • A 把 \(S_0\) 给 B
    • B 把 \(S_1\) 给 C
    • C 把 \(S_2\) 给 D
    • D 把 \(S_3\) 给 A
    • 然后大家各自保存收到的数据:
      • A 现在有了 \(S_0\) (自己的) 和 \(S_3\) (从 D 收到)
      • B 现在有了 \(S_1\) (自己的) 和 \(S_0\) (从 A 收到)
  • 第二轮:
    • A 把 \(S_3\) 传递给 B
    • …
    • D 把 \(S_2\) 传递给 A
  • 第三轮:
    • A 把 \(S_2\) 传递给 B,至此,B 从 A 处收到了 \(S_0, S_2, S_3\),加上自己的 \(S_1\),也就得到了完整的 \(S_0, S_1, S_2, S_3\)
    • 其他节点也一次类推
  • 这个过程同样进行 \(N-1\) 次,最终每个人都会收集到所有的 \(S_0, S_1, S_2, S_3\),从而得到了完整的总和

Ring AllReduce 的优点

  • 高效利用带宽 :去中心化设计是的它不会让某个节点成为瓶颈
    • 比如像传统的“参数服务器”模式,所有数据都传给一个中心服务器,而 Ring AllReduce 是让所有节点都参与数据传输和计算,充分利用了网络的带宽
  • 良好的可扩展性 :即使参与的设备数量很多,它的通信效率也能保持相对稳定
    • 非常适合大规模的分布式训练(比如训练大型深度学习模型)
  • 降低延迟 :通过分块和流水线式的传输,它能有效地减少数据同步的等待时间
  • 特别说明:每个 GPU 只能和自己邻居交流这个约束,看似是限制,实际是非常巧妙地设计,是的数据通道建立一次即可,且传输是连续的

通信量对比

  • Ring All-Reduce 是 All-Reduce 的一种高效实现方式,在通信量方面相比传统的All-Reduce有显著优势。
  • 假设存在 \(N\) 个设备(如GPU),每个设备的数据大小为 \(\Phi\)
  • 传统 All-Reduce 的原始版本中,每个GPU需要发送 \((N - 1)\Phi\) 个数据,\(N\) 个 GPU 的总通信量为
    $$N(N - 1)\Phi$$
    • 其通信量与GPU数量呈(N^2) 复杂度
  • Ring All-Reduce 将每个 GPU 存储的数据顺序切分为 \(N\) 块,每块的数据量是 \(\frac{\Phi}{N}\)
    • Ring All-Reduce 包含 Reduce-Scatter 和 All-Gather 两个步骤,每个步骤都需要 \(N - 1\) 次通信,每次通信的数据量为 \(\frac{\Phi}{N}\)
    • 所以每个 GPU 的通信数据量为
      $$\frac{2(N - 1)\Phi}{N} \approx 2\Phi$$
  • 与传统 All-Reduce 相比,Ring All-Reduce 的每个 GPU 通信量显著减少,且通信量与设备数量 \(N\) 无关,只受限于逻辑环中最慢的两个 GPU 的连接
1…353637…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4