Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

DL——条件扩散模型

文本介绍Conditional Diffusion的相关知识

  • 参考链接:
    • 生成扩散模型漫谈(九):条件控制生成结果:论文的证明大多参考自苏神的博客
    • What are Diffusion Models?:Lilian Weng的博客
    • Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021:第一篇条件扩散模型的文章,Classifier Guidance 方法;
      • 文章附录:Supplemental中有相关推导
      • 实现讲解:sunlin-ai.github.io/guided-diffusion
    • Classifier-Free Diffusion Guidance, Google Research, Brain team, 2022:第一篇使用Classifier-free Guidance方法的文章
    • Score-based generative modeling through stochastic differential equations, Stanford & Google, 2021:随机微分方程的视角
    • 【笔记】扩散模型(五):Classifier-Free Guidance 理论推导与代码实现

整体概述

  • 在图像生成领域,条件一般可以分成类别标签或者是一段文本
  • 从方法上来看,条件控制生成的方式主要分两种:Classifier Guidance 和 Classifier-free Guidance
    • Classifier Guidance :先训练无条件的模型,训练后在生成时引入条件分类器(Classifier)。这种方式灵活,无需重新训练,但是生成效果较差
    • Classifier-free Guidance:训练时引入条件,生成时无需Classifier。这种方式需要重新训练Diffusion模型,但生成效果更好些

前置推导

  • 正太分布 \(\mathbf{x} \sim \mathcal{N}(\mathbf{\mu}, \sigma^2 \boldsymbol{I})\) 的梯度计算如下:
    $$\nabla_{\mathbf{x}}\log p(\mathbf{x}) = \nabla_{\mathbf{x}} \Big(-\frac{1}{2\sigma^2}(\mathbf{x} - \boldsymbol{\mu})^2 \Big) = - \frac{\mathbf{x} - \boldsymbol{\mu}}{\sigma^2} = - \frac{\boldsymbol{\epsilon}}{\sigma}$$
  • 结合 \(q(\mathbf{x}_t \vert \mathbf{x}_0) \sim \mathcal{N}(\sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\boldsymbol{I})\),可得:
    $$
    \mathbf{s}_\theta(\mathbf{x}_t, t)
    \approx \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t)
    = \mathbb{E}_{q(\mathbf{x}_0)} [\nabla_{\mathbf{x}_t} q(\mathbf{x}_t \vert \mathbf{x}_0)]
    = \mathbb{E}_{q(\mathbf{x}_0)} \Big[ - \frac{\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)}{\sqrt{1 - \bar{\alpha}_t}} \Big]
    = - \frac{\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)}{\sqrt{1 - \bar{\alpha}_t}}
    $$
    • 其中 \(\mathbf{s}_\theta(\mathbf{x}_t, t) \) 是Noise Conditional Score Network (NCSN),详情参见论文:(NCSN)Generative Modeling by Estimating Gradients of the Data Distribution, Stanford, 2020

Classifier Guidance 方法

Classifier Guidance Diffusion的直观理解

  • 参考自:通俗理解Classifier Guidance 和 Classifier-Free Guidance 的扩散模型

    2021年OpenAI在「Diffusion Models Beat GANs on Image Synthesis」中提出Classifier Guidance,使得扩散模型能够按类生成。后来「More Control for Free! Image Synthesis with Semantic Diffusion Guidance」把Classifier Guidance推广到了Semantic Diffusion,使得扩散模型可以按图像、按文本和多模态条件来生成,例如,风格化可以通过content和style两者共同进行引导,这些都是通过梯度引导来实现

    • Classifier Guidance可以通过Score function(问题:Score function是什么?)直观地解释,用贝叶斯定理将条件生成概率进行对数分解:
      $$
      \begin{align}
      \nabla_{x_t} \log p(x_t|y) &= \nabla_{x_t} \log \Big( \frac{p(x_t)p(y|x_t)}{p(y)}\Big) \\
      &= \nabla_{x_t} \log p(x_t) + \nabla_{x_t} \log p(y|x_t) - \nabla_{x_t} \log p(y) \\
      &= \underbrace{\nabla_{x_t} \log p(x_t)}_{\text{unconditional score}} + \underbrace{\nabla_{x_t} \log p(y|x_t)}_{\text{classifier gradient}}
      \end{align}
      $$
    • 从上式可以看到,Classifier Guidance条件生成只需额外添加一个classifier的梯度来引导。从成本上看,Classifier Guidance 需要训练噪声数据版本的classifier网络,推理时每一步都需要额外计算classifier的梯度
  • 以上是简单的理解,推导不严谨,论文接下来对DDPM和DDIM对应的Classifier Guidance Diffusion方法进行详细推导

Conditional Diffusion Process

  • 本小节相关推导参考自:Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021——Supplemental

    • 其他参考:深度学习(生成式模型)——Classifier Guidance Diffusion
  • 条件扩散模型相关定义 :

    • 参考普通扩散模型的前向过程 \(q(x_{t+1}|x_{t})\) 和后向过程为 \(q(x_{t}|x_{t+1})\)
    • 定义带条件的扩散模型的前向过程 \(\hat{q}(x_{t+1}|x_{t}, y)\) 和后向过程为 \(\hat{q}(x_t|x_{t+1}, y)\)
  • 设定:条件扩散模型的前向过程与普通扩散模型的前向过程一致,与条件无关
    $$
    \begin{align}
    \hat{q}(x_0) &:= q(x_0) \\
    \hat{q}(y|x_0) &:= \text{Known labels per sample} \\
    \hat{q}(x_{t+1}|x_t,y) &:= q(x_{t+1}|x_t) \\
    \hat{q}(x_{1:T}|x_0,y) &:= \prod_{t=1}^T \hat{q}(x_{t}|x_{t-1}, y) \\
    \end{align}
    $$

    • 直观理解:上式表名前向过程中对 \(x_0\) 如何增加噪音(\(\epsilon \sim \mathcal{N}(0,1)\))确实与 \(x_0\) 类别无关
  • 进一步地,我们可以推导得到,前向过程中的每一步加噪都与条件无关(以下内容详细推导可见论文:Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021——Supplemental):
    $$
    \begin{align}
    \hat{q}(x_{t+1}|x_{t}) &= \hat{q} (x_{t+1}|x_{t},y) \\
    \hat{q}(x_{1:T}|x_0) &= q(x_{1:T}|x_0) \\
    \hat{q}(x_t) &= q(x_t) \\
    \hat{q}(y|x_t, x_{t+1}) &= \hat{q}(y|x_t) \\
    \hat{q}(x_{t}|x_{t+1}) &= q(x_{t}|x_{t+1}) \\
    \end{align}
    $$

    • 注意,其中: \(\hat{q}(x_{t}|x_{t-1})\) 中不含有条件,实际是一个边际分布(或边缘分布):
      • \(\hat{q}(x_{t}|x_{t-1}) = \mathbb{E}_{y}[\hat{q}(x_{t}, y|x_{t-1})]\)
      • 同理有: \(\hat{q}(x_t) = \mathbb{E}_{y}[\hat{q}(x_t,y)] = \mathbb{E}_{y}[\mathbb{E}_{x_{0:t-1}}[\hat{q}(y,x_0,\cdots,x_{t})]]\)
      • 或: \(\hat{q}(x_t) = \mathbb{E}_{x_{0:t-1}}[\hat{q}(x_0,\cdots,x_{t})] = \mathbb{E}_{x_{0:t-1}}[\mathbb{E}_{y}[\hat{q}(y,x_0,\cdots,x_{t})]]\)
    • 关于 \(\hat{q}(y|x_t, x_{t+1}) = \hat{q}(y|x_t)\) 的理解:因为 \(x_{t+1}\) 是 \(x_t\) 增加前向噪音得到的,而前向过程不依赖于条件,该式子的证明如下:
      $$
      \begin{align}
      \hat{q}(y|x_{t},x_{t+1}) &= \frac{\hat{q}(y,x_{t},x_{t+1})}{\hat{q}(x_{t},x_{t+1})} = \frac{\hat{q}(x_{t+1}|y,x_{t})\hat{q}(y|x_{t})\hat{q}(x_{t})}{\hat{q}(x_{t+1}|x_{t})\hat{q}(x_{t})} \\
      &= \hat{q} (x_{t+1}|x_{t}, y) \frac{\hat{q}(y|x_{t})}{\hat{q}(x_{t+1}|x_{t})} \\
      &= \hat{q} (x_{t+1}|x_{t}) \frac{\hat{q}(y|x_{t})}{\hat{q}(x_{t+1}|x_{t})}\\
      &= \hat{q}(y|x_{t}) \\
      \end{align}
      $$
      • 其他补充说明 :因为前向过程与条件无关 ,可以理解为已知 \(x_t\) 时,条件 \(y\) 就与反向过程的前一步 \(x_{t+1}\) 无关了,故可以消掉 \(x_{t+1}\) ;但由于反向过程依赖条件 \(y\),不能消掉 \(x_t\),即 \(\hat{q}(y|x_t, x_{t+1}) \neq \hat{q}(y|x_{t+1})\),可理解为在相同的 \(x_{t+1}\) 下, \(x_t\) 与条件 \(y\) 相关( \(x_{t+1}\) 想通的条件下不同的条件 \(y\) 会生成不同的 \(x_t\) )
    • 关于 \(\hat{q}(x_{t}|x_{t+1}) = q(x_{t}|x_{t+1})\) 的理解:,该式子的证明如下:
      $$
      \begin{align}
      \hat{q}(x_{t}|x_{t+1}) &= \frac{\hat{q}(x_{t+1}|x_{t})\hat{q}(x_{t})}{\hat{q}(x_{t+1})} = \frac{q(x_{t+1}|x_{t})q(x_{t})}{q(x_{t+1})} = q(x_t|x_{t+1})
      \end{align}
      $$
  • 接下来我们到了最终目标:推导条件扩散模型的反向过程
    $$
    \begin{align}
    \hat{q}(x_{t}|x_{t+1},y) &= \frac{\hat{q}(x_{t},x_{t+1},y)}{\hat{q}(x_{t+1},y)} = \frac{\hat{q}(y|x_{t},x_{t+1})\hat{q}(x_{t}|x_{t+1})\hat{q}(x_{t+1})}{\hat{q}(y|x_{t+1})\hat{q}(x_{t+1})} \\
    &= \frac{\hat{q}(y|x_{t},x_{t+1})\hat{q}(x_{t}|x_{t+1})}{\hat{q}(y|x_{t+1})} \\
    &= \frac{\hat{q}(y|x_{t})\hat{q}(x_{t}|x_{t+1})}{\hat{q}(y|x_{t+1})} \\
    &= \frac{q(x_{t}|x_{t+1})\hat{q}(y|x_{t})}{\hat{q}(y|x_{t+1})} \\
    \end{align}
    $$

  • 至此,在已知 \(x_{t+1}\) 时,在条件 \(y\) 下,为了采样生成 \(x_t\),我们可以按照带条件的后向过程的概率 \(\hat{q}(x_{t}|x_{t+1},y) = \frac{q(x_{t}|x_{t+1})\hat{q}(y|x_{t})}{\hat{q}(y|x_{t+1})}\) 采样来实现。其中 \(\hat{q}(y|x_{t+1})\) 与 \(x_t\) 无关,可以设置为常数;于是可按照下面的形式采样(这里使用 \(\frac{1}{Z}\) 更容易理解):
    $$ x_t \sim \frac{1}{Z} q(x_{t}|x_{t+1})\hat{q}(y|x_{t}) $$

    • 具体场景下, \(q(x_{t}|x_{t+1})\) 是无条件的Diffusion后向过程,可以通过无条件的Diffusion网络 \(p_\theta(x_{t}|x_{t+1})\) 来实现; \(\hat{q}(y|x_{t})\) 则可通过在扰动后的数据集 \(x_t \sim q(x_t)\) 上训练一个分类器(Classifier)来实现;故最终采样形式可以表示如下形式:

