Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

ML——LightGBM概述

本文介绍LightGBM的一些特性和并行实现方法

  • 参考博客: https://www.cnblogs.com/ldt-/p/10206356.html

LightGBM vs XGBoost

  • 树结点切分方式不同:
    • XGBoost树节点切分是 Level-wise
    • LightGBM树节点切分是 Leaf-wise
  • LightGBM直接支持类别特征 ,对类别特征不必进行 One-Hot 处理
  • 实现方面:
    • 直方图算法(近似切分算法)最初由LightGBM提出,后来的XGBoost算法也实现了直方图算法
  • XGBoost的近似搜索算法和LightGBM的直方图算法不同
    • XGBoost的近似搜索算法是保存所有样本的二阶梯度,用分位数确定划分方法,他的分割点是可以直接通过计算总的样本梯度和和分位数点得到的.
        * LightGBM算法是将所有样本放进对应的桶中,并以当前桶作为划分点,计算左右桶的最大增益,它的最优点是遍历所有的桶才能得到的.
  • LightGBM 通信优化 比 XGBoost 做得好
    • 这里有亲身体验: XGBoost使用多个处理器时,有时候处理器数量增加训练速度不增加,甚至反而变慢,xgboost.XGBClassifier
  • LightGBM 使用了 GOSS(Gradient-based One-Side Sampling) 来做采样算法
    • GOSS 是通过区分不同梯度的实例,保留较大梯度实例同时对较小梯度随机采样的方式减少计算量,从而达到提升效率的目的
    • GOSS 算法流程 :
      • 根据样本的梯度将样本降序排序
      • 保留前 \(a \ (0 < a < 1)\) 比例大小的数据样本,作为数据子集 \(Z_1\),也就是保留 \(a * len(all\_samples)\) 的数据样本
      • 对于剩下的数据的样本,随机采样获得大小为 \(b \ (0 < b < 1)\) 的数据子集 \(Z_2\),也就是采样 \(b * len(all\_samples)\) 的数据样本
      • 计算信息增益时对采样的 \(Z_2\) 样本的梯度数据乘以 \((1-a)/b\) (目的是不改变原数据的分布)
    • GOSS的思想是,梯度大的实例正常使用,梯度小的实例就通过采样实现部分拟合总体的方法(拟合时不改变原来的分布,除以采样率就行了)
  • LightGBM 使用了 EFB (Exclusive Feature Bundling)
    • EFB 通过特征捆绑的方式减少特征维度(其实是降维技术)的方式,提升计算效率
    • 通常被捆绑的特征都是互斥的(一个特征值为0,一个特征值不为0), 这样两个特征捆绑起来也不会造成信息丢失
    • 当两个特征不是完全互斥时,可以用一个指标对特征不互斥程度进行评估,称为冲突比率
    • 冲突比率较小时,我们可以将他们捆绑而不太影响训练结果
    • EFB 算法流程:
      • 将特征按照非零值的个数进行排序
      • 计算不同特征之间的冲突比率
      • 遍历每个特征并尝试合并特征,使冲突比率最小化
  • LightGBM 的内存优化
    • XGBoost 中 需要对每个特征进行预排序(注意:这里不能在算是XGBoost的缺点了,后来的XGBoost也实现了这个直方图算法,不需要预排序了)
    • LightGBM使用直方图算法替代了预排序

LightGBM的优点

  • 相对XGBoost:
    • 速度快 (GOSS, EFB, Histogram等)
    • 内存少 (XGBoost中排序)
    • 精度高(效果不明显, 与XGBoost相当, 本人测试, 实际使用中有时候不如XGBoost, 可能是参数调节的问题)

缺点

  • 虽然官方一再强调LightGBM的精度不比XGBoost低,而且可能超过XGBoost,但是实践中, LightGBM除了比XGBoost快以外, 精度方面没什么优势, 甚至精度还不如XGBoost(注意: 也可能是我调参技术不行)
    • 问题: 为什么某些数据集上出现LightGBM比XGBoost精度差的情况?
    • 回答: (个人理解)因为GOSS和EFB会带来一定的精度损失

总结

  • LightGBM = XGBoost + Histogram + GOSS + EFB
    • XGBoost: 不同于XGBoost的, 树节点切分不同, LightGBM使用了 Leaf-wise 的生长策略
    • Histogram:
      • Histogram方法牺牲了一定的精度,但是作者强调了在实验中精度降低并不多
      • 开始的XGBoost使用的是预先排序的方式, 后来在 XGBoost 中也实现了Histogram
      • LightGBM 对 Histogram 算法进一步加速
      • 一个叶子节点的 Histogram 可以直接由父节点的Histogram和兄弟节点的Histogram做差得到, 一般情况下,构造Histogram需要遍历该叶子上的所有数据,通过该方法,只需要遍历Histogram的k个捅, 速度提升了一倍
    • GOSS: 对于梯度小的样本, 使用采样部分代替总体的方法省时间
    • EFB: 互斥特征捆绑,提升计算效率
  • LightGBM 真正做到了把并行做到极致
    • 特征并行: 在不同的机器在不同的特征集合上分别寻找最优的分割点, 然后再机器间同步最优的分割点.
    • 数据并行: 让不同的机器先在本地构造直方图, 然后进行全局的合并,然后在合并的直方图上寻找最优的分割点.
  • LightGBM 支持类别特征
    • 无需将类别特征 One-Hot
  • 问题: 为什么XGBoost也使用了直方图,但是速度依然比较慢?
    • 直方图算法的实现有两种:
      • 1)全局构造 ,在树的构造过程中只建立一次直方图, 每次分裂都从缓存的直方图中寻找分裂点
      • 2)局部构造 ,每次树分裂到一层的时候就建立一次直方图
      • XGBoost使用的是局部构造的方式, 所以速度会较慢

ML——LR-逻辑回归


手动推导流程

  • 假设有m样本: \(X = (x_{1}, x_{2}, x_{3},\dots x_{m})\)
    • 样本点为: \(((x_{1}, y_{1}), (x_{2}, y_{2}), (x_{3}, y_{3}), \dots (x_{m}, y_{m}))\)
  • 假设 \(w, \theta, x_{i}\) 等所有向量都为列向量

确定分类决策函数

  • 线性回归模型
    $$f(x) = w^{T}x + b$$
  • 令
    $$
    \begin{align}
    \theta &= (w; b) \\
    x_{i} &= (x_{i}; 1)
    \end{align}
    $$
  • 有
    $$f(x) = \theta^{T} x$$
  • 逻辑回归决策函数在线性回归模型上加一个sigmoid函数
    $$
    \begin{align}
    h_{\theta}(x) &= sigmoid(f(x)) \\
    &= \frac{1}{1+e^{-f(x)}} \\
    &= \frac{1}{1+e^{-\theta^{T} x}} \\
    \end{align}
    $$
  • 即
    $$
    \begin{align}
    p(y=1|x) &= h_{\theta}(x) = \frac{1}{1+e^{-\theta^{T} x}} = \frac{e^{\theta^{T} x}}{1+e^{\theta^{T} x}}\\
    p(y=0|x) &= 1-h_{\theta}(x) = \frac{e^{-\theta^{T} x}}{1+e^{-\theta^{T} x}} = \frac{1}{1+e^{\theta^{T} x}} \\
    \end{align}
    $$
  • 对数几率(log odds, 也称为logit) 定义为: \(ln \frac{p}{1-p}\),在LR中有:
    $$
    \begin{align}
    ln\frac{h_{\theta}(x)}{1-h_{\theta}(x)} = \theta^T x
    \end{align}
    $$
  • 分类超平面不确定,与最终的阈值有关, \(\alpha\) 的值与最终阈值相关
    $$w^{\star}x + b^{\star} = \alpha$$
    • 分类超平面由 \((w, b)\) 和阈值唯一确定,(注意: SVM的分类超平面由 \((w, b)\) 唯一确定)
    • 这一点和SVM不同,SVM的分类超平面是确定的,详情参看ML——SVM-支持向量机

确定优化目标

  • LR中使用极大似然法
    $$
    \begin{align}
    L(\theta) &= p(Y|X;\theta) \\
    &= \prod_{i=1}^{m}p(y_{i}|x_{i}; \theta) \\
    &= \prod_{i=1}^{m}(p(y_{i} = 1|x_{i};\theta))^{y_{i}}(p(y_{i} = 0|x_{i};\theta))^{1-y_{i}}
    \end{align}
    $$
  • 上面的式子不易求导优化,我们使用与上面的式子单调性和最优点等价的对数似然函数为
    $$
    \begin{align}
    LL(\theta) &= \log L(\theta) \\
    &= \log \prod_{i=1}^{m}(p(y_{i} = 1|x_{i};\theta))^{y_{i}}(p(y_{i} = 0|x_{i};\theta))^{1-y_{i}} \\
    &= \sum_{i=1}^{m}\left (y_{i}\log(p(y_{i} = 1|x_{i};\theta)) + (1-y_{i})\log(p(y_{i} = 0|x_{i};\theta)) \right ) \\
    &= \sum_{i=1}^{m}\left (y_{i}\log\frac{p(y_{i} = 1|x_{i};\theta)}{p(y_{i} = 0|x_{i};\theta)} +\log(p(y_{i} = 0|x_{i};\theta))\right ) \\
    \end{align}
    $$
    • 上面的式子中:
      $$\sum_{i=1}^{m}\left (y_{i}\log(p(y_{i} = 1|x_{i};\theta)) + (1-y_{i})\log(p(y_{i} = 0|x_{i};\theta)) \right )$$
    • 加个负号即为为交叉熵损失函数
      $$-\sum_{i=1}^{m}\left (y_{i}\log(p(y_{i} = 1|x_{i};\theta)) + (1-y_{i})\log(p(y_{i} = 0|x_{i};\theta)) \right )$$
    • 所以交叉熵损失函数与逻辑回归的对数似然损失函数(=逻辑回归对数似然函数的负数)是等价的
  • 由前面的推导有
    $$
    \begin{align}
    \log \frac{p(y_{i} = 1|x_{i};\theta)}{p(y_{i} = 0|x_{i};\theta)} = \log\frac{\frac{e^{\theta^{T} x}}{1+e^{\theta^{T} x}}}{\frac{1}{1+e^{\theta^{T} x}}} = \log e^{\theta^{T} x} = \theta^{T}x\\
    \end{align}
    $$
    • 且:
      $$\log(p(y_{i} = 0|x_{i};\theta)) =\log(\frac{1}{1+e^{\theta^{T} x}}) = -\log(1+e^{\theta^{T} x})$$
  • 故而有
    $$
    \begin{align}
    LL(\theta) &= \sum_{i=1}^{m}\left (y_{i}\log\frac{p(y_{i} = 1|x_{i};\theta)}{p(y_{i} = 0|x_{i};\theta)} +\log(p(y_{i} = 0|x_{i};\theta))\right )\\
    &= \sum_{i=1}^{m}\left ( y_{i}\theta^{T}x - \log(1+e^{\theta^{T}x}) \right)
    \end{align}
    $$

损失函数

  • 最大化似然函数等价于最小化似然函数的负数
  • LR中使用极大似然法 ,所以对应的损失函数自然为对数似然损失函数
    $$loss(\theta) = -LL(\theta) = \sum_{i=1}^{m}\left (- y_{i}\theta^{T}x +\log(1+e^{\theta^{T}x}) \right)$$

