整体说明
- 深度学习的PS(Parameter Server,参数服务器)架构是一种分布式训练框架,用于高效管理大规模模型参数的更新与同步
- 其核心思想是将参数存储和计算分离,通过一个或多个中心服务器(Parameter Server)集中维护模型参数,而工作节点(Worker)负责本地计算和梯度更新,并通过网络与服务器通信
- PS 架构支持异步或同步训练,适合处理海量数据和超大规模模型,显著提升了分布式深度学习的扩展性和效率
角色介绍
Parameter Server(参数服务器,PS)
- 角色 :
- 负责存储和更新模型的全局参数(如神经网络的权重、梯度等)
- 聚合来自不同Worker的梯度,应用优化算法(如SGD)更新参数
- 提供参数的集中式管理,确保一致性
- 机器类型 :
- 通常是高性能的CPU服务器(因为参数更新是计算密集型操作,但不需要GPU的并行计算能力)
- 可以是单台机器或多台机器组成的集群(根据模型规模横向扩展)
Worker(工作节点)
- 角色 :
- 负责计算 :读取训练数据、执行前向传播和反向传播,计算本地梯度
- 将梯度发送给PS,或从PS拉取最新的参数(同步或异步更新)
- 如果Worker包含
Chief(主Worker),它还负责模型初始化、检查点保存、日志汇总等额外任务
- 机器类型 :
- 通常是配备GPU的机器(适合大规模矩阵运算)
- 每个Worker可以独立处理一个数据分片(数据并行)
其他辅助角色:
- Chief Worker(可选):
- 一个特殊的Worker(通常编号为
worker0),负责全局协调(如初始化参数、恢复训练、保存模型等) - 在TF 2.x中,这部分功能逐渐被整合到更高级的API(如
tf.distribute.Strategy)中
- 一个特殊的Worker(通常编号为
- Evaluator(评估器,可选):
- 独立于训练过程,定期加载模型快照进行验证/测试
一次完整训练过程
- 以下是包含 4个PS、8个Worker(含1个Chief Worker)和 1个Evaluator 的分布式TensorFlow训练过程的详细步骤
初始化阶段
- Chief Worker(worker0) :
- 构建计算图,定义模型结构(如神经网络层、损失函数、优化器等)
- 生成参数的初始值(如随机初始化),并将这些参数分片推送到4个PS(每个PS存储部分参数)
- 通知其他Worker和Evaluator初始化完成
- PS :
- 存储Chief Worker推送的初始参数(每个PS负责存储分配给自己的参数分片)
- 等待Worker的梯度更新请求
- 其他Worker(worker1~worker7) :
- 等待Chief Worker完成参数初始化
- 从PS拉取各自所需的参数分片
- Evaluator :
- 从PS拉取初始参数,准备后续的验证任务
训练阶段(同步更新示例)
步骤① Worker计算梯度
- 每个Worker(包括Chief):
- 从本地数据分片中读取一个batch的数据
- 从PS拉取最新的参数(全量或分片)
- 执行前向传播和反向传播,计算本地梯度
步骤② 梯度聚合与参数更新
- 等待所有8个 Worker 完成梯度计算并上传梯度到 PS
- PS 聚合所有 Worker 的梯度(求平均)
- PS 应用优化器(如SGD)更新参数
- 等待 PS 参数更新完成后,Worker 拉取新参数进入下一轮训练
步骤③ 循环迭代
- 重复步骤①~②,直到达到最大训练步数或收敛
验证(由 Evaluator 并行执行)
- 每隔 N 个训练步:
- Evaluator 从 PS 拉取最新参数
- 在独立的验证数据集上计算指标(如准确率、损失)
- 将结果反馈给 Chief Worker(可选)
检查点保存(由 Chief Worker 负责)
- 每隔 M 个训练步:
- Chief Worker 将模型参数和训练状态保存到磁盘(Chief 可部分读取PS参数,增量写入 Checkpoint)
- 如果训练中断,可从检查点恢复
关键交互流程说明
1 | Chief -> PS : 初始化(推送)/保存模型(拉取) |
角色分工总结
| 角色 | 数量 | 职责 |
|---|---|---|
| PS | 4 | 存储参数分片,接收梯度并更新参数。 |
| Worker | 8 | 计算梯度(worker0是Chief,负责初始化/保存模型)。 |
| Evaluator | 1 | 定期验证模型性能,不影响训练流程。 |
其他注意事项
- PS 不需要 GPU,使用大内存 + CPU 即可
- PS 的数量一般为1个即可,除非参数量很大(一个存不下),此外,如果 worker 数量太多时,也可以适当增加 PS 数量,防止网络带宽成为瓶颈
- 可选择异步训练,此时Worker无需等待其他节点,直接推送梯度到PS并更新参数(但可能梯度冲突)
- Worker不一定需要存储全部参数,每次可以仅拉去一个层或者某个特定参数进行计算
- 在 TensorFlow 的 PS 架构中,训练过程默认是异步的 ,但也可以通过配置实现同步训练
- PS 架构可以支持Torch,但需要结合特定的第三方工具或框架来实现分布式训练
附录:训练代码示例
一个简单的 PS 架构分布式训练代码Demo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101import tensorflow as tf
import os
import argparse
def build_model():
"""简单模型构建"""
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
return model
def main(args):
# 设置TF_CONFIG环境变量
tf_config = {
'cluster': {
'chief': [args.chief],
'worker': args.workers,
'ps': args.ps_servers, # 参数服务器角色
'evaluator': [args.evaluator]
},
'task': {
'type': args.task_type,
'index': args.task_index
}
}
os.environ['TF_CONFIG'] = tf.constant(tf_config) # 所有任务的这个配置都是一样的
# 根据任务类型选择不同的分布式策略
if args.task_type in ['chief', 'worker', 'ps']:
# 使用ParameterServerStrategy
strategy = tf.distribute.experimental.ParameterServerStrategy()
else:
# Evaluator不需要分布式策略
strategy = tf.distribute.get_strategy()
# 数据加载和预处理
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0
# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(64)
# 根据任务类型执行不同操作
if args.task_type == 'chief':
print("Running chief task...")
with strategy.scope():
# chief下,下面的语句会进行参数初始化
model = build_model() # 上面指定的strategy策略会依据不同角色做不同的事情,这里 ParameterServerStrategy 策略会在首席节点(chief)对模型参数进行初始化,之后再把这些参数分发给各个工作节点(worker)和参数服务器(PS)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(filepath='./checkpoints/model.ckpt'),
tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
model.fit(train_dataset, epochs=10, callbacks=callbacks)
print("Chief training completed.")
elif args.task_type == 'worker':
print(f"Running worker task {args.task_index}...")
with strategy.scope():
model = build_model() # strategy策略会依据不同角色做不同的事情
model.fit(train_dataset, epochs=10)
print(f"Worker {args.task_index} training completed.")
elif args.task_type == 'ps':
print(f"Running parameter server task {args.task_index}...")
# 参数服务器不需要显式代码,策略会自动管理
server = tf.distribute.Server(
tf_config['cluster']['ps'][args.task_index],
job_name="ps",
task_index=args.task_index
)
server.join() # 参数服务器会一直运行直到训练结束
elif args.task_type == 'evaluator':
print("Running evaluator task...")
model = build_model()
model.load_weights('./checkpoints/model.ckpt') # 加载最新的模型,进行评估工作
eval_results = model.evaluate(test_dataset)
print(f"Evaluation results: {eval_results}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--chief', type=str, required=True, help='Chief address')
parser.add_argument('--workers', type=str, nargs='+', required=True, help='Worker addresses')
parser.add_argument('--ps_servers', type=str, nargs='+', required=True, help='Parameter server addresses')
parser.add_argument('--evaluator', type=str, required=True, help='Evaluator address')
parser.add_argument('--task_type', type=str, required=True, choices=['chief', 'worker', 'ps', 'evaluator'], help='Task type')
parser.add_argument('--task_index', type=int, required=True, help='Task index')
args = parser.parse_args()
main(args)注:以上是TensorFlow 2.x版本(tf.distribute.experimental.ParameterServerStrategy()就是TensorFlow 2.x才有的)
- 实际上,TensorFlow 1.x也可以类似实现(使用tf.train.Server)
- 同时TensorFlow 1.x的 Estimator API提供了自己的一些自己的训练形式(tf.estimator)
启动脚本
注:以下脚本均以localhost为例,实际使用时需要替换为对应不同服务器的IP
启动chief节点:
1
2
3
4
5
6
7python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2225" \
--task_type="chief" \
--task_index=0启动worker节点:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17# Worker 0
python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2225" \
--task_type="worker" \
--task_index=0
# Worker 1
python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2225" \
--task_type="worker" \
--task_index=1启动evaluator节点:
1
2
3
4
5
6
7python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2225" \
--task_type="evaluator" \
--task_index=0启动 PS 节点
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17# 启动参数服务器0
python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2227" \
--task_type="ps" \
--task_index=0
# 启动参数服务器1
python distributed_training.py \
--chief="localhost:2222" \
--workers="localhost:2223" "localhost:2224" \
--ps_servers="localhost:2225" "localhost:2226" \
--evaluator="localhost:2227" \
--task_type="ps" \
--task_index=1