Classifier Guidance DDPM

  • 本小节相关推导参考自:Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021——Supplemental,论文附录有其他相关推导可做参考

  • 推导过程:

    • 推导过程中使用到一个技巧,在高斯分布相关的推导中,将指数函数临时取对数来进行推导,推导完成后得到的仍然可以看做一个高斯分布的指数部分(在论文附录中提供的另一种推导方法没有使用到这个点)
  • 其他说明:在Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021中,作者发现,在分类器上增加一个梯度缩放参数 \(s\) (gradient scale \(s\) )可以更好的调节生成效果:
    $$
    \begin{equation}\boldsymbol{x}_{t-1} = \boldsymbol{\mu}(\boldsymbol{x}_t) \color{red}{+} {\color{red}{\underbrace{ \color{black}{s} \sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)}_{\text{新增项}}}} + \sigma_t\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})\end{equation}
    $$

    • 当 \(s > 1\) 时,生成过程中将更多由Classifier主导,提升生成结果和条件 \(y\) 的相关性,降低生成结果的多样性(没有改变方差,实际上多样性不会变化吧?);反之,则会降低生成结果的相关性,增加多样性
    • 关于 \(s\) 这个参数还有更多理解和推导,讨论详情可见:生成扩散模型漫谈(九):条件控制生成结果:
      • \(s\) 可以理解为在概率函数上的指数,用来控制图片的聚焦程度, \(s\) 越大,下面的条件概率 \(\tilde{p}(\boldsymbol{y}|\boldsymbol{x}_t)\) 越趋近于one-hot
        $$\begin{equation}\tilde{p}(\boldsymbol{y}|\boldsymbol{x}_t) = \frac{p^{s}(\boldsymbol{y}|\boldsymbol{x}_t)}{Z(\boldsymbol{x}_t)},\quad Z(\boldsymbol{x}_t)=\sum_{\boldsymbol{y}} p^{s}(\boldsymbol{y}|\boldsymbol{x}_t)\end{equation}$$
  • 其他证明方式见附录

Classifier Guidance DDIM

  • 本小节相关推导参考自:Diffusion Models Beat GANs on Image Synthesis, OpenAI, 2021——Supplemental,论文附录有其他相关推导可做参考

  • 推导过程:

    • 上图第一个公式中的梯度证明见论文的“前置推导”
    • 推导过程中使用了一个技巧,通过求解梯度的相似形式来对比得到最终结果
    • 公式(58)中 \(p_\theta(x_t)p_\phi(y|x_t)\) 没有直接来源,但可以从下面的推导看出,\(\nabla_{x_t} \log (p_\theta(x_t)p_\phi(y|x_t)) = \nabla_{x_t} \log p(x_t|y)\)
      $$
      \begin{align}
      \nabla_{x_t} \log p(x_t|y) &= \nabla_{x_t} \log p(x_t) +\nabla_{x_t} \log p(y|x_t) - \nabla_{x_t} \log p(y) \\
      &= \nabla_{x_t} \log p(x_t) +\nabla_{x_t} \log p(y|x_t) \\
      &= \nabla_{x_t} \log (p(x_t)p(y|x_t)) \\
      \end{align}
      $$
  • 其他说明,这里也可以在分类器的梯度上增加一个权重 \(w\),调节条件相关性和多样性:
    $$ \hat{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t) = \boldsymbol{\epsilon}_\theta(x_t, t) - \sqrt{1 - \bar{\alpha}_t} w \nabla_{\mathbf{x}_t} \log p_\phi(y \vert \mathbf{x}_t) $$

  • 其他证明方式见附录

基于Classifier Guidance的DDPM和DDIM的采样过程

  • Classifier Guidance DDPM和DDIM下伪代码如下:

    • 其中 \(\Sigma\) 表示协方差矩阵,有不同的实现方式,IDDPM中,使用了 \(\mathbf{\Sigma}_\theta(\mathbf{x}_t, t) = \exp(\mathbf{v} \log \beta_t + (1-\mathbf{v}) \log \tilde{\beta}_t)\), \(\mathbf{v}_{\theta}(x_t, t)\) 维度是与 \(\mathbf{x}_t\) 相同的,(详情见improved_diffusion源码:模型定义和improved_diffusion源码:向量使用)
    • 问题:矩阵 \(\Sigma\) 和梯度向量相乘的结果是向量,这里梯度向量是列向量才能相乘

Classifier-free Guidance

  • Classifier-free Guidance的方法直接讲条件添加到原始Diffusion模型中,从而使得模型在训练时就能感知到条件,因此,在采样(生成)时,不需要额外训练Classifier了
  • 一个直观的解释是:
    • 在原始分布 \(p(x_{t-1}|x_t)\) 增加条件 \(y\),则原始分布变成 \(p(x_{t-1}|x_t,y)\)
      $$
      \begin{equation}
      p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{y}) = \mathcal{N}(\boldsymbol{x}_{t-1}; \boldsymbol{\mu}(\boldsymbol{x}_t, \boldsymbol{y}),\sigma_t^2\boldsymbol{I})
      \end{equation}$$
    • 接着,在训练时,噪声网络上也增加输入 \(y\),即 \(\epsilon_\theta(x_t, y, t)\)
    • 最后,可以使用下面的式子来实现加权
      $$\begin{equation}\tilde{\boldsymbol{\epsilon}}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, \boldsymbol{y}, t) = (1 + w)\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, \boldsymbol{y}, t) - w \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)\end{equation}
      $$
      • 其中,无条件的 \(\epsilon_\theta(x_t, t)\) 通过设置一个特定的条件 \(\varnothing\) 获得,即 \(\epsilon_\theta(x_t, t) = \epsilon_\theta(x_t, \varnothing, t)\)
      • \(w\) 也叫做 guidance scale,用于调整生成条件相关性和多样性
  • 接下来我们介绍详细推导过程,首先,回顾条件概率转换:
    $$
    \begin{align}
    \nabla_{x_t} \log p(x_t|y) &= \nabla_{x_t} \log \Big( \frac{p(x_t)p(y|x_t)}{p(y)}\Big) \\
    &= \nabla_{x_t} \log p(x_t) + \nabla_{x_t} \log p(y|x_t) - \nabla_{x_t} \log p(y) \\
    &= \underbrace{\nabla_{x_t} \log p(x_t)}_{\text{unconditional score}} + \underbrace{\nabla_{x_t} \log p(y|x_t)}_{\text{classifier gradient}}
    \end{align}
    $$
  • 于是有:
    $$ \nabla_{\mathbf{x}_t} \log p(y \vert \mathbf{x}_t) = \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t \vert y) - \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t) $$
  • 进一步地,有(以下推导来自What are Diffusion Models?的推导):
    $$
    \begin{aligned}
    \nabla_{\mathbf{x}_t} \log p(y \vert \mathbf{x}_t)
    &= \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t \vert y) - \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t) \\
    &= - \frac{1}{\sqrt{1 - \bar{\alpha}_t}}\Big( \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t, y) - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \Big) \\
    \end{aligned}
    $$
  • 回顾DDIM Classifier Guidance里面的推导有:
    $$
    \begin{aligned}
    \bar{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t, y)
    &= \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t, y) - \sqrt{1 - \bar{\alpha}_t} \ w \nabla_{\mathbf{x}_t} \log p(y \vert \mathbf{x}_t) \\
    &= \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t, y) + w \big(\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t, y) - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \big) \\
    &= (w+1) \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t, y) - w \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)
    \end{aligned}
    $$
  • 采样时,利用将DDPM或者DDIM中的 \(\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\) 替换为 \(\bar{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t, y)\), 正常采样即可
  • 总结一下,Classifier-free Guidance方法的训练流程和采样流程如下:

Classifier-free Guidance和Classifier Guidance谁更好?

  • GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models, OpenAI, 2022对CLIP Guidance和Classifier-free Guidance两种策略进行实验,结果发现Classifier-free Guidance效果更好

    The guided diffusion model, GLIDE (Nichol, Dhariwal & Ramesh, et al. 2022), explored both guiding strategies, CLIP guidance and classifier-free guidance, and found that the latter is more preferred. They hypothesized that it is because CLIP guidance exploits the model with adversarial examples towards the CLIP model, rather than optimize the better matched images generation.


附录:DDPM和DDIM加入条件以后的采样形式为什么不同呢?

  • 可以证明,在设置DDPM中协方差矩阵 \(\Sigma\) 使用固定值后,DDPM和DDIM的Classifier Guidance 形式基本是非常相似的,都可以看做是用梯度 \(\nabla_{x_t}\log p_\phi(y|x_t)\) 对 \(\mathbf{\epsilon}_\theta(x_t, t)\) 进行修正【有时间时可以详细推导一下】

  • 补充对比:

    • 带条件的训练形式:

    • 原始DDPM采样形式
      $$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\Big( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t) \Big) + \sigma_t z $$

    • 原始DDIM采样形式
      $$x_s = \sqrt{\bar{\alpha}_s}\left(\frac{x_k-\sqrt{1-\bar{\alpha}_k}\epsilon_{\theta}(x_k,k)}{\sqrt{\bar{\alpha}_k}}\right) + \sqrt{1-\bar{\alpha}_s-a_1\sigma_k^2}\epsilon_{\theta}(x_k,k) + a_2\sigma_k \epsilon$$