梯度下降法优化

注意: 这里优化目标也可以用牛顿法

  • 目标,求一个 \(\theta^{\star}\) 满足似然函数最大化或者损失函数最小化
    $$\theta^{\star} = \mathop{\arg\max}_{\theta} LL(\theta) = \mathop{\arg\min}_{\theta} -LL(\theta) = \mathop{\arg\min}_{\theta} loss(\theta)$$
  • 对数似然函数对参数 \(\theta\) 求导有
    $$
    \begin{align}
    \frac{\partial loss(\theta)}{\partial\theta} &= \sum_{i=1}^{m}\left ( -y_{i}x_{i} + \frac{x_{i}e^{\theta^{T}x}}{1+e^{\theta^{T}}x}\right ) \\
    &= \sum_{i=1}^{m}x_{i}\left ( -y_{i} + \frac{e^{\theta^{T}x}}{1+e^{\theta^{T}}x}\right ) \\
    &= \sum_{i=1}^{m}x_{i}\left ( -y_{i} + h_{\theta}(x_{i})\right ) \\
    \end{align}
    $$
  • LR模型的梯度下降参数迭代公式
    $$
    \begin{align}
    \theta^{t+1} &= \theta^{t} - \alpha\sum_{i=1}^{m}x_{i}\left ( -y_{i} + h_{\theta^{t}}(x_{i})\right ) \\
    &= \theta^{t} + \alpha\sum_{i=1}^{m}x_{i}\left ( y_{i} - h_{\theta^{t}}(x_{i})\right )
    \end{align}
    $$
    • 其中 \(\alpha\) 为步长
  • 线性回归和LR模型的梯度下降法参数迭代公式表达式看似相同,但是不同的模型对应的 \(h_{\theta}\) 函数并不相同

其他总结

参考:https://www.cnblogs.com/ModifyRong/p/7739955.html

  • LR中使用极大似然法 ,所以对应的损失函数自然为对数似然损失函数(对数损失函数)
    • 对数似然损失函数定义为:
      $$Loss(\theta) = -P(Y|X; \theta)$$
    • 《统计学习方法》P213 中定义 LR 的损失函数为逻辑斯蒂损失函数,在LR模型和最大熵模型中,逻辑斯蒂损失函数本质上与对数似然损失函数等价,可推导得到
  • 一句话概括逻辑回归:
    • 逻辑回归是假设数据服从伯努利分布,通过极大似然函数的方法,运用梯度下降法或者牛顿法来求解参数,来达到将数据二分类的目的
    • 逻辑回归的假设: 数据服从伯努利分布 , \(p(y=1|x) = 1-p(y=0|x)\)
    • 逻辑回归的损失函数: 对数似然损失函数(交叉熵损失函数) ,也就是对数似然函数的负数
    • 逻辑回归的求解方法: 梯度下降法或牛顿法
    • 逻辑回归的目的: 将数据二分类
    • 逻辑回归如何分类: 预测结果是连续的[0-1]的数 ,我们一般选择0.5作为阈值来分类,但是这个值可能是可以变化的,因为损失函数最小并不意味着0.5时分类精度最高
  • 为什么要用极大似然法?等价于为什么要用对数似然损失函数作为损失函数?
    • 损失函数一般有平方损失函数,对数损失函数,合页损失函数,绝对值损失函数等,极大似然函数取对数后等同于对数损失函数,在逻辑回归这个模型中,推导可以得到,对数损失函数训练求解参数的迭代函数只与 \(x_{i}, y_{i}\) 相关,与sigmoid函数的梯度等无关.这样的参数更新自始至终都比较稳定
    • 为什么不选平方损失函数的呢?其一是因为如果你使用平方损失函数,你会发现梯度更新的速度和sigmod函数本身的梯度是很相关的,sigmod函数在它在定义域内的梯度都不大于0.25, 这样训练会非常的慢
  • 逻辑回归中,如果某些特征高度相关,甚至某些特征完全相同,会造成什么影响?
    • 损失函数收敛后,没有影响,因为特征的相关性并不影响分类器效果,重复特征会分化特征的权重(10个重复特征和单个特征训练结果差别在于前者每个特征的权重是后者的十分之一),本质上最终结果不变的
    • 但是训练时由于特征重复,参数增多,模型复杂度增加,训练时长,内存等都会增加
  • 为什么需要去掉高度相关的特征?
    • 去掉高度相关的特征使得模型可解释性更好
    • 提高训练时间,节约内存,减少参数数量
    • 特征的提取本身也需要时间,实际工程项目中可以少提取一个特征往往能节约很多时间
  • logistic 与 logit 的区别?
    • logit: 又名 log adds , 指的是”对数几率”, 定义为 \(ln\frac{p}{1-p}\)
    • logistic: 又叫Sigmoid函数, 指的是”对数几率函数”, 本质上是一种”Sigmoid”函数, 定义为 \(f(x) = \frac{1}{1+e^{-x}}\)
  • 简单介绍LR模型的优缺点:
    • 优点:
      • 模型简单,可解释性好,(如果对数据特征进行了归一化处理的话)可以从特征的权重看到不同特征对最终结果的影响
      • 模型效果往往不错(特征工程做得好的话)
      • 训练速度快, 成熟的SGD优化方法(SGD可以分布式)等
      • 内存占用小
      • 输出结果可以作为概率值,然后可以对阈值根据实际进行划分,不一定是确定的0.5,只是一般选择0.5而已
    • 缺点:
      • 难以处理非线性数据,本质上是线性分类面
      • 难以处理数据不平衡问题 , 这里如果正例远远多于负例,那么全都预测为正例,整体损失函数也不会太大
      • LR 本身无法筛选特征 ,有时候会用GBDT和XGBoost来筛选特征,然后再用LR模型
  • 扩展:逻辑回归可像SVM一样引入核函数处理非线性分类问题吗?
    • 一般来说不可以
    • [存疑]理论上通过对原始样本非线性映射,似乎也可以,如果将 \(f(x) = \theta^{T}x\) 中的 \(f(x)\) 看做 \(\theta\) 看做变量,然后类比SVM的核函数,定义一个关于 \(x_{i}\) 的非线性映射
      • 这里 \(x_{i}\) 表示第 \(i\) 个样本, 用 \(x_{i}^{j}\) 表示第 \(i\) 个样本的第 \(j\) 个维度
        $$x_{i}^{j} = \phi_{j}(x_{i}^{j})$$
      • 基于上述非线性映射函数的定义,我们对每个样本都进行线性映射,每个维度用不同的映射函数(不同样本相同维度映射函数相同)
      • 这里的非线性映射与SVM的核函数不同,SVM不使用核函数的话,也可以通过相同的非线性映射的方式实现非线性分类
      • 使用核技巧后的LR模型将变得很慢,SVM与kernels是相配的,而LR与kernels会十分慢(来源SVM核技巧)
  • LR 模型训练完成后,输出概率多少的样本应该评估为正样本?【以下分析为个人理解,暂无严格证明】
    • LR模型的损失函数本质上是交叉熵损失函数,交叉熵损失函数本质是最小化预估分布与训练样本分布之间的差距,故而预估均值与真实训练样本均值应该相等 ,即LR模型的预估值均值理论上与训练样本标签均值相同(这里LR的预估值是Sigmoid的输出值,训练样本负样本标签为0,正样本标签为1)。【PS,一种辅助理解思考:假设训练集中的样本特征完全相同,但其中30%是正样本,另外70%为负样本,那么优秀的LR模型在预估该训练样本时应该输出约为0.3】
    • 进一步来说,当预估的平均值大于训练样本的均值时,即可判断为正样本
      • 举例,假设训练样本的均值为0.4,那么预估值大于0.4的样本均可视为正样本(思考:这种判定下模型的准确率 \( \text{Accuracy} = \frac{TP+TN}{TP+TN+FP+FN}\) 应该是最高的?)

附录:多分类任务重的 logits

  • TLDR:多分类任务中模型输出的 logits 和 LR 中的 logit 含义不完全相同
    • 二者是同源但适用场景和维度不同的概念
  • LR 的 logit 是二分类专属的标量对数几率;
  • 多分类模型的 logits 是多分类任务的向量原始得分 ,是 logit 概念在多类别场景下的推广,需通过 Softmax 转换为概率分布

同源:基于对数几率的定义

  • 二者的本质都源于对数几率(log-odds) ,即事件发生概率与不发生概率的比值的自然对数,公式为:
    $$\text{logit}(p) = \ln\left(\frac{p}{1-p}\right)$$
    • 这个公式是连接线性得分和概率的桥梁

二分类场景下的等价性

  • 当多分类任务退化为二分类时,模型的 logits 向量为 \([\text{logit}_0, \text{logit}_1]\)
    • 若满足 \(\text{logit}_0 = -\text{logit}_1\)
    • 则 \(2\text{logit}_1\) 就完全等价于 LR 中的 logit(正例的对数几率)
  • 此时 Softmax 激活等价于 Sigmoid 激活:
    $$ p(y=1) = \frac{e^{\text{logit}_1}}{e^{\text{logit}_0}+e^{\text{logit}_1}} = \frac{e^{\text{logit}_1}}{e^{-\text{logit}_1}+e^{\text{logit}_1}} = \frac{1}{1+e^{-2\text{logit}_1}} = \frac{1}{1+e^{-\theta^T x}} = \frac{1}{1+e^{-\text{logit}}}$$
    • 若令 \(\text{logit} = 2\text{logit}_1\),则与 LR 的 Sigmoid 输出完全一致

多分类 LR 的 logits

  • 多分类逻辑回归(Softmax 回归)的输出 logits 就是上述多分类模型的 logits 向量,每个元素对应一个类别的线性得分,这是 logit 概念从标量到向量的扩展

ML——XGBoost-vs-传统GBDT

本文主要介绍XGBoost和其他传统GBDT的比较的优劣
XGBoost又叫(Newton Boosting Tree)

  • GBDT推导: ML——GBDT-梯度提升树-推导过程
  • XGBoost推导: ML——XGBoost-推导过程

XGBoost的优点

