TensorFlow——分布式训练之PS架构


整体说明

  • 深度学习的PS(Parameter Server,参数服务器)架构是一种分布式训练框架,用于高效管理大规模模型参数的更新与同步
  • 其核心思想是将参数存储和计算分离,通过一个或多个中心服务器(Parameter Server)集中维护模型参数,而工作节点(Worker)负责本地计算和梯度更新,并通过网络与服务器通信
  • PS 架构支持异步或同步训练,适合处理海量数据和超大规模模型,显著提升了分布式深度学习的扩展性和效率

角色介绍

Parameter Server(参数服务器,PS)

  • 角色
    • 负责存储和更新模型的全局参数(如神经网络的权重、梯度等)
    • 聚合来自不同Worker的梯度,应用优化算法(如SGD)更新参数
    • 提供参数的集中式管理,确保一致性
  • 机器类型
    • 通常是高性能的CPU服务器(因为参数更新是计算密集型操作,但不需要GPU的并行计算能力)
    • 可以是单台机器或多台机器组成的集群(根据模型规模横向扩展)

Worker(工作节点)

  • 角色
    • 负责计算 :读取训练数据、执行前向传播和反向传播,计算本地梯度
    • 将梯度发送给PS,或从PS拉取最新的参数(同步或异步更新)
    • 如果Worker包含Chief(主Worker),它还负责模型初始化、检查点保存、日志汇总等额外任务
  • 机器类型
    • 通常是配备GPU的机器(适合大规模矩阵运算)
    • 每个Worker可以独立处理一个数据分片(数据并行)

其他辅助角色:

  • Chief Worker(可选):
    • 一个特殊的Worker(通常编号为worker0),负责全局协调(如初始化参数、恢复训练、保存模型等)
    • 在TF 2.x中,这部分功能逐渐被整合到更高级的API(如tf.distribute.Strategy)中
  • Evaluator(评估器,可选):
    • 独立于训练过程,定期加载模型快照进行验证/测试

一次完整训练过程

  • 以下是包含 4个PS8个Worker(含1个Chief Worker)和 1个Evaluator 的分布式TensorFlow训练过程的详细步骤

初始化阶段

  • Chief Worker(worker0)
    • 构建计算图,定义模型结构(如神经网络层、损失函数、优化器等)
    • 生成参数的初始值(如随机初始化),并将这些参数分片推送到4个PS(每个PS存储部分参数)
    • 通知其他Worker和Evaluator初始化完成
  • PS
    • 存储Chief Worker推送的初始参数(每个PS负责存储分配给自己的参数分片)
    • 等待Worker的梯度更新请求
  • 其他Worker(worker1~worker7)
    • 等待Chief Worker完成参数初始化
    • 从PS拉取各自所需的参数分片
  • Evaluator
    • 从PS拉取初始参数,准备后续的验证任务

训练阶段(同步更新示例)

步骤① Worker计算梯度
  • 每个Worker(包括Chief):
    • 从本地数据分片中读取一个batch的数据
    • 从PS拉取最新的参数(全量或分片)
    • 执行前向传播和反向传播,计算本地梯度
步骤② 梯度聚合与参数更新
  • 等待所有8个 Worker 完成梯度计算并上传梯度到 PS
  • PS 聚合所有 Worker 的梯度(求平均)
  • PS 应用优化器(如SGD)更新参数
  • 等待 PS 参数更新完成后,Worker 拉取新参数进入下一轮训练
步骤③ 循环迭代
  • 重复步骤①~②,直到达到最大训练步数或收敛

验证(由 Evaluator 并行执行)

  • 每隔 N 个训练步:
    • Evaluator 从 PS 拉取最新参数
    • 在独立的验证数据集上计算指标(如准确率、损失)
    • 将结果反馈给 Chief Worker(可选)

检查点保存(由 Chief Worker 负责)

  • 每隔 M 个训练步:
    • Chief Worker 将模型参数和训练状态保存到磁盘(Chief 可部分读取PS参数,增量写入 Checkpoint)
    • 如果训练中断,可从检查点恢复

关键交互流程说明

1
2
3
4
5
Chief -> PS : 初始化(推送)/保存模型(拉取)
Worker -> PS : 拉取参数
Worker -> PS : 推送梯度
PS -> PS : 内部同步参数分片(如需,一般怒需要这一步)
Evaluator -> PS : 拉取参数验证

角色分工总结

角色 数量 职责
PS 4 存储参数分片,接收梯度并更新参数。
Worker 8 计算梯度(worker0是Chief,负责初始化/保存模型)。
Evaluator 1 定期验证模型性能,不影响训练流程。