附录:其他推导-Classifier Guidance DDPM

  • 下面的推导参考苏神生成扩散模型漫谈(九):条件控制生成结果的推导
  • 在推导过程中,在原始分布 \(p(x_{t-1}|x_t)\) 增加条件 \(y\),则原始分布变成 \(p(x_{t-1}|x_t,y)\)
  • 根据贝叶斯公式有:
    $$\begin{equation}p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{y}) = \frac{p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)p(\boldsymbol{y}|\boldsymbol{x}_{t-1}, \boldsymbol{x}_t)}{p(\boldsymbol{y}|\boldsymbol{x}_t)}\label{eq:bayes-1}\end{equation}$$
  • 由于前向过程与条件无关, \(x_t\) 是由 \(x_{t-1}\) 加上与条件无关的噪声得到的,所以 \(p(\boldsymbol{y}|\boldsymbol{x}_{t-1}, \boldsymbol{x}_t) = p(\boldsymbol{y}|\boldsymbol{x}_{t-1})\),这在前文已经得到证明
  • 于是有:
    $$ \begin{equation}p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{y}) = \frac{p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)p(\boldsymbol{y}|\boldsymbol{x}_{t-1})}{p(\boldsymbol{y}|\boldsymbol{x}_t)} = p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) \exp({\log p(\boldsymbol{y}|\boldsymbol{x}_{t-1}) - \log p(\boldsymbol{y}|\boldsymbol{x}_t)})\label{eq:bayes-2}\end{equation} $$
  • 接下来我们要求解 \(\log p(\boldsymbol{y}|\boldsymbol{x}_{t-1}) - \log p(\boldsymbol{y}|\boldsymbol{x}_t)\) 这一项,首先做一下近似,当 \(T\) 足够大时,有 \(x_t\) 和 \(x_{t-1}\) 非常接近,故而可以用泰勒展开来近似:
    $$ \begin{equation}\log p(\boldsymbol{y}|\boldsymbol{x}_{t-1}) - \log p(\boldsymbol{y}|\boldsymbol{x}_t)\approx (\boldsymbol{x}_{t-1} - \boldsymbol{x}_t)\cdot\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)\end{equation}$$
    • 其中 \(\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)\) 是一种简写形式,实际上有 \(\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t) = \nabla_{\boldsymbol{x}} \log p(\boldsymbol{y}|\boldsymbol{x})\vert_{\boldsymbol{x} = \boldsymbol{x}_t}\),表示函数在 \(\boldsymbol{x}_t\) 处的梯度
  • 假设 \(p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)=\mathcal{N}(\boldsymbol{x}_{t-1};\boldsymbol{\mu}(\boldsymbol{x}_t),\sigma_t^2\boldsymbol{I})\propto \exp({-\frac{\Vert \boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t)\Vert^2}{2\sigma_t^2}})\),则有
    $$
    \begin{aligned}
    p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t, \boldsymbol{y}) =&\ p(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t) \exp({\log p(\boldsymbol{y}|\boldsymbol{x}_{t-1}) - \log p(\boldsymbol{y}|\boldsymbol{x}_t)}) \\
    \propto&\ \exp\Big({-\frac{\Vert \boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t)\Vert^2}{2\sigma_t^2} + (\boldsymbol{x}_{t-1} - \boldsymbol{x}_t)\cdot\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)}\Big) \\
    \propto&\ \exp\Big(\frac{-\Vert \boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t) - \sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)\Vert^2}{2\sigma_t^2}\Big)
    \end{aligned}
    $$
    • 其中 \(\exp\Big({-\frac{\Vert \boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t)\Vert^2}{2\sigma_t^2} + (\boldsymbol{x}_{t-1} - \boldsymbol{x}_t)\cdot\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)}\Big) \propto\ \exp\Big(\frac{-\Vert \boldsymbol{x}_{t-1} - \boldsymbol{\mu}(\boldsymbol{x}_t) - \sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)\Vert^2}{2\sigma_t^2}\Big) \) 是指数中配方后,略去了与 \(\boldsymbol{x}_{t-1}\) 无关的项,所以使用的是“正比于”(在指数部分,包括 \(\boldsymbol{x}_{t}\) 也与 \(\boldsymbol{x}_{t-1}\) 无关,可以略去,因为分布仅留下与 \(\boldsymbol{x}_{t-1}\) 相关的部分即可,注:标准的正太分布概率密度函数为 \(p(x) = \frac{1}{\sqrt{2\pi}\sigma} \exp(-\frac{(x-\mu)^2}{2\sigma^2}) \propto \exp(-\frac{(x-\mu)^2}{2\sigma^2})\),只关注与 \(x\) 有关的项即可)
  • 将上面的形式转换成正太分布,有:
    $$ \mathcal{N}(\boldsymbol{x}_{t-1};\boldsymbol{\mu}(\boldsymbol{x}_t) + \sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t),\sigma_t^2\boldsymbol{I})$$
  • 也就是说按照下面形式采样即可实现条件的加入:
    $$
    \begin{equation}\boldsymbol{x}_{t-1} = \boldsymbol{\mu}(\boldsymbol{x}_t) \color{red}{+} {\color{red}{\underbrace{\sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)}_{\text{新增项}}}} + \sigma_t\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})\end{equation}
    $$
  • 在原始论文Diffusion Models Beat GANs on Image SynthesisClassifier Guidance DDPM的采样形式为:
    $$\begin{equation}\sigma_t^2 \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{y}|\boldsymbol{x}_t)|_{\boldsymbol{x}_t=\boldsymbol{\mu}(\boldsymbol{x}_t)}\end{equation}$$
    • 这个形式与论文推导结果不同,生成扩散模型漫谈(九):条件控制生成结果中指出两者差不多(问题:零阶近似是指在 \(x_t=\mu(x_t)\) 处进行泰勒展开吗?需要进一步理解):

      论文中梯度项在 \(\boldsymbol{\mu}(\boldsymbol{x}_t)\) 处的结果而非 \(\boldsymbol{x}_t\) 处,而一般情况下 \(\boldsymbol{\mu}(\boldsymbol{x}_t)\) 的零阶近似正是 \(\boldsymbol{x}_t\),所以两者结果是差不多的

    • 理解:梯度的位置决定了建模价值模型时使用 \(\boldsymbol{x}_t\) 还是 \(\boldsymbol{\mu}(\boldsymbol{x}_t)\),直观上理解,应该是在 \(\boldsymbol{\mu}(\boldsymbol{x}_t)\) 上进行梯度偏移,所以应该是在 \(\boldsymbol{\mu}(\boldsymbol{x}_t)\) 的梯度才对
      • Diffuser中代码实现比较奇怪,与论文伪代码不同:这部分实现的Diffuser详细代码可见Diffuser源码-采样函数
      • 其他代码实现参考:

附录:其他推导-Classifier Guidance DDIM

本小节推导主要参考自What are Diffusion Models?

  • 由[(NCSN)Generative Modeling by Estimating Gradients of the Data Distribution, Stanford, 2020]和[Score-based generative modeling through stochastic differential equations, Stanford & Google, 2021]可以知道,只需要求得目标分布的对数概率梯度即可按照该梯度进行采样,原始分布的对数概率梯度为(详细证明见论文“前置推导”):
    $$ \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t) = - \frac{1}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) $$
  • 增加条件以后的联合概率梯度为(此时的 \(\nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t, y)\) 与 \(\nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t\vert y)\) 是等价的,因为 \(\nabla_{\mathbf{x}_t}\log q(y)=0\) ):
    $$
    \begin{aligned}
    \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t, y)
    &= \nabla_{\mathbf{x}_t} \log q(\mathbf{x}_t) + \nabla_{\mathbf{x}_t} \log q(y \vert \mathbf{x}_t) \\
    &\approx - \frac{1}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) + \nabla_{\mathbf{x}_t} \log f_\phi(y \vert \mathbf{x}_t) \\
    &= - \frac{1}{\sqrt{1 - \bar{\alpha}_t}} (\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) - \sqrt{1 - \bar{\alpha}_t} \nabla_{\mathbf{x}_t} \log f_\phi(y \vert \mathbf{x}_t))
    \end{aligned}
    $$
  • 对照两个概率的梯度形式,可以得到 \(\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\) 变成如下形式即可将 \(q(\mathbf{x}_t)\) 替换为带条件的 \(q(\mathbf{x}_t, y)\):
    $$\bar{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t) = \boldsymbol{\epsilon}_\theta(x_t, t) - \sqrt{1 - \bar{\alpha}_t} \nabla_{\mathbf{x}_t} \log f_\phi(y \vert \mathbf{x}_t)$$
  • 为了权衡Classifier Guidance的强度,可以添加一个权重 \(w\):
    $$ \bar{\boldsymbol{\epsilon}}_\theta(\mathbf{x}_t, t) = \boldsymbol{\epsilon}_\theta(x_t, t) - \sqrt{1 - \bar{\alpha}_t} w \nabla_{\mathbf{x}_t} \log f_\phi(y \vert \mathbf{x}_t) $$

附录:问题汇总

  • Classifier Guidance方法在训练Classifier时,需要包含随机扰动后的数据吧?
  • DDPM推导时(附录和正文)得到了两种不同的梯度位置 \(x_t\) 和 \(\mu_t\),实现时应该是用哪个呢?
  • Classifier Guidance方法训练过程中有哪些需要注意的?
  • Classifier-free Guidance方法训练过程中有哪些需要注意的?

DL——梯度检查点技术

  • 参考链接:
    • 原始论文 Training Deep Nets with Sublinear Memory Cost, MIT, 2016:论文以这篇文章的内容为主进行介绍,后续会补充一些 transformer 场景下的梯度检查点技术

整体总结

  • 梯度检查点(Gradient Checkpointing)技术广泛应用于当前大规模深度学习模型训练中,能有效降低显存的使用
  • 梯度检查点技术也叫做 重计算(re-materialization)技术
  • Training Deep Nets with Sublinear Memory Cost, MIT, 2016这篇文章的核心贡献是:
    • 提出并系统化了一种 “用计算换内存”的通用方法(在自动微分领域也被称为梯度检查点技术)
    • 通过分段和重计算,巧妙地将深度学习训练过程中的内存瓶颈从 \(O(n)\) 降低到 \(O(\sqrt{n})\),甚至理论上的 \(O(\log n)\),极大地拓展了论文使用现有硬件能够训练的模型规模和深度

核心问题:为什么训练深度网络如此消耗内存?

  • 在训练神经网络时,标准的训练流程包含两个步骤:
    • 前向传播 (Forward Pass) :
      • 输入数据从网络的第一层开始,逐层计算,直到最后一层得到输出
      • 在这个过程中,每一层的输出(称为“激活”或“特征图”)都需要被保存下来
    • 反向传播 (Backward Pass) :
      • 计算输出的损失(预测值与真实值的差距),然后从最后一层开始,反向逐层计算梯度
      • 为了计算某一层参数的梯度,通常需要用到该层在前向传播时产生的激活值
  • 问题出现:
    • 为了进行反向传播,必须在内存中保留网络中每一层的激活值
    • 如果一个网络有 \(n\) 层,那么内存消耗就大致与 \(n\) 成正比 ,即内存成本为 \(O(n)\)
    • 对于现在动辄成百上千层的深度模型(如 ResNet),这笔内存开销会迅速占满顶配 GPU 的几十 GB 显存,从而限制了探索更深、更复杂模型的能力

传统内存优化方法(治标不治本)

  • 论文首先提到了一些已有的内存优化技术,这些技术主要通过分析计算图 (Computation Graph) 来实现
    • 原地操作 (In-place Operation) :
      • 如果一个操作的输入值在后续计算中不再需要,那么其输出可以直接覆盖输入的内存空间
      • 例如,y = relu(x),如果 x 后面用不到了,y 的结果可以直接写在 x 的内存里
    • 内存共享 (Memory Sharing) :
      • 分析所有变量的“生命周期”,将生命周期不重叠的变量共享同一块内存
  • 这些方法能将内存占用降低2到3倍,但无法改变内存消耗随网络层数线性增长的趋势,当网络深到一定程度时,内存瓶颈依然存在

梯度检查点:用计算换内存 (Trade Computation for Memory)

  • 既然保存所有中间结果是内存消耗的根源,那么论文提出的核心思想非常直接:
    • 不保存所有中间结果,只保存其中一部分
    • 当反向传播需要用到某个被丢弃的中间结果时,再临时重新计算它
  • 这是一种典型的“用时间换空间”的策略
    • 虽然会增加一些计算量(因为需要重新执行部分前向计算),但可以极大地降低内存峰值

工作原理一:\(O(\sqrt{n})\) 内存成本算法(实用策略)

  • 这是论文提出的主要实用算法
  • \(O(\sqrt{n})\) 内存成本算法原理如下:
    • 1)分段 (Segmenting) :将一个包含 \(n\) 层的网络链条,切分成 \(k\) 个小段(segment)
    • 2)前向传播 :在正常的前向传播过程中,只保存每个分段的最终输出 ,而丢弃每个分段内部的所有中间激活值
    • 3)反向传播 :
      • 当反向传播进行到第 \(i\) 段时,由于计算该段的梯度需要其内部的激活值(这些值已经被丢弃了),算法会执行一次“局部前向传播”:
        • 利用保存的第 \(i-1\) 段的输出作为输入,重新计算一次第 \(i\) 段的前向传播,以得到所有需要的激活值
      • 计算完梯度后,这些临时重新计算的激活值可以立即被丢弃

关键推导:为什么是 \(O(\sqrt{n})\)?

  • 内存成本分析:假设网络总共有 \(n\) 层,被均匀地切分成 \(k\) 段,那么每一段的长度就是 \(n/k\)层
  • 总内存成本主要由两部分构成:
    • 1)段间内存 (Inter-segment Memory) :用于存储 \(k\) 个分段的输出,以便在反向传播时作为“检查点”(checkpoint)
      • 这部分的成本是 \(O(k)\)
    • 2)段内内存 (Intra-segment Memory) :在对任何一段进行反向传播时,需要临时重新计算并存储该段内部的所有激活值
      • 由于所有段中最大的内存开销决定了峰值,这部分的成本是 \(O(n/k)\)
  • 因此,总的内存成本可以表示为:
    $$\text{Cost}(n, k) = O(k) + O(n/k)$$
  • 为了让总成本最低,论文需要让这两部分达到一个平衡。一个简单的优化方法是让它们的量级相等(实际上等价于导数为 0 的推导):
    $$k \approx \frac{n}{k} \implies k^2 \approx n \implies k = \sqrt{n}$$
  • 当选择 \(k=\sqrt{n}\) 时,总内存成本为:
    $$\text{Cost} = O(\sqrt{n}) + O(n/\sqrt{n}) = O(\sqrt{n}) + O(\sqrt{n}) = O(\sqrt{n})$$
  • 至此,我们就成功地将内存成本从线性 \(O(n)\) 降到了亚线性 \(O(\sqrt{n})\)
  • 作为代价,整个训练过程大约需要额外进行一次完整的前向传播计算(因为每个分段都被重新计算了一次),使得训练时间增加了约 30%
    • 问题:30% 的计算时间怎么来的呢?
    • 回答:反向传播的时间复杂度大约是前向传播的 2~3 倍,折合计算以后大致能算出这个数字(增加了一次前向传播计算),详情见附录