参考博客: https://www.cnblogs.com/massquantity/p/9794480.html

  • XGBoost损失函数是二阶泰勒展开(与牛顿法对应),GBDT是一阶泰勒展开(与梯度下降法对应)

    • 传统 GBDT 在优化时只用到一阶导数信息, 所以传统GBDT也叫 (Gradient Boosting)
    • XGBoost 则对目标函数进行了二阶泰勒展开,同时用到了一阶和二阶导数, 所以XGBoost又叫(Newton Boosting Tree)
  • XGBoost加了正则项 ,普通的GBDT没有,所以XGBoost学出来的模型更简单,能防止过拟合,提高模型的泛化性能

  • XGBoost Shrinkage(缩减)

    • 每次进行完一次迭代后,将叶子节点的权重乘以该系数(一般叫做eta \(\eta\))
    • 可以理解为这里Shrinkage是将学习速率调小,从而需要的迭代次数增多
    • 减小学习率实际上是减弱了每棵树对整体的影响,从而让后面的树有更多的学习空间
    • 下面的表述还有待确定:
      • 进一步惩罚决策树叶节点的值(惩罚的意思是叶节点越大,惩罚越多,损失函数越大)
      • 对叶节点的惩罚本身可以理解为一个正则化
  • 结点分裂的增益计算公式不同

    • 传统 GBDT 一般采用的是最小二乘法作为内部分裂的增益计算指标(因为用的都是回归树)
      • 注意: 这里本文中描述的最小绝对偏差回归(LAD)是第一步损失函数的定义,不是这一步中的信息增益计算
      • 查看源码: sklearn.ensemble.GradientBoostingClassifier在分裂结点时可以选择三种方式:
        • friedman_mse(默认), mean squared error with improvement score by Friedman
        • mse: mean squared error
        • mae: mean absolute error
    • 而 XGBoost 使用的是经过优化推导后的式子
      $$
      \begin{align}
      Gain = \frac{G_L^2}{H_L+ \lambda} + \frac{G_R^2}{H_R+ \lambda} - \frac{(G_L + G_R)^2}{H_L+ H_R + \lambda} - \gamma
      \end{align}
      $$
      • 注意: XGBoost中的信息增益计算形式固定为上面的计算方式,但是具体的值与损失函数的定义相关(因为 \(g_i\) 和 \(h_i\) 的是损失函数的一阶和二阶梯度)
  • XGBoost支持自定义的损失函数 ,支持一阶和二阶可导就行

    • 注意,这里的损失函数指的是 \(l(y_i,\hat{y}_i)\),单个样本预测值与目标值的差异, 也就是单个样本的损失函数
    • 从ML——XGBoost-推导过程中可知:
      • \(g_i = l’(y_i,\hat{y}_i^{t-1})\) 为 \(l(y_i,\hat{y}_i)\) 对 \(\hat{y}_i\) 的一阶导数在 \(\hat{y}_i = \hat{y}_i^{t-1}\) 处的值
      • \(h_i = l’’(y_i,\hat{y}_i^{t-1})\) 是 \(l(y_i,\hat{y}_i)\) 对 \(\hat{y}_i\) 的二阶导数在 \(\hat{y}_i = \hat{y}_i^{t-1}\) 处的值
    • XGBoost中只要损失函数二次可微分即可得到 \(g_i\) 和 \(h_i\)
      • \(g_i\) 和 \(h_i\) 本身与损失函数的定义形式无关, 只要求损失函数二阶可微分即可
    • 只要有了 \(g_i\) 和 \(h_i\) 我们即可
      • 根据预先推导的叶子节点分数表达式 \(w_j^{\star} = -\frac{G_j}{H_j+\lambda}\) 求得叶子结点的分数
      • 根据预先推导的信息增益公式 \(Gain = \frac{G_L^2}{H_L+ \lambda} + \frac{G_R^2}{H_R+ \lambda} - \frac{(G_L + G_R)^2}{H_L+ H_R + \lambda} - \gamma\) 确定分裂特征和分裂点
    • GBDT 损失函数关系一般选择最小二乘回归或者最小绝对偏差回归
      • 最小方差回归 : (Least-Squares Regression, LSR)
        $$\begin{align} L(y,F(x)) = \frac{1}{2}(y-F(x))^{2} \end{align}$$
      • 最小绝对偏差回归 : (Least Absolute Deviation Regression, LAD)
        $$\begin{align} L(y,F(x)) = |y-F(x)| \end{align}$$
      • 查看源码: sklearn.ensemble.GradientBoostingClassifier的损失函数是定义好的, 不能自己定义, 详细如下源码
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        LOSS_FUNCTIONS = {'ls': LeastSquaresError,
        'lad': LeastAbsoluteError,
        'huber': HuberLossFunction,
        'quantile': QuantileLossFunction,
        'deviance': None, # for both, multinomial and binomial
        'exponential': ExponentialLoss,
        }


        _SUPPORTED_LOSS = ('deviance', 'exponential')
        ....

        if (self.loss not in self._SUPPORTED_LOSS
        or self.loss not in LOSS_FUNCTIONS):
        raise ValueError("Loss '{0:s}' not supported. ".format(self.loss))
  • XGBoost 借鉴了随机森林的做法,支持列采样 ,不仅能降低过拟合,还能减少计算量,这也是 XGBoost 异于传统 GBDT 的一个特性

    • 列采样: 这借鉴于随机森林中的做法,每棵树不使用所有特征,而是部分特征参与训练
    • 可以减少计算量,同时还能降低过拟合,简直优秀
  • XGBoost 有缺失值自动处理 , 在计算分裂增益时不会考虑带有缺失值的样本,这样就减少了时间开销,在分裂点确定了之后,将带有缺失值的样本分别放在左子树和右子树,比较两者分裂增益,选择增益较大的那一边作为默认分裂方向

    • 进一步理解稀疏数据的支持: [待更新]
  • 并行化处理 :由于 Boosting 本身的特性,传统 GBDT 无法像随机森林那样树与树之间的并行化

    • XGBoost 的并行主要体现在特征粒度上,在对结点进行分裂时,由于已预先对特征排序并保存为block 结构,每个特征的增益计算就可以开多线程进行,极大提升了训练速度
  • 剪枝策略不同

    • 传统 GBDT 在损失不再减少时会停止分裂,这是一种预剪枝的贪心策略,容易欠拟合
    • XGBoost采用的是后剪枝的策略,先分裂到指定的最大深度 (max_depth) 再进行剪枝
      • 而且和一般的后剪枝不同, XGBoost 的后剪枝是不需要验证集的[待更新:XGBoost剪枝的策略是怎样的?只依赖信息增益指标吗?]
      • 和博客作者指出的一样,我这里并不是”纯粹的”后剪枝,因为提前设定了最大深度
  • 基分类器的选择不同:

    • 传统GBDT中原始论文使用树回归 ,本文见Firedman 1999,后来作者提出可以使用逻辑回归 ,本文见Friedman 2000
    • XGBoost后面的各种损失计算等都包含着树模型的复杂度,叶子节点分类等,所以是只能用CART,不能使用逻辑回归的
    • (从函数空间定义和后面的公式推导来看)XGBoost中基函数只使用CART回归树,不能使用逻辑回归
    • 但是事实上XGBoost的实现中是支持线性分类器作为基分类器的, 参数booster[default='gbtree'],可选为booster=gblinear
      • 使用线性分类器作为基分类器时, XGBoost相当于带有L1正则化和L2正则化的:
        • Logistic回归(分类问题)
        • 线性回归(回归问题)
  • 分桶策略算法不同

    • 传统的GBDT分桶时每个样本的权重都是相同的
    • XGBoost中每个样本的权重为损失函数在该样本点的二阶导数(对不同的样本,计算得到的损失函数的二阶导数是不同的), 这里优点AdaBoost的思想,重点关注某些样本的感觉
    • 这里影响的是划分点的位置(我们划分划分点[桶]时都是均匀划分样本到桶里面,当不同样本的权重不同时,每个桶里面的样本数量可能会不同)
    • 下图是一个示例

XGBoost的缺点

注意这里是

  • 空间消耗大
    • 因为要在训练之前先对每个特征进行预排序并将结果存储起来,所以空间消耗较大
    • GBDT无需预排序,但是每次重新排序很耗时间
  • 速度慢
    • 虽然XGBoost速度比传统 GBDT 快了不少, 但是不如 LightGBM 快, 且 LightGBM 占用内存更低

XGBoost为什么能够并行?

而GBDT是不能并行的,原因是:https://www.136.la/shida/show-187480.html

  • GBDT不能并行的原因是没有预排序(XGB的预排序结果会存储到block结构)等,在有了预排序结果后,同一个特征的切割方式可以并行尝试计算增益
  • 决策树最耗时间(包括GBDT)的步骤是对特征值的排序
    • 用于确定最佳分割点
  • XGBoost训练前,预先对数据进行了排序,称为预排序
    • 将预先排序的结果保存为block结构, 后面迭代的时候重复使用这个结构,从而实现一次排序,多次使用,大大减少计算量
  • 这个结构减少计算量的同时还为并行化提供了可能([待更新]实际上不用预排序也能并行的吧?只是每次需要先使用一个单一线程排序,然后再多个线程并行?只是不够并行)
    • 进行结点的分裂时,需要计算每个特征的增益,然后选择增益最大的那个特征分裂
    • 这里我们可以同时使用多个线程计算不同特征的增益, 从而实现并行化
  • 总结为三方面的并行, ([待更新]但是具体实现了哪些并行不确定)
    • 同一层级的结点间每个结点的分裂可以并行
    • 同一个结点内部不同特征增益的计算可以并行
    • 同一个结点同一个特征的不同分裂点的增益计算可以并行

GBDT为什么不能自定义损失函数?

GBDT为什么不能像XGBoost一样自定义损失函数?

  • 查看sklearn.ensemble.GradientBoostingClassifier的源码发现, 确实不支持自定义的损失函数
  • [待更新],因为涉及到后面的参数更新方式?

XGBoost如何使用自定义的损失函数?

模型直接调用fit函数无法传入自定义的损失函数, 需要在模型开始定义的时候传入或者使用xgb.train函数调用

  • 使用方法1:

    1
    2
    3
    from xgboost import XGBClassifier

    clf = XGBClassifier(objective=MyLossFunction)
  • 使用方法2:

    1
    2
    3
    4
    5
    import xgboost as xgb
    from xgboost import XGBClassifier

    clf = XGBClassifier()
    xgb.train(xgb_model=clf, obj=MyLossFunction)

Algorithm——AVL树和红黑树等各种树结构总结


各种树的介绍

树

  • 一个根节点,每个结点可能有多个子节点

二叉树

  • 一个根节点,或者为空
  • 每个结点只有两个子节点

二叉搜索树

也叫二叉查找树

  • 首先是一棵二叉树
  • 左边孩子结点的值都小于当前结点
  • 右边孩子结点的值都大于当前结点
缺点
  • 极端情况下,树模型会退化成链表,查找变成了 O(n) 复杂度的

平衡二叉搜索树(AVL树)

也叫平衡二叉查找树

  • 首先是一棵二叉搜索树
  • 对每个结点而言, 左右孩子结点的深度差值不能超过1 , 从而保证查找是 O(log n) 的
  • 控制平衡方法: 参考链接AVL树详解
    • 左-左型: 右旋
    • 右-右型: 左旋
    • 左-右型: 左旋 + 右旋
      • 第一步左旋后面部分,变成 左-左 型
      • 第二步使用右旋修正 左-左 型
    • 右-左型: 右旋 + 左旋
      • 第一步右旋后面部分,变成 右-右 型
      • 第二步使用左旋修正 右-右 型
缺点
  • 每棵树的左右子树高度最多差1这个要求太严了
  • 几乎每次插入或者删除结点时都会造成规则破坏, 也就需要频繁的通过左旋或者右旋操作来修正
  • 插入和删除太频繁的场景中不太适合使用AVL树, 性能会因为左右子树高度最多差1这个规则频繁被打破而降低

红黑树

  • 首先是一棵二叉搜索树
  • 每个结点不是黑色就是红色
  • 根节点为黑色
  • 每个叶子结点都为黑色的空结点(NULL)
    • 注意: 叶子节点不存数据
  • 任何相邻结点不同时为红色
    • 注意,相邻结点可以为黑色,且可以某条路径上全是黑色
  • 对每个结点而言,当前结点到叶子结点的所有路径包含相同的黑色结点数