其他注意事项

  • PS 不需要 GPU,使用大内存 + CPU 即可
  • PS 的数量一般为1个即可,除非参数量很大(一个存不下),此外,如果 worker 数量太多时,也可以适当增加 PS 数量,防止网络带宽成为瓶颈
  • 可选择异步训练,此时Worker无需等待其他节点,直接推送梯度到PS并更新参数(但可能梯度冲突)
  • Worker不一定需要存储全部参数,每次可以仅拉去一个层或者某个特定参数进行计算
  • 在 TensorFlow 的 PS 架构中,训练过程默认是异步的 ,但也可以通过配置实现同步训练
  • PS 架构可以支持Torch,但需要结合特定的第三方工具或框架来实现分布式训练

附录:训练代码示例

  • 一个简单的 PS 架构分布式训练代码Demo

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    import tensorflow as tf
    import os
    import argparse

    def build_model():
    """简单模型构建"""
    model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
    tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
    )
    return model

    def main(args):
    # 设置TF_CONFIG环境变量
    tf_config = {
    'cluster': {
    'chief': [args.chief],
    'worker': args.workers,
    'ps': args.ps_servers, # 参数服务器角色
    'evaluator': [args.evaluator]
    },
    'task': {
    'type': args.task_type,
    'index': args.task_index
    }
    }
    os.environ['TF_CONFIG'] = tf.constant(tf_config) # 所有任务的这个配置都是一样的

    # 根据任务类型选择不同的分布式策略
    if args.task_type in ['chief', 'worker', 'ps']:
    # 使用ParameterServerStrategy
    strategy = tf.distribute.experimental.ParameterServerStrategy()
    else:
    # Evaluator不需要分布式策略
    strategy = tf.distribute.get_strategy()

    # 数据加载和预处理
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
    x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

    # 创建数据集
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)

    # 根据任务类型执行不同操作
    if args.task_type == 'chief':
    print("Running chief task...")
    with strategy.scope():
    # chief下,下面的语句会进行参数初始化
    model = build_model() # 上面指定的strategy策略会依据不同角色做不同的事情,这里 ParameterServerStrategy 策略会在首席节点(chief)对模型参数进行初始化,之后再把这些参数分发给各个工作节点(worker)和参数服务器(PS)

    callbacks = [
    tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints/model.ckpt'),
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
    ]

    model.fit(train_dataset, epochs=10, callbacks=callbacks)
    print("Chief training completed.")

    elif args.task_type == 'worker':
    print(f"Running worker task {args.task_index}...")
    with strategy.scope():
    model = build_model() # strategy策略会依据不同角色做不同的事情

    model.fit(train_dataset, epochs=10)
    print(f"Worker {args.task_index} training completed.")

    elif args.task_type == 'ps':
    print(f"Running parameter server task {args.task_index}...")
    # 参数服务器不需要显式代码,策略会自动管理
    server = tf.distribute.Server(
    tf_config['cluster']['ps'][args.task_index],
    job_name="ps",
    task_index=args.task_index
    )
    server.join() # 参数服务器会一直运行直到训练结束

    elif args.task_type == 'evaluator':
    print("Running evaluator task...")
    model = build_model()
    model.load_weights('./checkpoints/model.ckpt') # 加载最新的模型,进行评估工作
    eval_results = model.evaluate(test_dataset)
    print(f"Evaluation results: {eval_results}")

    if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--chief', type=str, required=True, help='Chief address')
    parser.add_argument('--workers', type=str, nargs='+', required=True, help='Worker addresses')
    parser.add_argument('--ps_servers', type=str, nargs='+', required=True, help='Parameter server addresses')
    parser.add_argument('--evaluator', type=str, required=True, help='Evaluator address')
    parser.add_argument('--task_type', type=str, required=True, choices=['chief', 'worker', 'ps', 'evaluator'], help='Task type')
    parser.add_argument('--task_index', type=int, required=True, help='Task index')

    args = parser.parse_args()
    main(args)
  • 注:以上是TensorFlow 2.x版本(tf.distribute.experimental.ParameterServerStrategy()就是TensorFlow 2.x才有的)

    • 实际上,TensorFlow 1.x也可以类似实现(使用tf.train.Server)
    • 同时TensorFlow 1.x的 Estimator API提供了自己的一些自己的训练形式(tf.estimator)

启动脚本

  • 注:以下脚本均以localhost为例,实际使用时需要替换为对应不同服务器的IP

  • 启动chief节点:

    1
    2
    3
    4
    5
    6
    7
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2225" \
    --task_type="chief" \
    --task_index=0
  • 启动worker节点:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # Worker 0
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2225" \
    --task_type="worker" \
    --task_index=0

    # Worker 1
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2225" \
    --task_type="worker" \
    --task_index=1
  • 启动evaluator节点:

    1
    2
    3
    4
    5
    6
    7
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2225" \
    --task_type="evaluator" \
    --task_index=0
  • 启动 PS 节点

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    # 启动参数服务器0
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2227" \
    --task_type="ps" \
    --task_index=0

    # 启动参数服务器1
    python distributed_training.py \
    --chief="localhost:2222" \
    --workers="localhost:2223" "localhost:2224" \
    --ps_servers="localhost:2225" "localhost:2226" \
    --evaluator="localhost:2227" \
    --task_type="ps" \
    --task_index=1