工作原理二:\(O(\log n)\) 内存成本算法(理论极限)

  • 论文进一步展示,通过递归 (Recursion) 的方式,可以实现更低的内存成本
  • 我们可以把一个分段本身看作一个“超级操作符”
  • 对这个超级操作符内部的计算,作者同样可以再次应用分段和重计算的策略

关键推导:\(O(\log n)\) 的递推关系

  • 让 \(g(n)\) 表示训练一个 \(n\) 层网络所需的内存
  • 假设将这个网络分成 \(k+1\) 个子问题,每个子问题的规模是 \(n/(k+1)\)
  • 为了连接这些子问题,需要存储 \(k\) 个中间结果
  • 那么,\(g(n)\) 可以表示为递推公式:
    $$g(n) = k + g\left(\frac{n}{k+1}\right)$$
    • 这是一个典型的对数关系
    • \(k\) 是存储这 \(k\) 个结果的成本
    • \(g\left(\frac{n}{k+1}\right)\) 是解决其中一个子问题所需的成本
    • 注:这里使用 \(k+1\) 或 \(k\) 不影响最终结果
  • 通过解这个递推公式,我们可以得到:
    $$g(n) = k \cdot \log_{k+1}(n)$$
  • 作为一个特例,如果每次只将问题一分为二,即只存储一个中间结果(\(k=1\)),那么递推关系变为 \(g(n) = 1 + g(n/2)\),解得:
    $$g(n) = \log_2(n)$$
    • 注:简单理解一下,展开 \(g(n)\) 后,大致共有 \(\log_2(n)\) 个 1
  • 这揭示了一个终极的理论可能性:
    • 训练一个 \(n\) 层网络的内存成本可以降低到 \(O(\log n)\)
    • 不过,这种方法的计算开销会大得多(需要 \(O(\log n)\) 次额外的前向传播,因此在实践中不如 \(O(\sqrt{n})\) 策略常用

实验效果与结论

  • 作者通过在深度残差网络 (ResNet) 和长短期记忆网络 (LSTM) 上的实验,验证了该方法的有效性
    • 对于一个 1000层 的 ResNet,标准优化方法需要 48GB 显存,而使用亚线性算法后仅需 7GB
    • 在 LSTM 上,该方法同样带来了 超过4倍 的内存节省
    • 代价是训练速度降低了大约30%,这对于能够训练原本无法训练的模型来说,是一个非常值得的交换

Transformer 中的梯度检查点

  • Transformer中的梯度检查点(Gradient Checkpointing)与上述论文中的基本原理上是完全相同的 ,但其应用方式和带来的收益上,针对Transformer 的结构有更多特点
  • 无论是用于CNN、RNN还是Transformer,梯度检查点的核心思想始终是:
    • 目标 :打破模型训练时内存消耗与网络深度(层数)之间的线性关系
    • 方法 :在前向传播时,不再保存所有中间层的激活值,而是只保存少数几个关键节点(检查点)
    • 代价 :在反向传播时,当需要用到被丢弃的激活值时,就从最近的一个检查点开始,重新进行一小段前向计算来恢复它们
    • 权衡 :本质上都是“用计算换内存”的策略

针对 Transformer 结构的应用说明

  • Transformer的独特结构使得梯度检查点的应用非常直接,且效果尤其显著
  • 1. 应用位置非常明确
    • 一个标准的 Transformer 模型是由一个个完全相同的 Transformer Block 堆叠而成的
      • 每个块通常包含一个多头自注意力(Multi-Head Self-Attention)层和一个前馈神经网络(Feed-Forward Network, FFN)层
    • 最自然、最常见的应用方式就是 将每一个Transformer块作为一个分段(Segment)
      • 前向传播时 :
        • 当数据流经第 \(i\) 个Transformer块时,只保留送入这个块的输入(也就是第 \(i-1\) 块的输出)
        • 在块内部计算过程中产生的所有中间结果,例如注意力分数矩阵(Attention Scores)、注意力权重(Attention Weights)、FFN层的中间激活等,计算完毕后立即被丢弃
      • 反向传播时 :
        • 当反向传播回第 \(i\) 个块时,算法会利用之前保存的输入,重新执行一次该块的前向计算,从而得到计算梯度所必需的那些中间结果
  • 2. 带来的收益为何对 Transformer 尤其显著
    • 梯度检查点能极大缓解Transformer在两个维度上的内存压力:
    • 深度(层数 \(L\)) :
      • 现代的大型语言模型(如GPT、LLaMA)可以有几十甚至上百层
      • 如果没有梯度检查点,内存消耗会随着层数 \(L\) 线性暴增。梯度检查点将这个成本从 \(O(L)\) 降到了 \(O(\sqrt{L})\) ,使得训练极深的Transformer 成为可能
    • 序列长度(Sequence Length \(S\)) :这是Transformer最独特的内存瓶颈
      • 注意力矩阵的二次方开销 :
        • 自注意力机制的核心是计算一个注意力分数矩阵,其大小为 (序列长度 x 序列长度)
        • 这意味着内存开销与序列长度成二次方关系 ,即 \(O(S^2)\)
        • 当序列很长时(例如4096、8192甚至更长),这个矩阵会变得异常巨大
      • 梯度检查点的作用 :
        • 梯度检查点不能改变单次注意力计算需要 \(O(S^2)\) 内存峰值的事实
        • 但它能确保不必同时在内存中保留每一层的这个巨大矩阵
        • 在没有检查点的情况下,内存中需要为 \(L\) 个注意力矩阵的激活值(或其相关值)分配空间
        • 有了检查点,在任何时候,只需要为当前正在重计算的那一个块的注意力矩阵分配内存
        • 这极大地降低了总体的内存占用

附录:梯度检查点技术增加了多少训练成本?

  • 论文中 “训练时间增加了约30%” 这个数字主要是一个经验性的测量结果 ,来源于作者在特定硬件上进行的基准测试,并且这个结果也与理论上的计算开销分析相符

实验测量结果(经验来源)

  • 作者在论文的第5.4节 (Impact on Training Speed) 和 图7 专门讨论了这个问题
    • 测试方法 :
      • 作者在单个 Titan X GPU 上对不同的内存分配策略进行了速度基准测试
      • 测量了在 ResNet 和 LSTM 两种模型上,处理一个批次(batch)数据所需的实际运行时间(秒)
    • 对比对象 :
      • 比较了采用标准内存优化(论文中称为 “sharing”)的策略和采用亚线性内存成本(”sublinear plan”)策略的速度
    • 测量结论 :
      • 实验结果图表(图7)直观地显示,”sublinear plan” 的时间成本曲线始终在 “sharing” 曲线之上
      • 论文在图7的说明文字和正文中明确指出,使用亚线性内存方案会带来“大约30%的额外运行时成本”

理论计算分析(理论支撑)

  • 训练一个批次的主要计算量如下:
    • 标准训练流程 :包含一次完整的前向传播(Forward Pass)和一次完整的反向传播(Backward Pass):
      $$ 总计算量 \approx 1F + 1B $$
    • 亚线性方案的流程 :它在反向传播过程中需要重新计算前向传播,因此:
      $$ 总计算量 \approx 1F + (1F_recompute + 1B) = 2F + 1B $$
  • 论文中提到,通常一次反向传播的计算量大约是前向传播的两倍(\(B \approx 2F\))
    • 实际上是大约 2~3 倍的样子
  • 基于上述这个假设,我们可以估算增加的计算开销:
    • 标准流程计算量 :
      $$T_{standard} = T_F + T_B \approx T_F + 2T_F = 3T_F$$
    • 亚线性方案计算量 :
      $$T_{sublinear} = 2T_F + T_B \approx 2T_F + 2T_F = 4T_F$$
  • 增加的运行时间百分比约为:
    $$\frac{T_{sublinear} - T_{standard}}{T_{standard}} = \frac{4T_F - 3T_F}{3T_F} = \frac{1T_F}{3T_F} = \frac{1}{3} \approx 33.3%$$
  • 这个理论计算出的 33.3% 与实验中测量到的 约30% 基本吻合(考虑到反向传播大致是前向传播的 2~3 倍,那就基本符合了)

KG——知识图谱的描述

参考博客: https://blog.csdn.net/u011801161/article/details/78833958


RDF

  • Resource Description Framework
  • 资源描述框架
  • 本质是一个数据模型
  • 提供了统一的描述实体和资源的标准
  • 形式上表现为主谓宾(SPO, Subject-Predication-Object)三元组, 也称为一条语句(Statement), 知识图谱中称为一条知识

RDF的序列化方法

参考博客: https://blog.csdn.net/u011801161/article/details/78833958

  • RDF/XML: 用XML格式来表示RDF数据

  • N-Triples: 用多个三元组来表示RDF数据集合,是最直观的表示方法,每一行表示一个三元组,方便机器解析和处理,DBpedia 是按照这个方式来发布数据的

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/chineseName> "罗纳尔多·路易斯·纳萨里奥·德·利马"^^string.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/career> "足球运动员"^^string.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/fullName> "Ronaldo Luís Nazário de Lima"^^string.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/birthDate> "1976-09-18"^^date.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/height> "180"^^int.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/weight> "98"^^int.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/nationality> "巴西"^^string.
    <http://www.kg.com/person/1> <http://www.kg.com/ontology/hasBirthPlace> <http://www.kg.com/place/10086>.
    <http://www.kg.com/place/10086> <http://www.kg.com/ontology/address> "里约热内卢"^^string.
    <http://www.kg.com/place/10086> <http://www.kg.com/ontology/coordinate> "-22.908333, -43.196389"^^string.
  • RDFa: (The Resource Description Framework in Attributes)

  • Turtle是最常用的RDF序列化方式, 比RDF/XML更紧凑, 可读性比N-Triples更好

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    Example2 Turtle:

    @prefix person: <http://www.kg.com/person/> .
    @prefix place: <http://www.kg.com/place/> .
    @prefix : <http://www.kg.com/ontology/> .

    person:1 :chineseName "罗纳尔多·路易斯·纳萨里奥·德·利马"^^string.
    person:1 :career "足球运动员"^^string.
    person:1 :fullName "Ronaldo Luís Nazário de Lima"^^string.
    person:1 :birthDate "1976-09-18"^^date.
    person:1 :height "180"^^int.
    person:1 :weight "98"^^int.
    person:1 :nationality "巴西"^^string.
    person:1 :hasBirthPlace place:10086.
    place:10086 :address "里约热内卢"^^string.
    place:10086 :address "-22.908333, -43.196389"^^string.
    • 同一个实体拥有多个属性(数据属性)或关系(对象属性),我们可以只用一个subject来表示,使其更紧凑。我们可以将上面的Turtle改为
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      Example3 Turtle:

      @prefix person: <http://www.kg.com/person/> .
      @prefix place: <http://www.kg.com/place/> .
      @prefix : <http://www.kg.com/ontology/> .

      person:1 :chineseName "罗纳尔多·路易斯·纳萨里奥·德·利马"^^string;
      :career "足球运动员"^^string;
      :fullName "Ronaldo Luís Nazário de Lima"^^string;
      :birthDate "1976-09-18"^^date;
      :height "180"^^int;
      :weight "98"^^int;
      :nationality "巴西"^^string;
      :hasBirthPlace place:10086.
      place:10086 :address "里约热内卢"^^string;
      :address "-22.908333, -43.196389"^^string.
  • JSON-LD: 即“JSON for Linking Data”,用键值对的方式来存储RDF数据

    1
    2
    3
    4
    5
    6
    7
    {
    "@context": "https://json-ld.org/contexts/person.jsonld",
    "@id": "http://dbpedia.org/resource/John_Lennon",
    "name": "John Lennon",
    "born": "1940-10-09",
    "spouse": "http://dbpedia.org/resource/Cynthia_Lennon"
    }

RDF的缺点

  • 表达能力有限
    • 无法区分雷和对象
    • 无法定义和描述类的关系/属性

RDFS/OWL

  • 是RDF的一种扩展
  • 是用来描述RDF数据的
  • 本质上是一些预定义词汇(Vocabulary)构成的集合
  • 用于对RDF进行类似的类定义以及属性的定义

RDFS/OWL的序列化方法

  • RDFS/OWL序列化方式和RDF没什么不同,其实在表现形式上,它们就是RDF
  • 常用的方式主要是RDF/XML,Turtle

RDFS

  • Resource Description Framework Schema
  • 是RDF的一种扩展
  • RDFS几个比较重要,常用的词汇:
    • rdfs:Class. 用于定义类。
    • rdfs:domain. 用于表示该属性属于哪个类别。
    • rdfs:range. 用于描述该属性的取值类型。
    • rdfs:subClassOf. 用于描述该类的父类。比如,我们可以定义一个运动员类,声明该类是人的子类。
    • rdfs:subProperty. 用于描述该属性的父属性。比如,我们可以定义一个名称属性,声明中文名称和全名是名称的子类
    • 其实rdf:Property和rdf:type也是RDFS的词汇,因为RDFS本质上就是RDF词汇的一个扩展。我们在这里不罗列进去,是不希望读者混淆, 更多RDFS词汇的用法参考W3C官方文档
  • 举例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    @prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
    @prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
    @prefix : <http://www.kg.com/ontology/> .

    ### 这里我们用词汇rdfs:Class定义了“人”和“地点”这两个类
    :Person rdf:type rdfs:Class.
    :Place rdf:type rdfs:Class.

    ### rdfs当中不区分数据属性和对象属性,词汇rdf:Property定义了属性,即RDF的“边”
    :chineseName rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :career rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :fullName rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :birthDate rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:date .

    :height rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:int .

    :weight rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:int .

    :nationality rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :hasBirthPlace rdf:type rdf:Property;
    rdfs:domain :Person;
    rdfs:range :Place .

    :address rdf:type rdf:Property;
    rdfs:domain :Place;
    rdfs:range xsd:string .

    :coordinate rdf:type rdf:Property;
    rdfs:domain :Place;
    rdfs:range xsd:string .

OWL

  • Web Ontology Language

  • 是对RDFS的一个扩展,添加了额外的预定义词汇

  • 提供快速,灵活的数据建模能力

  • 高效的自动推理能力

  • 描述属性特征的词汇

    • owl:TransitiveProperty. 表示该属性具有传递性质。例如,我们定义“位于”是具有传递性的属性,若A位于B,B位于C,那么A肯定位于C。
    • owl:SymmetricProperty. 表示该属性具有对称性。例如,我们定义“认识”是具有对称性的属性,若A认识B,那么B肯定认识A。
    • owl:FunctionalProperty. 表示该属性取值的唯一性。 例如,我们定义“母亲”是具有唯一性的属性,若A的母亲是B,在其他地方我们得知A的母亲是C,那么B和C指的是同一个人。
    • owl:inverseOf. 定义某个属性的相反关系。例如,定义“父母”的相反关系是“子女”,若A是B的父母,那么B肯定是A的子女
  • 本体映射词汇(Ontology Mapping)

    • owl:equivalentClass. 表示某个类和另一个类是相同的。
    • owl:equivalentProperty. 表示某个属性和另一个属性是相同的。
    • owl:sameAs. 表示两个实体是同一个实体
  • 举例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    @prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
    @prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
    @prefix : <http://www.kg.com/ontology/> .
    @prefix owl: <http://www.w3.org/2002/07/owl#> .

    ### 这里我们用词汇owl:Class定义了“人”和“地点”这两个类
    :Person rdf:type owl:Class.
    :Place rdf:type owl:Class.

    ### owl区分数据属性和对象属性(对象属性表示实体和实体之间的关系)。词汇owl:DatatypeProperty定义了数据属性,owl:ObjectProperty定义了对象属性
    :chineseName rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :career rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :fullName rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :birthDate rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:date .

    :height rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:int .

    :weight rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:int .

    :nationality rdf:type owl:DatatypeProperty;
    rdfs:domain :Person;
    rdfs:range xsd:string .

    :hasBirthPlace rdf:type owl:ObjectProperty;
    rdfs:domain :Person;
    rdfs:range :Place .

    :address rdf:type owl:DatatypeProperty;
    rdfs:domain :Place;
    rdfs:range xsd:string .

    :coordinate rdf:type owl:DatatypeProperty;
    rdfs:domain :Place;
    rdfs:range xsd:string .
  • 举个例子体现对两个不同知识图谱的融合

    1
    2
    3
    http://www.zhangsan.com/ontology/Person rdf:type owl:Class . 
    http://www.lisi.com/ontology/Human rdf:type owl:Class .
    http://www.zhangsan.com/ontology/Person owl:equivalentClass http://www.lisi.com/ontology/Human .

Python——跨文件类中isinstance函数困境


不同文件为入口文件时

  • 文件一(fruit.py):

    1
    2
    3
    4
    5
    6
    7
    # file: fruit.py
    class Apple:
    def __init__(self):
    name = "HongFuShi"

    apple = Apple()
    print apple.__class__
  • 文件二(run.py):

    1
    2
    # file: run.py
    import fruit
  • 考虑一个文件名为 fruit.py 的文件夹中定义了一个类Apple,同时初始化一个对象apple

    • 若执行 python fruit.py:
      • 输出 “main.Apple”
    • 若将当前文件导入到另一个文件 run.py 中,然后执行python run.py:
      • 输出 “fruit.Apple”
    • 也就是说,执行不同文件,类 Apple 的前缀不同

跨文件类中isinstance函数的困境

  • 困境说明:isinstance的困境:看起来是同一个类,但执行isinstance后返回False

  • 文件一(fruit.py):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    # file: fruit.py
    class Apple(object):
    pass

    if __name__ == "__main__":
    from getapple import get_apple # 注意,这里 import 必须放到 "__main__" 中,不能放到外面,否则会发生文件的递归依赖,出现初始化错误
    apple1 = Apple()
    print apple1.__class__ # <class '__main__.Apple'>
    apple2 = get_apple()
    print apple2.__class__ # <class 'fruit.Apple'>
    print isinstance(apple2, Apple) # False
  • 文件二(getapple.py):

    1
    2
    3
    4
    5
    6
    7
    # file: getapple.py
    from fruit import Apple

    def get_apple():
    apple = Apple()
    print apple.__class__ # <class 'fruit.Apple'>
    return apple
  • 此时执行 python getapple.py,无任何输出(符合预期,因为这里 getapple.py 只是用于定义函数)

  • 若执行 python fruit.py,则输出如下:

    1
    2
    3
    4
    <class '__main__.Apple'>
    <class 'fruit.Apple'>
    <class 'fruit.Apple'>
    False
    • 此时 fruit.py 是程序的入口文件
    • 在入口文件中执行 apple1 = Apple() 后得到的类将是 __main__.Apple
    • 在入口文件被导入到 getapple.py 文件中后,执行 apple2 = Apple() 后得到的类将是 fruit.Apple
    • 此时,由于下面的原因造成 isinstance(apple2, Apple) 返回 False
      • apple2 的类别是 fruit.Apple(在 getapple.py 中定义的)
      • Apple 是 __main__.Apple(在 fruit.py 中定义的)
  • isinstance的困境总结:看起来是同一个类,但执行isinstance后返回False

Python——数字范围边界等问题

C++中不同类型的数字有自己不同的边界和范围,Python中呢?如何判断边界问题?


最大最小整数

C++

1
2
int minInt = 0x80000000;
int maxInt = 0xffffffff

Python

1
2
minInt = -0xffffffff
maxInt = 0xffffffff
  • Python中int大小为24个字节,数字太大时不会越界,会变为long类型,long类型的字节占位可以非常大(24以下为int,之后为long,分别可以为36,44,52,60等,每次8位递加?),不会越界

    • 测试:当一个数字太大时,使用int(a)强制字符转换也不能将数字转换为int类型,将一直为long类型
    • 测试: Python中24个字节存储一个int类型对象,但是并不是所有空间都存值,只有一部分用来存储数值
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      MinInt = int(0x80000000000000000)
      print MinInt
      import sys
      print sys.getsizeof(MinInt)
      print type(int(MinInt))

      # output:
      147573952589676412928
      36
      <type 'long'>
  • Python中定义最小负数时可以使用float最小值或者是很大的整数的负数,而不是像C++一样


最大最小浮点数

Python

1
2
minFloat = float("-inf")
maxFloat = float("inf")

ML——主成分分析-PCA


与TSVD的对比

  • 关于SVD的进一步了解可参考Math——奇异值分解-SVD
  • PCA与TSVD目标不同
  • TSVD奇异值与PCA分解得到的对角矩阵元素意义不同
    • PCA得到矩阵对角元素的是该维度的方差
    • TSVD得到的是某种重要的隐形意义(注意,不是方差)
  • PCA等价于下面两个步骤:
    • 对数据X中心化
    • 对数据做TSVD分解

与ICA的对比

  • ICA得到的变量满足独立性
  • PCA得到的变量满足不相关性
  • 独立与不相关的关系
    • 变量独立 \(=>\) 变量不相关
    • 变量不相关 \(\neq>\) 变量独立
    • 当变量是正态分布时:变量独立 \(<=>\) 变量不相关

Tips——一些有用的tips总结

本文对一些程序员日常可能用到的小 tips 进行总结和记录


海量字符串的合并

问题描述

  • 将大量的(现实遇到的是 18W+ 级别的句子,句子长度平均在 100 个字符以上)字符串需要合并为一个字符串

  • 如果直接迭代并使用下面的语句合并,花费很多时间,随着字符串的增大,合并速度越来越慢

    1
    2
    all_texts += " %s" % text
    all_texts = "%s %s" % (all_text, text)
  • 现实生活中发现到了 5000 个字符串以上时速度变得极慢

解决方案

  • 分批次合并:将词语分批次分别合并为一个比较小的,最后再合并到一起
  • 实际中我按照 2000 个句子一份合并完成,再最终合并,速度提升了非常多

Emoji 搜索网址

  • Emoji 大全:www.unicode.org/emoji

Python——多线程可中断文件逐行处理示例


整体说明

  • 本文示例使用 AI 辅助生成,Prompt 为:
    1
    2
    3
    4
    5
    写一个多线程 python 代码,从一个文件读取数据,然后逐行进行处理,加载为 json 后从中读取 'input' 字段并在前后添加 '```',处理完成后写入另一个文件中,要求如下:
    1. 处理过程中实时打印处理进度
    2. 要求不使用 queue 等高级的包,用原生的 Python 和 threading 包实现即可
    3. 要求写入文件顺序和原始文件的顺序相同
    4. 由于文件很大,且执行过程中可能会随时中断,请用一个文件维护完成情况(完整写入文件才算完成),保证可以随时重启(指定参数 resume=True 时则从中断处启动,否则从头开始重新执行)

多线程可中断文件逐行处理示例

  • 代码示例,仅修改 process_line 即可使用:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    import threading
    import json
    import os
    import signal
    import sys
    from typing import Optional
    import traceback

    class MultiThreadFileProcessor:
    def __init__(self, input_file: str, output_file: str, progress_file: str = "progress.txt", num_threads: int = 4):
    """
    初始化文件处理器

    Args:
    input_file: 输入文件路径
    output_file: 输出文件路径
    progress_file: 进度记录文件路径
    num_threads: 线程数量
    """
    self.input_file = input_file
    self.output_file = output_file
    self.progress_file = progress_file
    self.num_threads = num_threads

    # 用于线程同步的锁
    self.task_lock = threading.Lock()
    self.result_lock = threading.Lock()
    self.write_lock = threading.Lock()

    # 所有待处理的行
    self.all_lines = []

    # 当前要分配的任务索引
    self.current_task_index = 0

    # 存储处理结果的字典
    self.results = {}

    # 已完成的行号集合
    self.completed_lines = set()

    # 下一个要写入的行号
    self.next_write_line = 0

    # 总行数
    self.total_lines = 0

    # 处理完成的行数
    self.processed_count = 0

    # 输出文件句柄
    self.output_handle = None

    # 停止标志(用于优雅退出)
    self.stop_flag = threading.Event()

    # 注册信号处理器
    signal.signal(signal.SIGINT, self.signal_handler)
    signal.signal(signal.SIGTERM, self.signal_handler)

    def signal_handler(self, signum, frame):
    """处理 Ctrl+C 和终止信号"""
    print("\n\n收到中断信号,正在优雅退出...")
    print("已处理的数据会保存,可以使用 resume=True 继续")
    self.stop_flag.set() # 设置停止标志

    def load_progress(self) -> set:
    """加载进度文件,返回已完成的行号集合"""
    if os.path.exists(self.progress_file):
    with open(self.progress_file, 'r') as f:
    completed = set(int(line.strip()) for line in f if line.strip())
    return completed
    return set()

    def save_progress(self, line_num: int):
    """保存进度到文件"""
    with open(self.progress_file, 'a') as f:
    f.write(f"{line_num}\n")
    f.flush()

    def process_line(self, line: str) -> str:
    """
    处理单行数据

    Args:
    line: 原始行数据

    Returns:
    处理后的数据
    """
    try:
    data = json.loads(line.strip())
    if 'input' in data:
    data['input'] = f"```{data['input']}```"
    return json.dumps(data, ensure_ascii=False)
    except json.JSONDecodeError as e:
    print(f"\nJSON解析错误: {e}, 原始数据: {line[:100]}")
    return line.strip()

    def get_next_task(self) -> Optional[tuple]:
    """
    获取下一个待处理的任务(线程安全)

    Returns:
    (行号, 行内容) 或 None(无任务)
    """
    with self.task_lock:
    # 检查停止标志
    if self.stop_flag.is_set():
    return None

    # 跳过已完成的任务
    while self.current_task_index < len(self.all_lines):
    line_num, line = self.all_lines[self.current_task_index]
    self.current_task_index += 1

    if line_num not in self.completed_lines:
    return (line_num, line)
    else:
    # 已完成的任务也计入进度
    self.processed_count += 1

    return None

    def worker(self):
    """工作线程函数 - 动态获取任务"""
    try:
    while not self.stop_flag.is_set():
    # 获取下一个任务
    task = self.get_next_task()
    if task is None:
    break # 没有任务了或收到停止信号

    line_num, line = task

    # 处理数据
    processed = self.process_line(line)

    # 检查是否需要停止
    if self.stop_flag.is_set():
    # 将未写入的结果放回(不保存进度)
    break

    # 将结果存储到字典中
    with self.result_lock:
    self.results[line_num] = processed
    self.processed_count += 1

    # 实时打印进度
    progress = (self.processed_count / self.total_lines) * 100
    print(f"\r处理进度: {self.processed_count}/{self.total_lines} ({progress:.2f}%) | 待写入: {len(self.results)}", end='', flush=True)

    # 尝试写入文件(按顺序)
    self.try_write_results()

    except Exception as e:
    print(f"\n线程 {threading.current_thread().name} 发生错误: {e}")
    traceback.print_exc()
    self.stop_flag.set() # 发生错误时通知其他线程停止

    def try_write_results(self):
    """尝试按顺序写入结果到文件"""
    if self.stop_flag.is_set():
    return # 如果收到停止信号,不再写入

    with self.write_lock:
    # 按顺序写入所有可以写入的行
    while self.next_write_line in self.results:
    line_num = self.next_write_line

    with self.result_lock:
    content = self.results.pop(line_num)

    # 写入文件
    self.output_handle.write(content + '\n')
    self.output_handle.flush() # 确保写入磁盘

    # 保存进度
    self.save_progress(line_num)

    # 更新下一个要写入的行号
    self.next_write_line += 1

    def process(self, resume: bool = False):
    """
    主处理函数

    Args:
    resume: 是否从中断处继续
    """
    # 如果不是恢复模式,清空输出文件和进度文件
    if not resume:
    if os.path.exists(self.output_file):
    os.remove(self.output_file)
    if os.path.exists(self.progress_file):
    os.remove(self.progress_file)
    self.completed_lines = set()
    self.next_write_line = 0
    else:
    # 加载已完成的行
    self.completed_lines = self.load_progress()
    self.next_write_line = len(self.completed_lines)
    print(f"从第 {self.next_write_line} 行继续处理...")

    # 读取所有行
    print("正在读取文件...")
    with open(self.input_file, 'r', encoding='utf-8') as f:
    self.all_lines = [(i, line) for i, line in enumerate(f)]

    self.total_lines = len(self.all_lines)
    print(f"文件总行数: {self.total_lines}")
    print(f"已完成行数: {len(self.completed_lines)}")
    print(f"待处理行数: {self.total_lines - len(self.completed_lines)}")

    if len(self.completed_lines) >= self.total_lines:
    print("所有行已处理完成!")
    return

    # 打开输出文件(追加模式)
    self.output_handle = open(self.output_file, 'a', encoding='utf-8')

    try:
    # 创建并启动线程(设置为守护线程)
    threads = []
    for i in range(self.num_threads):
    thread = threading.Thread(target=self.worker, name=f"Worker-{i}")
    thread.daemon = False # 不设置为守护线程,以便优雅退出
    threads.append(thread)
    thread.start()

    # 等待所有线程完成
    for thread in threads:
    thread.join()

    # 检查是否是被中断的
    if self.stop_flag.is_set():
    print(f"\n程序被中断")
    print(f"已完成 {self.next_write_line} 行的处理和写入")
    print(f"使用 resume=True 可以继续处理")
    else:
    # 确保所有结果都已写入
    self.try_write_results()
    print(f"\n处理完成! 输出文件: {self.output_file}")

    except KeyboardInterrupt:
    print("\n\n检测到键盘中断...")
    self.stop_flag.set()

    # 等待线程退出(最多等待5秒)
    for thread in threads:
    thread.join(timeout=5)

    print(f"已完成 {self.next_write_line} 行的处理和写入")

    finally:
    # 关闭输出文件
    if self.output_handle:
    self.output_handle.close()
    print("输出文件已安全关闭")


    # 使用示例
    if __name__ == "__main__":
    # 创建处理器实例
    processor = MultiThreadFileProcessor(
    input_file="input.jsonl",
    output_file="output.jsonl",
    progress_file="progress.txt",
    num_threads=4
    )

    # 从头开始处理
    # processor.process(resume=False)

    # 从中断处继续
    processor.process(resume=True)

Python——多线程和多进程


整体说明

  • 在 Python 里,多线程和多进程都可以实现并行处理
  • Python 还提供了方便使用的 ThreadPoolExecutor 和 ProcessPoolExecutor 类用于多线程和多进程并行处理
  • 多线程切换和启动任务开销小,但在 Python 中受到 全局解释器锁(GIL) 限制,导致更适合一些 IO 密集型任务
  • 多进程切换和启动任务开销大,但在 Python 中不受 全局解释器锁(GIL) 限制,更适合一些 CPU 密集型任务

多线程(Multithreading)

  • 多线程是在同一个进程内运行多个线程,这些线程共享进程的内存空间
  • 但在 Python 中,由于 全局解释器锁(GIL) 的存在,在 CPU 密集型任务中,多线程无法充分利用多核 CPU 的优势 ,更适合 I/O 密集型任务,像网络请求、文件读写这类
    • 注:这是 Python 特有的问题( 全局解释器锁(GIL))其他语言没有这个问题
  • 多线程的示例代码(原生形式):
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    import threading
    import time

    def io_task(task_id):
    print(f"Task {task_id} starts")
    time.sleep(1) # 模拟非计算密集型的 I/O 操作
    print(f"Task {task_id} ends")

    if __name__ == "__main__":
    start_time = time.time()
    threads = []

    # 创建并启动线程
    for i in range(3):
    thread = threading.Thread(target=io_task, args=(i,))
    threads.append(thread)
    thread.start() # 启动单个线程

    # 等待所有线程完成
    for thread in threads:
    thread.join() # 等待单个线程完成
    print(f"Total time taken: {time.time() - start_time:.2f} seconds")

    # Task 0 starts
    # Task 1 starts
    # Task 2 starts
    # Task 1 ends
    # Task 2 ends
    # Task 0 ends
    # Total time taken: 1.01 seconds

多进程(Multiprocessing)

  • 多进程是指运行多个独立的进程,每个进程都有自己独立的内存空间
  • 多进程不受 GIL 的限制 ,能够充分发挥多核 CPU 的性能 ,所以更适合 CPU 密集型任务 ,例如科学计算、图像处理等
  • 与多线程相比,除了将 threading.Thread 替换成 multiprocessing.Process,用法几乎一模一样
  • 多进程的示例代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import multiprocessing
    import time

    def cpu_task(task_id):
    print(f"Task {task_id} starts,Process ID: {multiprocessing.current_process().pid}")
    sum([i * i for i in range(10**7)]) # 模拟 CPU 密集型操作
    print(f"Task {task_id} ends")

    if __name__ == "__main__":
    start_time = time.time()
    processes = []
    # 创建并启动多个进程
    for i in range(3):
    process = multiprocessing.Process(target=cpu_task, args=(i,))
    processes.append(process)
    process.start() # 启动单个进程
    # 等待所有进程完成
    for process in processes:
    process.join() # 等待单个进程完成
    print(f"Total time taken: {time.time() - start_time:.2f} seconds")

    # Task 0 starts,Process ID: 70140
    # Task 1 starts,Process ID: 70141
    # Task 2 starts,Process ID: 70142
    # Task 0 ends
    # Task 1 ends
    # Task 2 ends
    # Total time taken: 2.03 seconds

ProcessPoolExecutor 和 ThreadPoolExecutor 的使用

  • 这两个类都位于 concurrent.futures 模块中,为我们提供了更高级的异步执行接口

    • ThreadPoolExecutor :用于多线程编程
    • ProcessPoolExecutor :用于多进程编程
  • 它们都提供了 submit() 和 map() 方法,还能通过 with 语句来自动管理资源

    • submit():单个任务提交
    • map():批量任务提交
  • 下面是使用这两个类的示例代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    import concurrent.futures
    import time

    def io_task(args):
    task_id, content = args
    print(f"IO Task {task_id} starts")
    time.sleep(1)
    print(f"IO Task received the content: '{content}'")
    print(f"IO Task {task_id} ends")
    return args

    def cpu_task(args):
    task_id, content = args
    print(f"CPU Task {task_id} starts")
    sum([i * i for i in range(10**7)])
    print(f"IO Task received the content: '{content}'")
    print(f"CPU Task {task_id} ends")
    return args

    def io_task_multi_args(task_id, content):
    print(f"IO Task {task_id} starts")
    time.sleep(1)
    print(f"IO Task received the content: '{content}'")
    print(f"IO Task {task_id} ends")
    return task_id, content

    def cpu_task_multi_args(task_id, content):
    print(f"CPU Task {task_id} starts")
    sum([i * i for i in range(10**7)])
    print(f"IO Task received the content: '{content}'")
    print(f"CPU Task {task_id} ends")
    return task_id, content

    if __name__ == "__main__":
    print("== 单参数形式:")
    print("=== ThreadPoolExecutor Demo ===")
    start_time = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    params = [(1, "content1"), (2, "content2"), (3, "content3")]
    results = list(executor.map(io_task, params, timeout=10))
    print(f"Total time taken for IO tasks: {time.time() - start_time:.2f} seconds, results={results}")

    print("\n=== ProcessPoolExecutor Demo ===")
    start_time = time.time()
    with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor:
    params = [(1, "content1"), (2, "content2"), (3, "content3")]
    results = list(executor.map(cpu_task, params, timeout=10))
    print(f"Total time taken for CPU tasks: {time.time() - start_time:.2f} seconds, results={results}")


    print("== 多参数形式:")
    print("=== ThreadPoolExecutor Demo ===")
    start_time = time.time()
    with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
    params = [(1, "content1"), (2, "content2"), (3, "content3")]
    futures = list(executor.submit(io_task_multi_args, *param) for param in params) # 需要使用 lambda 关键字定义新函数
    for future in concurrent.futures.as_completed(futures, timeout=10):
    print(f"future result: {future.result()}")
    print(f"Total time taken for IO tasks: {time.time() - start_time:.2f} seconds")

    print("\n=== ProcessPoolExecutor Demo ===")
    start_time = time.time()
    with concurrent.futures.ProcessPoolExecutor(max_workers=3) as executor:
    params = [(1, "content1"), (2, "content2"), (3, "content3")]
    futures = list(executor.submit(io_task_multi_args, *param) for param in params) # 需要使用 lambda 关键字定义新函数
    for future in concurrent.futures.as_completed(futures, timeout=10):
    print(f"future result: {future.result()}")
    print(f"Total time taken for CPU tasks: {time.time() - start_time:.2f} seconds")

    # == 单参数形式:
    # === ThreadPoolExecutor Demo ===
    # IO Task 1 startsIO Task 2 starts
    #
    # IO Task 3 starts
    # IO Task received the content: 'content1'IO Task received the content: 'content2'
    # IO Task 2 ends
    # IO Task received the content: 'content3'
    # IO Task 3 ends
    #
    # IO Task 1 ends
    # Total time taken for IO tasks: 1.01 seconds, results=[(1, 'content1'), (2, 'content2'), (3, 'content3')]
    #
    # === ProcessPoolExecutor Demo ===
    # CPU Task 1 starts
    # CPU Task 2 starts
    # CPU Task 3 starts
    # IO Task received the content: 'content2'
    # CPU Task 2 ends
    # IO Task received the content: 'content3'
    # CPU Task 3 ends
    # IO Task received the content: 'content1'
    # CPU Task 1 ends
    # Total time taken for CPU tasks: 1.13 seconds, results=[(1, 'content1'), (2, 'content2'), (3, 'content3')]
    # == 多参数形式:
    # === ThreadPoolExecutor Demo ===
    # IO Task 1 starts
    # IO Task 2 starts
    # IO Task 3 starts
    # IO Task received the content: 'content3'
    # IO Task 3 ends
    # future result: (3, 'content3')
    # IO Task received the content: 'content2'
    # IO Task 2 ends
    # future result: (2, 'content2')
    # IO Task received the content: 'content1'
    # IO Task 1 ends
    # future result: (1, 'content1')
    # Total time taken for IO tasks: 1.01 seconds
    #
    # === ProcessPoolExecutor Demo ===
    # IO Task 1 starts
    # IO Task 2 starts
    # IO Task 3 starts
    # IO Task received the content: 'content1'
    # IO Task 1 ends
    # future result: (1, 'content1')
    # IO Task received the content: 'content2'
    # IO Task 2 ends
    # IO Task received the content: 'content3'
    # IO Task 3 ends
    # future result: (2, 'content2')
    # future result: (3, 'content3')
    # Total time taken for CPU tasks: 1.46 seconds
  • 特别说明:concurrent.futures.as_completed() 是 Python 标准库中的一个函数,用于处理异步执行的多个任务

    • 输入:接收一个包含 Future 对象的可迭代对象(如列表)
    • 返回:一个迭代器,该迭代器会在每个 Future 完成时立即返回它的结果(按完成顺序,而非提交顺序)

multiprocessing.spawn 用法

  • multiprocessing.spawn 是 Python multiprocessing 模块中用于启动子进程的方法,适用于需要在新进程中执行函数的场景,尤其在分布式训练(如 PyTorch 多GPU训练)中常用

  • spawn 会创建全新的Python解释器进程,每个子进程独立运行,拥有独立的内存空间,避免了多线程中的全局解释器锁(GIL)限制

  • 基本用法示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import multiprocessing as mp

    # 子进程要执行的函数
    def worker_function(rank, world_size, shared_arg):
    print(f"子进程 {rank}/{world_size} 启动,参数: {shared_arg}")
    # 子进程的具体逻辑...

    def main():
    # 总进程数(例如等于GPU数量)
    world_size = 4
    # 共享参数(会被传递给每个子进程)
    shared_arg = "hello from main"

    # 启动子进程
    # args 是传递给 worker_function 的参数元组(除了 rank 之外的参数)
    mp.spawn(
    worker_function, # 子进程执行的函数
    args=(world_size, shared_arg), # 传递给函数的参数(第一个参数固定为 rank)
    nprocs=world_size, # 子进程数量
    join=True # 是否等待所有子进程结束后再继续
    )
    print("所有子进程执行完毕")

    if __name__ == "__main__":
    # 在Windows系统中必须放在 if __name__ == "__main__" 下
    mp.set_start_method("spawn") # 显式指定启动方法(可选,默认可能为fork)
    main()
    • fn:子进程要执行的函数(第一个参数必须是 rank,表示进程编号,从0开始)
    • args:传递给函数的额外参数(元组形式)
    • nprocs:要启动的子进程数量
    • join:若为 True,主进程会等待所有子进程执行完毕再继续
    • daemon:是否将子进程设为守护进程(主进程退出时自动终止子进程)
  • 注意:多线程没有完全对应的API ,只能通过 threading.Thread 实现类似的多线程启动逻辑


使用 subprocess.Popen 函数启动进程

  • subprocess.Popen 是 Python 标准库 subprocess 模块中的一个类,用于启动一个新的进程
  • subprocess.Popen 可以连接到其输入/输出/错误管道,并获得其返回码
  • subprocess.Popen 为更复杂的进程管理提供了灵活的接口,是替代 os.system、os.spawn*、os.popen* 等旧有方法的推荐方式

基本用法即常用参数说明

  • 用法说明

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import subprocess

    p = subprocess.Popen(
    args,
    bufsize=-1,
    executable=None,
    stdin=None,
    stdout=None,
    stderr=None,
    preexec_fn=None,
    close_fds=True,
    shell=False,
    cwd=None,
    env=None,
    universal_newlines=False,
    startupinfo=None,
    creationflags=0
    )
  • subprocess.Popen 常用参数详细说明见下文

  • args

    • 字符串或序列(如列表)
    • 指定要执行的命令及其参数- 当 shell=False 时,推荐使用列表,如 ['ls', '-l']
      • 当 shell=True 时,通常传递字符串,如 "ls -l"
    • 注:如果命令中包含空格或特殊字符,建议使用列表方式,避免命令解析错误
  • bufsize

    • 整数
    • 设置缓冲策略
      • 0:无缓冲(直接读写)
      • 1:行缓冲(文本模式下有效)
      • 其他正整数:指定缓冲区大小(以字节为单位)
      • -1(默认):使用系统默认缓冲策略
    • 通常用法 :一般保持默认即可,特殊需求时才设置
  • executable

    • 字符串
    • 指定要执行的程序的路径(用于替换默认可执行文件)
      • 例如,executable="/usr/bin/python3" 可以强制使用指定的解释器
  • stdin, stdout, stderr

    • 取值可以是文件对象、文件描述符、subprocess.PIPE、subprocess.DEVNULL、None、subprocess.STDOUT
    • 分别指定子进程的标准输入、输出、错误
      • None(默认):继承父进程的对应流
      • subprocess.PIPE:创建管道,允许父进程与子进程通信
      • subprocess.DEVNULL:丢弃输入/输出
      • 文件对象:如 open('output.txt', 'w'),将输出写入文件
    • 举例:
      • stdout=subprocess.PIPE 表示捕获标准输出,后续可通过 p.communicate() 返回读取;
      • stderr=subprocess.PIPE 表示捕获标准错误
      • stdout=fp 将标准输出写到指定文件中(常用 a 追加形式打开文件 fp)
      • stderr=subprocess.STDOUT 将子进程的标准错误(stderr)也重定向到标准输出(stdout),即错误信息也写入日志文件
  • preexec_fn

    • 可调用对象(函数)
    • 在子进程启动前执行指定的函数(仅限Unix)
    • 比如设置进程组、修改环境等
    • 注:在Windows平台无效
  • close_fds

    • 布尔值
    • 是否在子进程中关闭除 stdin/stdout/stderr 以外的所有文件描述符
      • 默认:True(Unix),False(Windows)
    • 通常建议保持默认,除非有特殊文件句柄传递需求
      • 默认定义是:close_fds: bool = ... 表示 bool 类型的占位,在实际调用 subprocess.Popen 时,close_fds 的默认值由具体的实现决定(如在 Unix 下默认 True,Windows 下默认 False)
      • 也可以在调用时显式传递 close_fds=True 或 close_fds=False
  • shell

    • 布尔值
    • 是否通过 shell 运行命令
      • True:命令通过 shell 解析(如 /bin/sh 或 cmd.exe),可用 shell 特性(如重定向、管道)
      • False(默认):直接执行指定的程序
    • 安全提示 :shell=True 存在命令注入风险,处理外部输入时需谨慎
  • cwd

    • 字符串
    • 指定子进程的工作目录
    • 举例:cwd="/tmp" 表示子进程在 /tmp 目录下运行
  • env

    • 字典
    • 设置子进程的环境变量
      • 若为 None,则继承父进程环境
      • 可自定义环境变量,如:env={"PATH": "/usr/bin", "USER": "test"}
    • 注:未指定的变量将丢失,需包含必需的环境变量
  • universal_newlines / text

    • 布尔值
    • 指明子进程是否以文本模式处理输入输出(自动编码/解码)
      • True:与子进程通信时,输入输出为字符串(str),自动处理换行
      • False(默认):以字节流处理(bytes)
    • 处理文本数据时设为 True 或 text=True
    • 注:Python 3.7 以后,建议使用 text 参数替代 universal_newlines 参数
      • 虽然 universal_newlines 依然还在,但不建议使用
      • text 和 universal_newlines 是等价参数,不能同时设置;如果同时传递,会抛出 ValueError 异常
  • startupinfo, creationflags(仅Windows)

    • startupinfo:用于指定进程启动信息(如窗口显示方式)
    • creationflags:用于指定进程创建标志(如 subprocess.CREATE_NEW_CONSOLE)
  • 其他参数

    • restore_signals(3.2+):是否恢复信号处理(Unix)
    • start_new_session(3.2+):是否在新会话中启动进程(Unix)

用法示例

  • 最简单的用法:

    1
    2
    3
    import subprocess

    p = subprocess.Popen(['ls', '-l']) # 启动进程并执行,此外不做任何操作
  • 捕获输出:

    1
    2
    3
    4
    5
    import subprocess

    p = subprocess.Popen(['ls', '-l'], stdout=subprocess.PIPE) # 启动进程并执行,同时创建管道,允许父进程与子进程通信
    out, err = p.communicate() # 与主进程通信,并返回执行结果信息
    print(out.decode()) # 若没有 stdout=subprocess.PIPE,这里的 out 是 None,不可以执行 decode() 命令
  • 通过 shell 运行命令:

    1
    2
    3
    4
    import subprocess

    p = subprocess.Popen("echo Hello World", shell=True) # 启动进程并执行,此外不作任何操作,这里没有任何输出
    out, err = p.communicate() # 与主进程通信,由于没有 stdout=subprocess.PIPE 和 stderr=subprocess.PIPE,out 和 err 均为 None
  • 传递输入:

    1
    2
    3
    4
    5
    import subprocess

    p = subprocess.Popen(['cat'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) # 启动进程并执行,cat命令启动进入等待输入,cat 命令启动后,接下来每次输入数据都会被原样输出
    out, err = p.communicate(input=b'Hello\n') # 与主进程通信,并返回执行结果信息
    print(out.decode())
  • 管道操作示例:

    1
    2
    3
    4
    5
    6
    7
    import subprocess

    p1 = subprocess.Popen(['ls'], stdout=subprocess.PIPE)
    p2 = subprocess.Popen(['grep', 'py'], stdin=p1.stdout, stdout=subprocess.PIPE)
    p1.stdout.close()
    out, err = p2.communicate()
    print(out.decode())

Popen 类对象相关常用方法汇总

  • communicate(input=None, timeout=None):
    • 与子进程交互,发送数据到 stdin,读取 stdout 和 stderr,等待进程结束
  • wait(timeout=None):
    • 等待子进程结束,返回退出码
  • poll():
    • 检查子进程是否结束,未结束返回 None
  • terminate():
    • 终止子进程(发送 SIGTERM)
  • kill():
    • 强制杀死子进程(发送 SIGKILL)
  • pid:
    • 子进程的进程号
  • returncode:
    • 子进程的返回码

使用注意事项

  • 如果需要捕获输出,记得设置 stdout=subprocess.PIPE 和/或 stderr=subprocess.PIPE,否则无法通过 communicate() 获取输出
  • 使用 shell=True 时,命令通常以字符串形式传递,并注意安全风险(如命令注入)
  • 如果子进程输出很大,建议及时读取输出,避免死锁,详细理解见附录
  • Python 3.5+ 推荐用 subprocess.run 简化常见用法,但 Popen 适合更复杂的场景

附录:subprocess.Popen 死锁情况分析

  • 问题说明:当子进程输出数据量很大时,父进程必须及时读取这些数据,否则操作系统管道缓冲区会被写满,导致子进程和父进程互相等待,程序陷入死锁

  • 管道缓冲区有限 ,当用 subprocess.PIPE 捕获子进程的标准输出(stdout)或标准错误(stderr)时,父进程和子进程之间通过一个操作系统管道通信

  • 这个管道是有缓冲区大小限制的(通常几 KB 到几十 KB,依赖操作系统)

  • 如果子进程输出的数据量超过了缓冲区大小,而父进程没有及时读取这些数据,缓冲区会被写满

  • 当缓冲区被写满后,子进程会被阻塞,无法继续写入新的输出(即子进程暂停在写操作)

  • 如果此时父进程又在等待子进程结束(比如调用 wait() 或 communicate()),但没有及时读取管道内容,父子进程就会互相等待,导致死锁 :

    • 子进程等着缓冲区有空间继续输出;
    • 父进程等着子进程结束,却没读取缓冲区内容;
    • 结果两者都无法继续执行,程序卡死
  • 问题代码示例:

    1
    2
    3
    4
    5
    import subprocess

    p = subprocess.Popen(['cat ./big_document.txt'], stdout=subprocess.PIPE)
    p.wait() # 只等子进程结束,不读取输出
    output = p.stdout.read() # 这一步可能永远无法执行到
  • 正确用法建议:及时读取输出 ,常用的做法是用 communicate(),它会在等待子进程结束的同时,自动持续读取所有输出,防止缓冲区写满:

    1
    2
    3
    4
    import subprocess

    p = subprocess.Popen(['cat ./big_document.txt'], stdout=subprocess.PIPE)
    output, _ = p.communicate() # 推荐做法,自动避免死锁
  • 如果需要实时处理输出,可以循环读取:

    1
    2
    3
    4
    5
    6
    import subprocess

    p = subprocess.Popen(['cat ./big_document.txt'], stdout=subprocess.PIPE)
    for line in p.stdout:
    process(line) # 逐行处理,及时清空缓冲区
    p.wait()

附录:使用 subprocess.call 函数启动进程

  • subprocess.call 可用于启动进程执行外部命令,是高层封装的函数,本质上是 Popen 的简化接口
    • 它内部会创建 Popen 对象并等待命令执行完成,返回命令的退出码
    • 适用于简单场景,只需知道是否成功(退出码),无需复杂交互
  • subprocess.Popen 是底层核心类,提供最完整的功能和灵活性
    • 它直接创建进程对象,允许用户与子进程进行复杂交互(如输入/输出处理、异步执行等)
    • 适合需要精细控制的场景

subprocess.call 的使用示例

  • call() 是阻塞式的:调用后会等待命令执行完毕才返回,返回值是命令的退出码(0 表示成功),示例:

    1
    2
    import subprocess
    ret_code = subprocess.call(["echo", "hello"]) # 等待命令完成,返回 0
  • 注:Popen 默认是非阻塞式的:创建进程后立即返回 Popen 对象,不会等待命令结束(需显式调用 wait() 或 communicate() 等待完成),示例如下:

    1
    2
    3
    import subprocess
    proc = subprocess.Popen(["echo", "hello"]) # 立即返回,不等待
    ret_code = proc.wait() # 手动等待命令完成,获取退出码
  • 注:Popen 支持更多高级操作,而 call() 不支持,比如

    • Popen 支持输入/输出重定向(通过 stdin/stdout/stderr 与子进程交互),但 call 不支持:

      1
      2
      3
      # 捕获命令输出
      proc = subprocess.Popen(["ls"], stdout=subprocess.PIPE)
      output, _ = proc.communicate() # 获取输出
    • Popen 支持异步执行,可以在命令运行时做其他事情,再回头处理结果

    • Popen 支持信号处理,可通过 send_signal() 向子进程发送信号(如终止进程)

    • Popen 支持管道操作,可多个 Popen 对象可通过管道连接(类似 Linux 管道 |)

使用 subprocess.check_all 启动进程并监控失败异常

  • subprocess.check_call(args, ...) 执行指定的命令,等待命令运行结果的返回码(return code),同时还具备 call 没有的抛出异常功能
    • 如果命令执行成功(返回码为 0),则无返回值(或说返回 0);
    • 如果命令执行失败(返回码非 0),则会抛出 subprocess.CalledProcessError 异常(call 不会抛出异常)
    • 示例:
      1
      2
      3
      4
      5
      6
      import subprocess
      try:
      subprocess.check_call(["ls", "-l"]) # 执行 ls -l 命令
      print("命令执行成功")
      except subprocess.CalledProcessError as e:
      print(f"命令执行失败,返回码:{e.returncode}")

使用 subprocess.check_output 启动进程、监控失败异常,并读取输出结果

  • subprocess.check_output(args, ...) 执行指定的命令,并返回命令的输出结果(stdout),同时还具备 call 没有的抛出异常功能

    • 如果命令执行成功(返回码为 0),返回输出内容(字节串,可通过 text=True 参数转为字符串);
    • 如果命令执行失败(返回码非 0),则会抛出 subprocess.CalledProcessError 异常
    • 示例:
      1
      2
      3
      4
      5
      6
      import subprocess
      try:
      result = subprocess.check_output(["echo", "hello"], text=True)
      print(f"命令输出:{result.strip()}") # 输出:hello
      except subprocess.CalledProcessError as e:
      print(f"命令执行失败,返回码:{e.returncode}")
  • 与 subprocess.check_call(args, ...) 的区别为,subprocess.check_output(args, ...) 返回值为输出结果而不是执行状态(返回码, return code)


附录:multiprocessing.Pool 的用法

  • 使用 multiprocessing.Pool 的示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    import multiprocessing
    import time
    import random

    # 定义一个待并行执行的任务函数
    def process_task(x):
    time.sleep(random.uniform(0.1, 0.5)) # 模拟任务耗时
    result = x * x
    print(f"处理 {x} -> {result} (进程ID: {multiprocessing.current_process().pid})")
    return result

    def main():
    # 生成待处理的数据
    data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    print(f"原始数据: {data}\n")

    # 创建进程池(默认使用 CPU 核心数,也可指定 processes 参数)
    with multiprocessing.Pool(processes=4) as pool: # 推荐使用 with 语句管理进程池的资源
    # 使用 map:阻塞式,返回列表(按输入顺序)
    print("=== 使用 pool.map ===")
    map_results = pool.map(process_task, data)
    print(f"map 结果: {map_results}\n")

    # 使用 imap:非阻塞式,返回迭代器(按输入顺序,逐步获取结果),注:返回的迭代器只能遍历一次,且遍历过程中会阻塞等待对应任务完成
    print("=== 使用 pool.imap ===")
    imap_iter = pool.imap(process_task, data)
    # 迭代获取结果(每次迭代会阻塞直到对应任务完成,完成一个返回一个)
    imap_results = [res for res in imap_iter] # 注:返回结果是有序的,与 data 的序一致
    print(f"imap 结果: {imap_results}\n")

    # imap 高阶用法,可指定单次分配给进程的任务数量,可避免多次分配 (下面的代码每批给进程池分配 50 个任务)
    imap_iter = pool.imap(process_task, data, chunksize=50) # 默认 chunksize=1
    # for 循环读取方式(读取到的结果是有序的,与输入数据 data 的顺序一致),注:pool.imap(process_task, data) 的结果也可以这样读取
    for i, result in enumerate(imap_iter):
    if i < 5:
    print(result) # 依次输出:1,2,3,4,5
    else:
    break

    # 使用 imap_unordered:非阻塞式,返回迭代器(按任务完成顺序),注:返回的迭代器只能遍历一次,且遍历过程中会阻塞等待对应任务完成
    print("=== 使用 pool.imap_unordered ===")
    imap_unordered_iter = pool.imap_unordered(process_task, data)
    unordered_results = [res for res in imap_unordered_iter]
    print(f"imap_unordered 结果(无序): {unordered_results}\n")

    # 使用 apply:单次提交任务(阻塞式,适合单个任务)
    print("=== 使用 pool.apply ===")
    single_result = pool.apply(process_task, args=(100,)) # 传入单个参数
    print(f"apply 单个结果: {single_result}\n")

    # 使用 starmap:处理多参数任务(类似 map,但支持元组拆包)
    print("=== 使用 pool.starmap ===")
    # 定义一个多参数函数
    def multi_param_task(a, b):
    return a + b
    # 数据为元组列表(每个元组对应一组参数)
    multi_data = [(1, 2), (3, 4), (5, 6)]
    starmap_results = pool.starmap(multi_param_task, multi_data)
    print(f"starmap 结果: {starmap_results}")

    # 使用 starmap_async:异步版本的starmap(非阻塞)
    print("=== 使用 pool.starmap_async ===")
    # 提交异步任务,立即返回AsyncResult对象,不阻塞主进程
    async_result = pool.starmap_async(multi_param_task, multi_data)
    # 可以在这里执行其他操作(演示非阻塞特性)
    print("等待异步任务完成...")
    time.sleep(0.5) # 模拟主进程其他工作
    # 获取结果(get()方法会阻塞直到任务完成)
    starmap_async_results = async_result.get()
    print(f"starmap_async 结果: {starmap_async_results}")

    if __name__ == "__main__":
    main()

相关核心函数说明

  • pool.map(func, iterable)
    • 阻塞式:等待所有任务完成后返回结果列表
    • 结果顺序与输入 iterable 一致
    • 适合简单的单参数任务
  • pool.imap(func, iterable)
    • 非阻塞式:返回一个迭代器,可逐步获取结果(迭代时会阻塞直到对应任务完成)
    • 结果顺序与输入一致,适合处理大量数据时节省内存(无需等待全部完成)
  • pool.imap_unordered(func, iterable)
    • 非阻塞式:返回迭代器,但结果顺序与任务完成顺序一致(不保证输入顺序)
    • 适合对结果顺序无要求的场景,可更快获取部分结果
  • pool.apply(func, args)
    • 阻塞式:单次提交一个任务,args 为函数参数
    • 适合偶尔提交单个任务,效率较低(不建议批量使用)
  • pool.starmap(func, iterable_of_tuples)
    • 类似 map,但支持多参数函数:iterable_of_tuples 中的每个元组会被拆分为函数的参数(如 (a,b) 对应 func(a,b))

注意事项

  • 进程池使用 with 语句可自动关闭,无需手动调用 pool.close() 和 pool.join()
  • imap 和 imap_unordered 返回的迭代器只能遍历一次,且遍历过程中会阻塞等待对应任务完成

Python——队列和栈使用

本文从总结Python中栈和队列的基本使用
Python 中queue模块是线程安全的,为多线程任务设计的,没有peek()操作

  • 双端队列(deque)是一个具有栈和队列性质的数据结构,可以从两端弹出

普通的栈和队列

栈

list实现栈
1
2
3
4
5
6
7
8
9
10
11
# init
stack = list()
# push
stack.append(1)
# pop
stack.pop()
# peek
top = stack[-1]
# determine whether it is empty
if len(stack) == 0:
print("stack is empty")
deque实现栈
1
2
3
4
5
6
7
8
9
10
11
12
from collections import deque
# init
stack = deque([1, 2, 3])
# push
stack.append(4)
# pop
stack.pop()
# peek
top = stack[-1]
# determine whether it is empty
if len(stack) == 0:
print("stack is empty")

队列

list实现栈
1
2
3
4
5
6
7
8
9
10
11
12
# init
queue = [1, 2, 3]
# push
queue.append(4)
# pop
queue.pop(0)
# peek
first = queue[0]
last = queue[-1]
# determine whether it is empty
if len(queue) == 0:
print("queue is empty")
deque实现队列
1
2
3
4
5
6
7
8
9
10
11
12
13
from collections import deque
# init
queue = deque([1, 2, 3])
# push
queue.append(4)
# pop
queue.popleft()
# peek
first = queue[0]
last = queue[-1]
# determine whether it is empty
if len(queue) == 0:
print("queue is empty")

线程安全的栈和队列

queue模块实现队列和栈

1
2
3
4
5
6
7
8
9
10
11
import queue
# init, stack and queue
sstack = queue.LifoQueue()
squeue = queue.Queue()
# push
sstack.put(item)
# pop
sstack.get()
# determine whether it is empty
if sstack.empty():
print("sstack is empty")
1…424344…66
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

659 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4