优点
  • 能保证查找时间复杂度是 O(log n) 的, 和AVL树差不多
    • 证明: [待更新]
  • 插入和删除操作中不会频繁破坏红黑树的规则
红黑树的应用
  • 容器集合 HashMap, TreeMap等
    • HashMap是 链表 + 红黑树的实现, 当冲突时就需要使用红黑树加快检索
    • HashMap中 链表长度太短时使用链表, 太长时使用红黑树, 这个阈值一般设置为8

二三树

  • 红黑树是二三树的一个变形,一般用红黑树就够了
    [待更新]

B树

  • B树在大量的数据库和文件系统场景中都有使用
    [待更新]

B+树

[待更新]


总结

  • 可以说红黑树是一种不严格的平衡树

Python——Hydra库的使用

  • 参考链接:hydra.cc/docs/intro

整体说明

  • Hydra 是一个开源的 Python 框架 ,旨在简化复杂应用程序的配置管理
  • Hydra 的核心功能是能够通过组合动态创建分层配置 ,并且可以通过配置文件和命令行轻松覆盖这些配置
  • Hydra 的名字来源于神话中的九头蛇(Hydra) ,象征着它能够轻松地使用不同配置运行多个相似的作业(即 Multirun 功能),这在机器学习和科学实验中尤其有用
  • Hydra 的主要特点总结如下
    • 分层配置 (Hierarchical Configuration): 配置可以从多个独立的配置文件组合而成
    • 命令行覆盖 (Command-Line Overrides): 能够通过命令行参数轻松修改配置的任何部分
    • 多任务运行 (Multirun): 使用一个命令就能运行多次实验,每次实验使用不同的配置组合
    • 配置快照 (Configuration Snapshots): 自动保存每次运行的完整配置,确保结果的可复现性
    • 工作目录管理 (Working Directory Management): 每次运行都会在 outputs/ 或 multirun/ 目录下创建一个以日期和时间命名的新目录,将运行结果和日志隔离
  • Hydra 常常和 omegaconf 包一起使用

Hydra 安装

  • 通过 pip 安装 hydra-core:

    1
    pip install hydra-core --upgrade
    • 依赖的 omegaconf 包会自动安装

常用示例(必会)

  • 文件结构

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    tree
    .
    ├── config
    │   ├── color
    │   │   ├── blue.yaml
    │   │   └── green.yaml
    │   ├── config.yaml
    │   ├── config2.yaml
    │   └── person
    │   ├── alice.yaml
    │   └── bob.yaml
    └── hydra_demo.py
  • ./config/color/blue.yaml文件内容

    1
    2
    favorite_color: blue
    time: 10
  • ./config/color/green.yaml文件内容

    1
    favorite_color: green
  • ./config/person/alice.yaml文件内容

    1
    2
    name: Alice
    age: 30
  • ./config/person/bob.yaml文件内容

    1
    2
    name: Bob
    age: 25
  • config/config2.yaml 文件内容:

    1
    name_aux: 100
  • config/config.yaml 文件内容:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    # 定义到 defaults 的一定是配置文件,没有配置文件会出错,索引方式见下图
    defaults:
    # - _self_ # 放到最前面则用下面的默认参数覆盖当前文件定义(比如 person:name:0)
    - person: alice # 索引 ./person/alice.yaml,也可以被参数覆盖
    - color: blue # 索引 ./blue/blue.yaml,直接效果与 - color/blue 等价,但 - color/blue 覆盖参数需要使用 `+`,不建议使用
    - person@aux_person: bob # 索引 ./person/bob.yaml,同时重命名为 aux_person,后续通过 "aux_person" 替换 ”person" 作为引用
    - config2 # 直接引用同步目录下的其他文件,相关字段会被 config2.yaml 更新
    - _self_ # 放到最后则用当前文件定义覆盖前面的默认参数(比如 person:name:0)
    # 可以在这里添加其他全局配置
    full_name: "${person.name} Li" # 全局参数,要等到所有解析完成才解析这里,所以不用担心先后顺序,这个总是最后执行的
    modes: ??? # ??? 的变量比较特殊,在通过命令行传入该参数值前,无法直接使用,否则会报错:omegaconf.errors.MissingMandatoryValue: Missing mandatory value: modes
    person:
    name: "lilian" # 当前文件定义参数,是否覆盖引入的默认值与 `_self_` 的位置有关
    ENV_PATH: ${oc.env:PATH} # 读取环境变量 $PATH,环境变量不存在会出错
  • hydra_demo.py 文件内容

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    import hydra
    from omegaconf import OmegaConf
    import json

    @hydra.main(config_path="config", config_name="config", version_base=None)
    def main(cfg):
    print("===== to yaml =====:")
    print(OmegaConf.to_yaml(cfg))

    print("===== parse to json =====:")
    dict_obj = OmegaConf.to_container(cfg, resolve=True)
    json_str = json.dumps(dict_obj, indent=4, ensure_ascii=False)
    print(json_str)

    if __name__ == '__main__':
    main()
  • 执行命令1

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    python hydra_demo.py

    # ===== to yaml =====:
    # person:
    # name: lilian
    # age: 30
    # color:
    # favorite_color: blue
    # time: 10
    # aux_person:
    # name: Bob
    # age: 25
    # name_aux: 100
    # full_name: ${person.name} Li
    # modes: ???
    # ENV_PATH: ${oc.env:PATH}
    #
    # ===== parse to json =====:
    # {
    # "person": {
    # "name": "lilian",
    # "age": 30
    # },
    # "color": {
    # "favorite_color": "blue",
    # "time": 10
    # },
    # "aux_person": {
    # "name": "Bob",
    # "age": 25
    # },
    # "name_aux": 100,
    # "full_name": "lilian Li",
    # "modes": "???",
    # "ENV_PATH": "/Users/jiahong/.nvm/versions/node/v12.14.0/bin:/usr/local/opt/node@16/bin:/Users/jiahong/anaconda3/envs/torch_py310/bin:/Users/jiahong/anaconda3/condabin:/usr/local/bin:/System/Cryptexes/App/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/local/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/appleinternal/bin"
    # }
  • 执行命令2

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    python hydra_demo.py +new_name=Joey person=bob color.time=15

    # ===== to yaml =====:
    # person:
    # name: lilian
    # age: 25
    # color:
    # favorite_color: blue
    # time: 15
    # aux_person:
    # name: Bob
    # age: 25
    # name_aux: 100
    # full_name: ${person.name} Li
    # modes: ???
    # ENV_PATH: ${oc.env:PATH}
    # new_name: Joey
    #
    # ===== parse to json =====:
    # {
    # "person": {
    # "name": "lilian",
    # "age": 25
    # },
    # "color": {
    # "favorite_color": "blue",
    # "time": 15
    # },
    # "aux_person": {
    # "name": "Bob",
    # "age": 25
    # },
    # "name_aux": 100,
    # "full_name": "lilian Li",
    # "modes": "???",
    # "ENV_PATH": "/Users/jiahong/.nvm/versions/node/v12.14.0/bin:/usr/local/opt/node@16/bin:/Users/jiahong/anaconda3/envs/torch_py310/bin:/Users/jiahong/anaconda3/condabin:/usr/local/bin:/System/Cryptexes/App/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/local/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/appleinternal/bin",
    # "new_name": "Joey"
    # }

Multi-run:启动多个配置运行

  • 启动方式:

    1
    2
    3
    # 两种启动方式等价
    python my_app.py --multirun db=mysql,postgresql schema=warehouse,support,school
    python my_app.py -m db=mysql,postgresql schema=warehouse,support,school
    • 以上启动会生成6份任务,且串行执行
  • 使用 --multirun 启动的任务配置记录在 multirun/ 文件夹下(单任务启动方式的记录在 outputs/ 下)

Multi-run 的高阶用法

  • 通过覆盖 hydra.sweeper.param 实现启动多个任务

    1
    2
    3
    4
    5
    hydra:
    sweeper:
    params:
    db: mysql,postgresql
    schema: warehouse,support,school
  • 启动命令:

    1
    2
    3
    4
    5
    python my_app.py -m db=mysql
    # [2021-01-20 17:25:03,317][HYDRA] Launching 3 jobs locally
    # [2021-01-20 17:25:03,318][HYDRA] #0 : db=mysql schema=warehouse
    # [2021-01-20 17:25:03,458][HYDRA] #1 : db=mysql schema=support
    # [2021-01-20 17:25:03,602][HYDRA] #2 : db=mysql schema=school

日志文件说明

  • 每次执行命令后都会按照时间生成日志文件

    1
    2
    3
    4
    5
    6
    7
    $ tree outputs/2024-09-25/15-16-17
    outputs/2024-09-25/15-16-17
    ├── .hydra
    │ ├── config.yaml
    │ ├── hydra.yaml
    │ └── overrides.yaml
    └── my_app.log
  • config.yaml: A dump of the user specified configuration

  • hydra.yaml: A dump of the Hydra configuration

  • overrides.yaml: The command line overrides used

  • my_app.log: A log file created for this run

    • 用 Python 文件命令的日志文件,记录被 @hydra.main 注解过的函数中的 log 对象输出
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      import logging

      log = logging.getLogger(__name__)

      @hydra.main(config_path="config", config_name="config", version_base=None)
      def main(config):
      log.info("Info level message")
      log.debug("Debug level message") # 若输出日志的等级包含 debug,则这句话也会输出到日志文件
      pass

      if __name__ == '__main__':
      log.info("out info") # 不会输出到日志文件中(因为不在 `@hydra.main` 注解过的函数中)
      main()

特别需要注意的点

  • 参数覆盖规则:
    • 传入的参数 > 后定义的参数 > 先定义的参数
  • 传入参数的规则:
    • 被覆盖的参数必须是存在的,如 name=Joe 要求 name 已经存在,若不存在则会报错
    • 不存在的参数就需要使用 + 增加参数,如 +name=Joe (少用)
    • 如果存在的参数上使用 +name=Joe 也会出现错误(不可以同时出现两个相同的 key)
    • 注:由于传入的参数会影响生效的子配置文件,自配置文件的参数配置命名上可能不同,所以参数的判定有一定的复杂性
  • 对于子配置可以使用动态方式添加(+),但建议使用 defaults 关键字定义,方便管理,定义后可以被正常覆盖(不再需要 +)

附录:使用 Structured Config

  • 在新增加文件的情况下,也可以使用 Python 类定义对象实现类似 yaml 文件的效果(不常用)

    • 详情见:https://hydra.cc/docs/tutorials/structured_config/config_store/
  • 示例(无需任何 yaml 文件配置):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    from dataclasses import dataclass
    import hydra
    from hydra.core.config_store import ConfigStore

    @dataclass
    class MySQLConfig:
    host: str = "localhost"
    port: int = 3306

    cs = ConfigStore.instance()
    # Registering the Config class with the name 'config'.
    cs.store(name="config", node=MySQLConfig)

    @hydra.main(version_base=None, config_name="config")
    def my_app(cfg: MySQLConfig) -> None:
    if cfg.port == 80:
    print("Is this a webserver?!")

    if __name__ == "__main__":
    my_app()
    • 等价于有了 config.yaml 配置文件,写入了下面的信息
      1
      2
      3
      # config.yaml
      'host': 'localhost'
      'port': 3306
  • 更高阶的层级示例(参考自:https://hydra.cc/docs/tutorials/structured_config/hierarchical_static_config/):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    from dataclasses import dataclass

    import hydra
    from hydra.core.config_store import ConfigStore

    @dataclass
    class MySQLConfig:
    host: str = "localhost"
    port: int = 3306

    @dataclass
    class UserInterface:
    title: str = "My app"
    width: int = 1024
    height: int = 768

    @dataclass
    class MyConfig:
    db: MySQLConfig = field(default_factory=MySQLConfig)
    ui: UserInterface = field(default_factory=UserInterface)

    cs = ConfigStore.instance()
    cs.store(name="config", node=MyConfig)

    @hydra.main(version_base=None, config_name="config")
    def my_app(cfg: MyConfig) -> None:
    print(f"Title={cfg.ui.title}, size={cfg.ui.width}x{cfg.ui.height} pixels")

    if __name__ == "__main__":
    my_app()
  • 更多详情参考:


附录:运行时文件工作路径获取

  • 使用 Python 命令获取,详情见原始路径
  • 参考链接:https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory/

附录:调试参数配置情况

  • 在命令中添加 --cfg job 等来输出自己的配置
    • job: 个人配置参数生效情况,包括命令行传入的参数,这里是最终生效参数情况
    • hydra: Hydra’s config
    • all: The full config, which is a union of job and hydra. 二者融合
  • 参考链接:https://hydra.cc/docs/tutorials/basic/running_your_app/debugging/

Python——Jinja2模板引擎


整体说明

  • Jinja2 是一个功能强大的 Python 模板引擎,广泛广泛用于 Web 开发(如 Flask、Django 可集成)和文档生成等场景
  • 注:Jinja 和 Jinja2 实际上是同一个模板引擎的不同版本,Jinja2 是 Jinja 的升级版本,它们在模板语法格式上大部分是兼容的,但也存在一些差异和改进
  • 本文主要以 Jinja2 为主介绍简单的使用方法
  • Jinja2 是 Python 的库,安装 Jinja2 库使用 pip 即可:
    1
    pip install jinja2

Jinja2 基本概念介绍

  • 模板(Template) :包含固定内容和动态变量/逻辑的 Text 文件(如 HTML、TXT 等)
  • 变量(Variables) :模板中需要动态替换的值,用 {{ 变量名 }} 表示
  • 控制结构 :用于实现条件判断、循环等逻辑,用 {% 代码 %} 表示
  • 过滤器(Filters) :对变量进行处理(如格式化、转换),用 {{ 变量|过滤器 }} 表示
  • 模板继承 :通过 extends 和 block 实现模板复用
  • 语句分隔符 :用来包裹“控制语句”的那对标记符号,告诉模板引擎“这里不是普通文本,而是一条要执行的指令”;Jinja2 默认的语句分隔符是:
    • 语句块(for / if / set / macro …):开始 {%` 结束 `%}
    • 变量输出(把值打印到页面):开始 {{ 变量 }}
    • 注释 :开始 ``

Jinja2 基础语法介绍

变量定义及相关操作

  • 在 Jinja2 中,定义变量可以使用 {% set %} 标签,基本语法和用法总结:
    • 基础变量用 {% set 变量名 = 值 %} 定义
    • 需在循环中修改的变量,用 namespace 命名空间
    • 变量通过 {{ 变量名 }} 输出,支持列表、字典等复杂类型
    • 实际开发中,变量更多从外部(如 Python 代码)传递到模板
基本变量定义
  • 使用 {% set 变量名 = 值 %} 格式定义变量:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    {# 定义字符串变量 #}
    {% set name = "Jinja2" %}

    {# 定义数字变量 #}
    {% set version = 2 %}

    {# 定义布尔值变量 #}
    {% set is_active = true %}

    {# 定义列表变量 #}
    {% set fruits = ["apple", "banana", "cherry"] %}

    {# 定义字典变量 #}
    {% set user = {"name": "Alice", "age": 30} %}
使用变量
  • 定义后可以用 {{ 变量名 }} 输出变量值,或在控制结构中使用:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    <p>名称:{{ name }}</p>
    <p>版本:{{ version }}</p>

    {# 列表遍历 #}
    <ul>
    {% for fruit in fruits %}
    <li>{{ fruit }}</li>
    {% endfor %}
    </ul>

    {# 字典取值 #}
    <p>用户名:{{ user.name }}</p> {# 或 user["name"] #}
命名空间(namespace)变量
  • 如果需要在循环或嵌套结构中修改变量值 ,普通变量无法直接生效,需使用 namespace 命名空间:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    {# 定义命名空间变量 #}
    {% set ns = namespace(total=0) %}

    {# 在循环中修改命名空间变量 #}
    {% for num in [1, 2, 3, 4] %}
    {% set ns.total = ns.total + num %}
    {% endfor %}

    <p>总和:{{ ns.total }}</p> {# 输出:总和:10 #}
  • 普通变量在循环中修改会被重置(作用域限制),而命名空间变量可以跨循环保持状态

变量作用域说明
  • 变量默认在定义它的模板块(block)或宏(macro) 内有效
  • 全局变量可在父模板定义,子模板通过 {{ 变量名 }} 直接使用(需确保变量已传递到子模板)
从外部传递变量
  • 实际开发中,变量通常从 Python 代码中传递到模板,而非在模板内定义:
    1
    2
    3
    4
    5
    6
    # Python 代码
    from jinja2 import Template

    template = Template("Hello, {{ name }}!")
    result = template.render(name="World") # 传递变量 name
    print(result) # 输出:Hello, World!

注释

  • 用 `` 表示,渲染时会被忽略:
    1
    2
    {# 这是一段注释,不会被渲染 #}
    <p>{{ content }}</p>

变量输出

  • 用 {{ 变量名 }} 输出变量,支持嵌套结构(如字典、对象属性):

    1
    2
    3
    4
    <!-- 模板示例 -->
    <h1>{{ title }}</h1>
    <p>作者:{{ author.name }}</p>
    <p>年龄:{{ author.age }}</p>
  • 在 Python 中渲染:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    from jinja2 import Template

    # 定义模板内容
    template_str = """
    <h1>{{ title }}</h1>
    <p>作者:{{ author.name }}</p>
    <p>年龄:{{ author.age }}</p>
    """

    # 定义变量
    data = {
    "title": "Jinja2 教程",
    "author": {"name": "张三", "age": 30}
    }

    # 渲染模板
    template = Template(template_str)
    result = template.render(**data)
    print(result)

    # <h1>Jinja2 教程</h1>
    # <p>作者:张三</p>
    # <p>年龄:30</p>

控制结构

条件判断(if-elif-else)
  • 用于实现条件判断,仅执行符合条件的分支
    1
    2
    3
    4
    5
    6
    7
    {% if score >= 90 %}
    <p>优秀</p>
    {% elif score >= 60 %}
    <p>及格</p>
    {% else %}
    <p>不及格</p>
    {% endif %}
循环(for)
  • 用于遍历列表、字典等可迭代对象,支持 loop 辅助变量(如索引、是否为第一个元素):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    <ul>
    {% for item in items %}
    <li>
    {{ loop.index }}. {{ item.name }} - {{ item.price }}元
    {% if loop.first %}(第一个){% endif %}
    {% if loop.last %}(最后一个){% endif %}
    </li>
    {% else %}
    <li>暂无数据</li> <!-- 当列表为空时执行 -->
    {% endfor %}
    </ul>
  • loop 不需要定义即可使用,是 jinja2 新给的 feature

  • loop 常用属性:

    • loop.index:当前迭代序号(从 1 开始)
    • loop.index0:当前迭代序号(从 0 开始)
    • loop.first:是否为第一个元素(布尔值)
    • loop.last:是否为最后一个元素(布尔值)

过滤器(Filters)

  • 对变量进行处理,格式为 {{ 变量|过滤器(参数) }} 。常用过滤器:
    过滤器 作用 示例
    upper 转为大写 {{ name|upper }}
    lower 转为小写 {{ name|lower }}
    capitalize 首字母大写 {{ name|capitalize }}
    length 获取长度 {{ list|length }}
    join 列表拼接为字符串 {{ list|join(', ') }}
    default 变量不存在时使用默认值 {{ value|default('暂无') }}
    date 日期格式化(需传入 datetime) {{ now|date('%Y-%m-%d') }}
  • 示例:
    1
    2
    3
    <p>姓名(大写):{{ name|upper }}</p>
    <p>列表长度:{{ items|length }}</p>
    <p>列表拼接:{{ items|join('、') }}</p>

模板继承(重要功能)

  • 通过继承可以复用模板中的公共部分(如页面头部、底部),核心是 extends 和 block

  • 父模板(base.html) :定义公共结构和可替换的块(block)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    <!DOCTYPE html>
    <html>
    <head>
    <meta charset="UTF-8">
    <title>{% block title %}默认标题{% endblock %}</title>
    </head>
    <body>
    <header>公共头部</header>

    <main>
    {% block content %}{% endblock %} <!-- 子模板替换这里 -->
    </main>

    <footer>公共底部</footer>
    </body>
    </html>
  • 子模板(page.html) :继承父模板并替换块

    1
    2
    3
    4
    5
    6
    7
    8
    {% extends "base.html" %}  <!-- 继承父模板 -->

    {% block title %}首页{% endblock %} <!-- 替换标题块 -->

    {% block content %} <!-- 替换内容块 -->
    <h1>这是首页内容</h1>
    <p>欢迎访问</p>
    {% endblock %}

加载外部模板文件

  • 实际开发中,模板通常存放在文件中(而非字符串),可通过 FileSystemLoader 加载:

  • 目录结构:

    1
    2
    3
    4
    5
    project/
    ├── templates/
    │ ├── base.html # 父模板
    │ └── page.html # 子模板
    └── app.py # 主程序
  • Python 代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    from jinja2 import Environment, FileSystemLoader

    # 配置模板目录
    env = Environment(loader=FileSystemLoader('templates'))

    # 加载并渲染子模板
    template = env.get_template('page.html')
    result = template.render() # 可传入变量,如 render(title="首页")
    print(result)

常用高级功能

  • 宏(Macro) :类似函数,用于复用代码片段:

    1
    2
    3
    4
    5
    6
    7
    {% macro input(name, value='', type='text') %}
    <input type="{{ type }}" name="{{ name }}" value="{{ value }}">
    {% endmacro %}

    <!-- 使用宏 -->
    {{ input('username') }}
    {{ input('password', type='password') }}
  • 包含(Include) :引入其他模板片段:

    1
    2
    <!-- 引入导航栏模板 -->
    {% include "navbar.html" %}
  • 自动转义 :默认开启(防止 XSS 攻击),可通过 autoescape 控制:

    1
    2
    3
    {% autoescape off %}
    {{ html_content }} <!-- 不转义,直接渲染HTML -->
    {% endautoescape %}

空白控制符

  • 在原有标签格式上加入 - 可以控制空白控制符
  • 当需要精确控制输出格式(如避免多余空行、压缩 HTML)时,使用带减号的形式
  • 对格式要求不严格时,使用默认形式更简洁
  • 减号仅影响空白字符,不改变标签的逻辑功能(如循环、条件判断等)
  • 最佳实践:一般建议都加上 {%- %} 来使用,格式更美观

控制结构中的空白控制符

  • {% %} (默认形式) :标签不会影响其前后的空白字符(空格、换行、制表符等)。例如:

    1
    2
    3
    4
    5
    <ul>
    {% for item in [0,1] %}
    <li>{{ item }}</li>
    {% endfor %}
    </ul>
    • 渲染后会保留循环标签前后的换行和缩进,可能产生多余空白:
      1
      2
      3
      4
      5
      6
      7
      <ul>

      <li>0</li>

      <li>1</li>

      </ul>
  • {%- %} (带减号的形式) :减号会移除标签一侧的空白字符(具体取决于减号的位置):

    • {%- ... %} :移除标签左侧(前面)的空白

    • {% ... -%} :移除标签右侧(后面)的空白

    • {%- ... -%} :同时移除标签两侧的空白

    • 例如,优化上面的循环:

      1
      2
      3
      4
      5
      <ul>
      {%- for item in items %}
      <li>{{ item }}</li>
      {%- endfor %}
      </ul>
      • 渲染后空白更紧凑:
        1
        2
        3
        4
        <ul>
        <li>0</li>
        <li>1</li>
        </ul>

变量输出中的空白控制符

  • {{ var }} 正常输出,不做额外处理

  • {{- var }} 或 {{ var -}} 或 {{- var -}} 同样可以用连字符 -,把变量前面或后面的空白吃掉

  • 举例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    from jinja2 import Template

    template_str = """
    {% set name = 'wo' %}
    <p> {{- name -}} </p>
    <p> {{ name }} </p>
    """

    template = Template(template_str)
    result = template.render()
    print(result)

    # <p>wo</p>
    # <p> wo </p>
  • 一句话:

    • - 就是“吃掉这条标签前/后的空白”,写在左边 ( {%- / {{-) 吃前面,写在右边 (-%} / -}} ) 吃后面

字符串连接符

  • ~ 是字符串连接运算符 ,用于将左右两边的元素拼接成一个字符串

  • ~ 作用类似于 Python 中的 + 运算符,但更灵活:

    • 会自动将非字符串类型(如变量、数字等)转换为字符串后再拼接
    • 不会像 + 那样在两边添加额外空格
    • 如果使用 + 运算符,需要确保两边都是字符串类型,而 ~ 则会自动处理类型转换,在模板中更常用
  • 以常见代码为例:

    1
    {{- "\nthinking_budget: < " ~ thinking_budget ~ "."}}
  • 这里的两个 ~ 会将三个部分拼接成一个完整字符串:

    • 1)"\nthinking_budget: < "(字符串字面量)
    • 2)thinking_budget(变量,会被转换为字符串)
    • 3)"."(字符串字面量)
  • 假设 thinking_budget 的值是 100,最终结果会是:

    1
    thinking_budget: < 100.

附录:补充示例(LongCat-Flash-Chat更详细一些)

  • 以美团开源的 LongCat-Flash-Chat/blob/main/tokenizer_config.json 为例,以下是格式化后的 chat_template Jinja2 代码及其逐行解释:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    {# 设置工具选择变量,默认值为'auto' #}
    {%- set tool_choice = tool_choice | default('auto') %}

    {# 创建命名空间变量,用于存储循环计数、工具类型和最后查询索引 #}
    {%- set ns = namespace(rounds = 0, tool_types = [], last_query_index = -1) %}


    {# 如果存在工具且工具选择不是'none',则输出工具相关信息 #}
    {%- if tools and tool_choice != 'none' %}
    {{- "# Tools\n" }}
    {{- "You have access to the following tools: \n\n" }}

    {# 遍历所有工具 #}
    {%- for tool in tools %}
    {# 只处理代码解释器和函数类型的工具 #}
    {%- if tool.type in ['code_interpreter', 'function'] %}
    {# 如果是新类型的工具,输出工具命名空间 #}
    {%- if tool.type not in ns.tool_types %}
    {%- set ns.tool_types = ns.tool_types + [tool.type] %}
    {{- "## Tool namespace: " ~ tool.type ~ "\n\n" }}
    {%- endif %}

    {# 如果是代码解释器工具,重新定义其配置 #}
    {%- if tool.type == 'code_interpreter' %}
    {%- set tool = {
    "type": "code_interpreter",
    "function": {
    "name": "code_interpreter_preview",
    "description": "The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations. \nCode sandbox environment (network isolated) Any external network requests or online API calls are prohibited. \nIf online functionality is needed, please use other permitted tools. \nCode will respond with the output of the execution or time out after 60.0 seconds. ",
    "parameters": {
    "type": "object",
    "properties": {
    "language": {
    "type": "string",
    "description": "The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++."
    },
    "code": {
    "type": "string",
    "description": "Python code to be executed must not include the following:\n- Importing network libraries such as requests, httplib, etc.\n- Any form of HTTP requests.\n- External API calls.\n- Network port operations. Example: ```python\nimport pandas as pd\npd.DataFrame({'A':[1,2]})\n```"
    },
    "timeout": {
    "type": "number",
    "description": "The maximum execution time of the code, in seconds. Default is 60.0."
    }
    }
    },
    "required": ["code"]
    }
    } %}
    {%- endif %}

    {# 输出工具名称、描述和输入 schema #}
    {{- "### Tool name: " + tool.function.name + "\n\n" }}
    {{- "Description: " + tool.function.description + "\n\n" }}
    {{- "InputSchema: \n" + tool.function.parameters | tojson(indent=2) + "\n\n" }}
    {%- endif %}
    {%- endfor %}

    {# 输出工具调用格式说明 #}
    {{- '**Note** :For each function call, return a json object with function name and arguments within <longcat_tool_call></longcat_tool_call> XML tags as follows:
    <longcat_tool_call>
    {"name": <function-name>, "arguments": <args-dict>}
    </longcat_tool_call>
    ' }}
    {{- 'When multiple functions need to be called simultaneously, each function call should be wrapped in its own <longcat_tool_call> tag and placed consecutively. For example:
    <longcat_tool_call>
    {"name": <function-name>, "arguments": <args-dict>}
    </longcat_tool_call><longcat_tool_call>
    {"name": <function-name>, "arguments": <args-dict>}
    </longcat_tool_call>

    ' }}
    {{- "# Messages\n" }}

    {# 遍历消息,找到最后一个助手的非工具调用消息索引 #}
    {%- for idx in range(messages|length - 1) %}
    {%- set msg = messages[idx] %}
    {%- if msg.role == 'assistant' and not msg.tool_calls %}
    {%- set ns.last_query_index = idx %}
    {%- endif %}
    {%- endfor%}
    {%- endif %}


    {# 遍历所有消息并格式化输出 #}
    {%- for msg in messages %}
    {# 系统消息处理 #}
    {%- if msg.role == "system" %}
    {{- "SYSTEM:" + msg.content }}

    {# 用户消息处理 #}
    {%- elif msg.role == "user" %}
    {%- if loop.first %}
    {{- "[Round " ~ (ns.rounds) ~ "] USER:" }}
    {%- else %}
    {{- " [Round " ~ (ns.rounds) ~ "] USER:"}}
    {%- endif %}
    {%- set ns.rounds = ns.rounds + 1 %}

    {# 如果有文件,输出文件信息 #}
    {%- if msg["files"] %}
    {{- '<longcat_files>\n' ~ msg.files | tojson(indent=2) ~ '\n</longcat_files>' }}
    {%- endif %}
    {{- msg.content }}

    {# 助手消息处理 #}
    {%- elif msg.role == "assistant" %}
    {{- " ASSISTANT:" }}

    {# 如果启用思考模式且有思考内容,输出思考过程 #}
    {%- if enable_thinking == true and msg.reasoning_content and ns.tool_types != [] and loop.index0 > ns.last_query_index %}
    {{- "\n<longcat_think>\n" ~ msg.reasoning_content ~ "\n</longcat_think>\n" }}
    {%- endif %}

    {# 输出助手内容 #}
    {%- if msg.content%}
    {{- msg.content }}
    {%- endif %}

    {# 输出工具调用信息 #}
    {%- if msg.tool_calls %}
    {%- for tool_call in msg.tool_calls -%}
    {{- "<longcat_tool_call>\n" -}}
    {%- if tool_call.function.arguments is string -%}
    {"name": "{{ tool_call.function.name}}", "arguments": {{tool_call.function.arguments}}}
    {%- else -%}
    {"name": "{{ tool_call.function.name}}", "arguments": {{tool_call.function.arguments | tojson}}}
    {%- endif -%}
    {{- "\n</longcat_tool_call>" }}
    {%- endfor %}
    {%- endif %}
    {{- "</longcat_s>" -}}

    {# 工具返回结果处理 #}
    {%- elif msg.role == "tool" %}
    {{- " TOOL:" -}}
    {%- if msg.name -%}
    {"name": {{msg.name | tojson}}, "content": {{msg.content | tojson}}}
    {%- else -%}
    {"content": {{msg.content | tojson}}}
    {%- endif -%}
    {%- endif %}
    {%- endfor %}


    {# 如果需要生成提示,输出相应的提示信息 #}
    {%- if add_generation_prompt %}
    {%- if enable_thinking == true %}
    {{- " /think_on" }}
    {%- if thinking_budget %}
    {%- if thinking_budget < 1024 %}
    {%- set thinking_budget = 1024 %}
    {%- endif%}
    {{- "\nthinking_budget: < " ~ thinking_budget ~ "."}}
    {%- endif %}
    {{- " ASSISTANT:<longcat_think>\n"}}
    {%- elif enable_thinking == false %}
    {{- " /think_off ASSISTANT:<longcat_think>\n\n</longcat_think>\n" }}
    {%- else %}
    {{- " ASSISTANT:" }}
    {%- endif %}
    {%- endif %}
  • 关于示例的一些补充说明:

    • </longcat_s> 是结束符,不是开始符号,记忆:从 </s> 变形而来
    • add_generation_prompt 用于判断是否需要增加生成信息,一般来说是 serving 需要,trainging 不需要
    • thinking_budget 可以不加,默认没有预算约束
    • 上述示例还缺少的模版 features 为 RAG documents 参数的使用,详情可参考 huggingface.co/CohereLabs/c4ai-command-r-v01/blob/main/tokenizer_config.json

code_interpreter 的使用

  • 特别说明:使用 code_interpreter 时,只需要在 tools 里面加一项 { "type": "code_interpreter" },,这样 chat_template 会自动识别到该字段并输出一些使用信息,告诉模型如何给出代码,并告知模型这个代码可以被执行

  • 本文示例中 chat_template 的具体做法是先将 code_interpreter 包装成一个类似 function 的格式,再统一输出,最终效果就是让模型知道可以调用 code_interpreter 执行代码("code" 参数内容就是代码)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    {# 如果是代码解释器工具,重新定义其配置 #}
    {%- if tool.type == 'code_interpreter' %}
    {%- set tool = {
    "type": "code_interpreter",
    "function": {
    "name": "code_interpreter_preview",
    "description": "The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations. \nCode sandbox environment (network isolated) Any external network requests or online API calls are prohibited. \nIf online functionality is needed, please use other permitted tools. \nCode will respond with the output of the execution or time out after 60.0 seconds. ",
    "parameters": {
    "type": "object",
    "properties": {
    "language": {
    "type": "string",
    "description": "The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++."
    },
    "code": {
    "type": "string",
    "description": "Python code to be executed must not include the following:\n- Importing network libraries such as requests, httplib, etc.\n- Any form of HTTP requests.\n- External API calls.\n- Network port operations. Example: ```python\nimport pandas as pd\npd.DataFrame({'A':[1,2]})\n```"
    },
    "timeout": {
    "type": "number",
    "description": "The maximum execution time of the code, in seconds. Default is 60.0."
    }
    }
    },
    "required": ["code"]
    }
    } %}
    {%- endif %}
  • 当 tools 的第一条信息是 { "type": "code_interpreter" } 时,chat_tempalte 格式化的结果为:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    # Tools
    You have access to the following tools:

    ## Tool namespace: code_interpreter

    ### Tool name: code_interpreter_preview

    Description: The code will be executed in a stateful Jupyter notebook sandbox environment, only supports local computation, data processing, and file operations.
    Code sandbox environment (network isolated) Any external network requests or online API calls are prohibited.
    If online functionality is needed, please use other permitted tools.
    Code will respond with the output of the execution or time out after 60.0 seconds.

    InputSchema:
    {
    "type": "object",
    "properties": {
    "language": {
    "type": "string",
    "description": "The programming language of the code to be executed. Available values: python (Default), java, go, js, ts, c, c++."
    },
    "code": {
    "type": "string",
    "description": "Python code to be executed must not include the following:\n- Importing network libraries such as requests, httplib, etc.\n- Any form of HTTP requests.\n- External API calls.\n- Network port operations. Example: ```python\nimport pandas as pd\npd.DataFrame({'A':[1,2]})\n```"
    },
    "timeout": {
    "type": "number",
    "description": "The maximum execution time of the code, in seconds. Default is 60.0."
    }
    }
    }

    ## Tool namespace: function

    ### Tool name: search

    Description: 网页搜索,使用传统搜索引擎,复杂问题需要拆分为简单query

    InputSchema:
    {
    "type": "object",
    "required": [
    "query"
    ],
    "properties": {
    "query": {
    "type": "string",
    "description": "适合传统搜索引擎的简单query"
    }
    }
    }
    ... 更多

Python——heapq模块-最大堆最小堆

由于queue不是Python标准库,所以在LeetCode等OJ上面不能直接使用,我们可以选择heapq来使用最大最小堆


使用示例

  • 堆排序示例

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    import heapq

    nums = [2, 3, 5, 1, 54, 23, 132]
    heap = []
    for num in nums:
    heapq.heappush(heap, num)
    # 等价于
    heapq.heapify(nums)

    # heap sort by incresing
    print([heapq.heappop(heap) for _ in range(len(nums))])
  • 加入元素

    1
    heapq.heappush(heap, num)
  • 弹出元素

    1
    num = heapq.heappop(heap)
  • 获取最大最小值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    import heapq

    nums = [1, 3, 4, 5, 2]
    print(heapq.nlargest(3, nums))
    print(heapq.nsmallest(3, nums))

    #Output:
    [5, 4, 3]
    [1, 2, 3]
  • 获取堆顶元素

    1
    top = nums[0]

最大堆的实现

  • 由于Python heapq模块只实现了最小堆, 最大堆需要我们自己实现
  • 一种简单可行的实现方案:
    • 在加入和弹出时,把元素取反 ,从而实现最大堆

Python——easydict包的使用


整体说明

  • EasyDict 是一个轻量级的 Python 库,旨在简化字典操作,它允许用户像访问对象属性一样访问字典的键值对,从而提高代码的可读性和简洁性
  • EasyDict 通过重写字典的几个关键方法,如__getattr__和__setattr__等,实现了将字典键转换为对象属性的功能
  • EasyDict 不仅支持顶级字典的属性访问方式,还能递归应用于内嵌的字典,使得处理多层次数据结构变得简单易行
  • EasyDict 实例仍然遵循标准字典的所有操作,保证了灵活性

安装 EasyDict

  • 可以使用pip进行安装,命令如下:
    1
    pip install easydict

使用示例

  • 简单使用示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    from easydict import EasyDict as edict
    # 创建一个EasyDict对象
    data = edict({'name': 'John', 'age': 30, 'job': 'Engineer'})
    # 访问字典元素
    print(data.age)
    print(data.job)
    # 添加新的键值对
    data.gender = 'Male'
    print(data.gender)
  • 嵌套字典的访问:

    1
    2
    3
    4
    5
    6
    7
    8
    my_dict = edict({
    'level1': edict({
    'level2': edict({
    'key': 'value'
    })
    })
    })
    print(my_dict.level1.level2.key)
  • 动态设置属性:

    1
    2
    3
    my_dict = edict()
    my_dict.key1 = 'value1'
    print(my_dict.key1)
  • 常见的字典操作:

    1
    2
    3
    4
    5
    6
    7
    my_dict = edict({'key1': 'value1'})
    # 更新字典
    my_dict.update({'key2': 'value2'})
    print(my_dict.key2)
    # 删除一个键值对
    my_dict.pop('key1')
    print(my_dict.key1) # 会抛出AttributeError,因为'key1'不再存在
  • 获取默认值:

    1
    2
    3
    my_dict = edict({'name': 'Alice'})
    value = my_dict.get('nonexistent_key', 'default_value')
    print(value)

EasyDict 和 namedtuple 对比

  • EasyDict 和 namedtuple 都是 Python 中用于简化数据访问的工具
  • TDRL:namedtuple 是”先定义类,再用类创建实例”;EasyDict 是”直接用通用类创建实例,动态定义结构”
    • namedtuple 需要先定义特定结构的类(如Person),再创建该类的实例,适合固定结构的数据
    • EasyDict 直接使用通用的EasyDict类创建实例,实例的字段结构可以动态变化,适合灵活的数据场景
  • TDRL:若需 固定结构、不可变数据 ,追求性能和内存效率,用 namedtuple;若需 动态结构、灵活修改 ,优先便捷性,用 EasyDict

本质与继承关系

  • namedtuple 是 tuple 的子类,属于不可变(immutable)数据结构
    • 一旦创建,其字段值无法修改,类似元组的特性
  • namedtuple 定义时需要指定固定的字段名,结构是静态的,不能动态添加新字段
  • EasyDict 是 dict 的子类,属于可变(mutable)数据结构
    • 创建后可以随时修改字段值,也能动态添加/删除新字段,保留了字典的灵活性

数据访问方式

  • 两者都支持 属性式访问(如 obj.field)和 键值访问(如 obj['field']),但底层实现不同:
    • namedtuple 本质是元组,字段值存储在固定位置,访问速度更快
    • EasyDict 本质是字典,通过重写 __getattr__ 实现属性访问,性能略低于 namedtuple

可变性

  • namedtuple 不可变:创建后无法修改字段值,也不能添加新字段,类似常量集合,示例如下:

    1
    2
    3
    4
    from collections import namedtuple
    Person = namedtuple('Person', ['name', 'age'])
    p = Person('Alice', 30)
    p.age = 31 # 报错:'Person' object does not support item assignment
  • EasyDict 可变:支持修改现有字段、添加新字段、删除字段等操作,示例如下:

    1
    2
    3
    4
    5
    from easydict import EasyDict as edict
    p = edict(name='Alice', age=30)
    p.age = 31 # 允许修改
    p.gender = 'Female' # 允许添加新字段
    del p.age # 允许删除字段

定义类情况

  • namedtuple 显式定义了一个新的类(如Person),这个类继承自tuple,并且在定义时就固定了字段结构,例如:

    1
    2
    3
    4
    from collections import namedtuple
    # 这里显式创建了一个名为 Person 的类
    Person = namedtuple('Person', ['name', 'age'])
    print(type(Person)) # 输出:<class 'type'>,说明是一个类
    • 后续使用时,Person() 是创建该类的实例,每个实例都严格遵循预定义的字段结构
  • EasyDict 没有要求你显式定义新的类(如Person),但它本身是一个通用的 EasyDict 类,所有实例都属于这个类,例如:

    1
    2
    3
    4
    from easydict import EasyDict as edict
    # 直接创建 EasyDict 类的实例,无需预先定义结构
    p = edict(name='Alice', age=30)
    print(type(p)) # 输出:<class 'easydict.EasyDict'>
    • 你将 p 视为一个”动态对象”,它属于 EasyDict 类,但其字段可以灵活添加/修改,不需要提前定义特定的类(如Person)

适用场景

  • namedtuple 适合存储 固定结构、不可变的数据(如配置项、记录、坐标等),强调数据的稳定性和内存效率
    • 例如:表示点坐标 Point(x=1, y=2)、数据库查询结果等
  • EasyDict 适合处理 动态结构、需要灵活修改的数据(如嵌套配置、JSON 数据解析等),强调操作的便捷性
    • 例如:解析 API 返回的 JSON 数据(可动态添加/修改字段)、多层级的配置文件等

其他差异

  • 内存占用 :namedtuple 比 EasyDict 更轻量,内存占用更少
  • 序列化 :两者都支持序列化,但 namedtuple 可直接通过 _asdict() 转换为普通字典,EasyDict 本身就是字典,可直接序列化
  • 类型提示 :namedtuple 在定义时已明确字段,类型提示更友好;EasyDict 动态字段较多,类型提示较弱

Python——field的用法


整体说明

  • 在 Python 中,field 主要关联两个核心场景:
    • 一是标准库 dataclasses 模块的 field() 函数 ,用于定制数据类字段
    • 二是第三方库如 pydantic 的 Field 类 (注意:首字母是大写), 用于数据校验/序列化

dataclasses.field()

  • dataclasses 是 Python 内置的轻量级数据类工具
  • field() 用于精细化定义数据类的字段(替代默认的简单赋值),支持定制默认值、初始化行为、序列化等
    • 注:Python 3.7+ 内置

dataclasses.field() 基础语法

  • 用法示例:

    1
    2
    3
    4
    5
    from dataclasses import dataclass, field

    @dataclass
    class ClassName:
    name = field(...) # 字段名: 类型 = Field(参数1=值1, 参数2=值2, ...)
  • field() 核心参数说明

    • default
      • 字段默认值(仅当字段无默认值时使用,与 default_factory 二选一)
      • 示例:field(default=0)
    • default_factory
      • 动态生成默认值的工厂函数(如列表/字典等可变类型)
      • 示例:field(default_factory=list)
    • init
      • 是否参与 __init__ 方法(默认 True)
      • 示例:field(init=False)
    • repr
      • 是否出现在 __repr__ 输出中(默认 True)
    • compare
      • 是否参与比较(__eq__/__lt__ 等,默认 True)
    • hash
      • 是否参与 __hash__ 计算(默认 None,继承 compare 值)
    • metadata
      • 附加元数据(字典,供外部工具使用)
      • field(metadata={"desc": "用户ID"})

dataclasses.field() 常用示例

  • 示例:基础使用(默认值/工厂函数)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    from dataclasses import dataclass, field

    @dataclass
    class User:
    # 简单默认值(不可变类型)
    id: int = field(default=0)
    # 可变类型默认值(必须用 default_factory,避免所有实例共享同一对象)
    tags: list[str] = field(default_factory=list)
    # 字符串默认值
    name: str = field(default="未知用户")

    # 实例化
    u1 = User()
    print(u1) # User(id=0, tags=[], name='未知用户')
    u1.tags.append("admin")
    u2 = User()
    print(u2.tags) # [](独立的列表,无共享问题)
  • 示例2:定制初始化/序列化行为

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    @dataclass
    class Product:
    name: str
    # 不参与 __init__(手动赋值)
    price: float = field(init=False)
    # 不显示在 repr 中
    stock: int = field(default=0, repr=False)
    # 不参与比较
    sku: str = field(default="", compare=False)

    # 实例化(无需传 price 和 stock/sku)
    p = Product("手机")
    p.price = 2999.99 # 手动赋值
    print(p) # Product(name='手机', price=2999.99)(stock 未显示)
    print(p == Product("手机")) # True(sku 不参与比较)
  • 示例3:附加元数据

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    @dataclass
    class Student:
    id: int = field(metadata={"desc": "学生学号", "required": True})
    score: float = field(default=0.0, metadata={"min": 0, "max": 100})

    # 获取元数据
    s = Student(1001)
    # 方式1:通过 dataclasses.fields 获取
    from dataclasses import fields
    for f in fields(s):
    print(f.name, f.metadata)
    # 输出:
    # id {'desc': '学生学号', 'required': True}
    # score {'min': 0, 'max': 100}

pydantic.Field(第三方库,数据校验)

  • pydantic 是Python主流的数据校验库
  • Field 用于定义模型字段的校验规则、默认值、文档等,功能比 dataclasses.field 更丰富
  • 使用前先安装:
    1
    pip install pydantic

pydantic.Field 基础语法

  • 用法说明:

    1
    2
    3
    4
    from pydantic import BaseModel, Field

    class ClassName(BaseModel):
    name = field(...) # 字段名: 类型 = Field(默认值, 参数1=值1, 参数2=值2, ...)
  • 核心参数说明

    • default/default_factory:
      • 默认值/动态默认值(同dataclasses)
      • 示例:Field(default=10) / Field(default_factory=list)
    • alias
      • 字段别名(序列化/反序列化时可用)
      • 示例: Field(alias="user_id")
    • gt/ge/ lt/le
      • 数值大于/大于等于/小于/小于等于
      • 示例:Field(gt=0)(值必须>0)
    • min_length/max_length
      • 字符串最小/最大长度
      • 示例:Field(min_length=2, max_length=10)
    • pattern
      • 字符串正则匹配
      • 示例:Field(pattern=r"^[A-Z]+$")
    • description
      • 字段描述(文档生成)
      • 示例:Field(description="用户年龄")
    • nullable
      • 是否允许为None(Pydantic v1,v2需用 Optional)
      • 示例:Field(nullable=True)
    • examples
      • 示例值(OpenAPI文档)
      • 示例:Field(examples=[18, 20])

pydantic.Field常用示例

  • 示例1:基础校验

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    from pydantic import BaseModel, Field

    class User(BaseModel):
    name: str = Field(..., min_length=2, max_length=20, description="用户名(2-20字符)")
    age: int = Field(..., gt=0, le=120, description="年龄(1-120)")
    email: str = Field(None, pattern=r"^[\w-]+@[\w-]+\.[a-z]+$", description="邮箱(可选)")

    # 合法实例
    u1 = User(name="张三", age=25, email="zhangsan@example.com")
    print(u1.model_dump())
    # 输出:{'name': '张三', 'age': 25, 'email': 'zhangsan@example.com'}

    # 非法实例(触发校验错误)
    try:
    u2 = User(name="李", age=150, email="invalid-email")
    except Exception as e:
    print(e)
    # 输出:
    # 1 validation error for User
    # name
    # String should have at least 2 characters [type=string_too_short, input_value='李', input_type=str]
    # age
    # Input should be less than or equal to 120 [type=less_than_or_equal, input_value=150, input_type=int]
    # email
    # String should match pattern '^[\w-]+@[\w-]+\.[a-z]+$' [type=string_pattern_mismatch, input_value='invalid-email', input_type=str]
  • 示例2:别名与默认值

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    class Product(BaseModel):
    # 别名:序列化时用 product_id,反序列化时可传 id 或 product_id
    id: int = Field(..., alias="product_id")
    # 动态默认值(每次实例化生成新列表)
    tags: list[str] = Field(default_factory=lambda: ["未分类"])

    # 用别名传参
    p = Product(product_id=1001)
    print(p.id) # 1001
    print(p.tags) # ['未分类']
    # 序列化(输出别名)
    print(p.model_dump(by_alias=True)) # {'product_id': 1001, 'tags': ['未分类']}
  • 示例3:结合文档(OpenAPI)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    from pydantic import BaseModel, Field
    from pydantic.schema import schema

    class Order(BaseModel):
    order_id: str = Field(..., pattern=r"^ORD-\d{6}$", description="订单号(格式:ORD-6位数字)")
    amount: float = Field(..., gt=0, examples=[99.9, 199.0], description="订单金额(>0)")

    # 生成JSON Schema(用于OpenAPI文档)
    schema_dict = schema([Order])
    print(schema_dict)

三、pydantic.Field 和 dataclasses.field 区别

  • dataclasses.field:
    • 内置库
    • 简单数据存储、无校验需求
    • 无校验能力(仅基础类型注解)
  • pydantic.Field:
    • 第三方接口
    • 接口参数校验、数据清洗、API文档
    • 强大的数值/字符串/结构校验

使用注意:

  • 可变类型默认值 :
    • 无论是 dataclasses.field 还是 pydantic.Field,可变类型(list/dict/set)的默认值必须用 default_factory,否则所有实例会共享同一对象
    • 错误:tags: list = []
    • 正确:tags: list = field(default_factory=list)
  • Pydantic版本差异 :
    • v1 中 nullable=True 允许字段为None;v2 需用 Optional[类型](如 age: Optional[int] = Field(None))
    • v2 中 Field 的参数更简洁,推荐使用最新版
  • dataclasses 不可变字段 :
    • 若需不可变数据类,加 @dataclass(frozen=True),此时 init=False 的字段需在 __post_init__ 中赋值

Python——iter函数的用法


整体说明

  • 在 Python 里,iter() 函数主要用于生成迭代器
  • 迭代器用于遍历可迭代对象(像列表、元组、字典这样的),能逐个获取对象里的元素
  • 用法可总结如下:
    • iter() 函数的主要作用是把可迭代对象转变为迭代器
    • 迭代器通过 next() 函数来获取下一个元素
    • 可以通过自定义类并实现 __iter__() 和 __next__() 方法来自定义迭代器
    • 使用 StopIteration 异常或者设置哨值能够终止迭代

iter()函数的基本用法

  • 函数用法:

    1
    iter(iterable)
    • 这里的 iterable 可以是列表、元组、字符串、集合等可迭代对象
  • 示例:遍历列表

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    my_list = [1, 2, 3, 4, 5]
    my_iter = iter(my_list)

    # # 错误用法,会抛出异常
    # for i in range(10):
    # print(next(my_iter)) # 依次输出:1,2,3,4,5,第6次调用时直接抛出异常 StopIteration

    # 正确用法
    for i in my_iter:
    print(i) # 依次输出:1,2,3,4,5 然后停止

自定义迭代器(无终止迭代)

  • 借助 iter() 函数,还能自定义迭代器,这需要在类中实现 __iter__() 和 __next__() 方法
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    class MyNumbers:
    def __iter__(self):
    self.a = 1
    return self

    def __next__(self):
    x = self.a
    self.a += 1
    return x

    myclass = MyNumbers()
    myiter = iter(myclass)

    print(next(myiter)) # 输出:1
    print(next(myiter)) # 输出:2
    print(next(myiter)) # 输出:3

自定义迭代器(终止迭代)

  • 在自定义迭代器时,可以使用 StopIteration 异常来终止迭代
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    class MyNumbers:
    def __iter__(self):
    self.a = 1
    return self

    def __next__(self):
    if self.a <= 3:
    x = self.a
    self.a += 1
    return x
    else:
    raise StopIteration

    myclass = MyNumbers()
    myiter = iter(myclass)

    for x in myiter:
    print(x) # 依次输出1,2,3后停止(注意程序是自然退出的,不会抛出异常)

附录:Iterable 和 Iterator 的区别

  • 可迭代对象(Iterable) :任何可以被 iter() 函数调用并返回一个迭代器的对象
    • 常见例子:列表(list)、元组(tuple)、字符串(str)、字典(dict)、集合(set)等
    • 可以用 for 循环遍历,不存储迭代状态(即每次调用 iter() 都会生成一个新的迭代器)
    • 可以被多次迭代(每次都是新的迭代器)
  • 迭代器(Iterator) :实现了 __next__() 方法和 __iter__() 方法的对象
    • 常见例子:由 iter() 函数返回的对象、生成器(generator)等
    • 有“状态”,记录当前迭代位置
    • 调用 next() 方法会返回下一个元素,直到耗尽后抛出 StopIteration
    • 只能迭代一次(无法重置,耗尽后失效)
    • __iter__() 方法返回自身(所以迭代器也是一种可迭代对象)
  • 特别说明:DataLoader 本身是一个可迭代对象(iterable),而非一次性迭代器(iterator)

附录:特别说明 iter() 的第二个参数

  • iter() 函数还有一种不太常见的用法,就是接收两个参数
    • 第一个参数得是个可调用对象(像函数)
    • 第二个参数是哨值
  • 当可调用对象返回的值等于哨值时,迭代就会停止
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    def my_function():
    value = input("请输入内容(输入 'q' 结束):")
    return value

    # 创建迭代器,当输入 'q' 时停止
    my_iter = iter(my_function, 'q')

    for value in my_iter:
    print(f"你输入的是:{value}")

    # 请输入内容(输入 'q' 结束):123
    # 你输入的是:123
    # 请输入内容(输入 'q' 结束):q

附录:for 循环和 __iter__ 函数的用法

  • 在 for i in x 循环中,x.__iter__() 只会被调用一次 ,且该方法的返回值必须是一个 Iterator(迭代器)
    • 这属于 Python 迭代协议(Iteration Protocol)的核心要求
  • __iter__ 的调用 1 次后,循环的所有迭代过程,都基于 __iter__ 返回的同一个迭代器
  • __iter__ 的返回值要求:必须返回一个实现了迭代器协议的对象(即同时具有 __iter__() 和 __next__() 方法的对象)
    • 可迭代对象的 __iter__:“生产迭代器”(返回新的迭代器实例);
    • 迭代器的 __iter__:“暴露自己”(返回自己,因为自己就是 “干活的”)
    • 迭代器的 __next__(),真正用于返回下一个元素
  • 支持 for 循环的对象,一定是具有 __iter__() 函数的

原理拆解:for 循环的执行流程

  • for i in x 的底层逻辑完全遵循迭代协议,步骤如下:
    • 1)调用 x.__iter__(),获取一个迭代器对象(记为 it);
    • 2)反复调用 it.__next__(),每次返回的结果赋值给 i,执行循环体;
    • 3)当 it.__next__() 抛出 StopIteration 异常时,循环捕获该异常并正常终止(不会暴露给用户)
  • 注:整个过程中,x.__iter__() 只在第一步执行一次,后续所有迭代都依赖第一步返回的那个 it 迭代器
1…464748…61
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

608 posts
49 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4