ML——xgboost包使用笔记

xgboost包中包含了XGBoost分类器,回归器等, 本文详细介绍XGBClassifier


安装和导入

  • 安装

    1
    pip install xgboost
  • 导入

    1
    import xgboost as xgb
  • 使用

    1
    clf = xgb.XGBClassifier()

模型参数

普通参数

以下参数按照我理解的重要性排序

  • booster:
    • ‘gbtree’: 使用树模型作为基分类器
    • ‘gbliner’: 使用线性模型作为基分类器
    • 默认使用模型树模型即可,因为使用线性分类器时XGBoost相当于退化成含有L1和L2正则化的逻辑回归(分类问题中)或者线性回归(回归问题中)
  • n_estimators: 基分类器数量
    • 每个分类器都需要一轮训练,基分类器越多,训练所需要的时间越多
    • 经测试发现,开始时越大越能提升模型性能,但是增加到一定程度后模型变化不大,甚至出现过拟合
  • max_depth[default=3]: 每棵树的最大深度
    • 树越深,越容易过拟合
  • objective[default="binary:logistic"]: 目标(损失函数)函数,训练的目标是最小化损失函数
    • ‘binary:logistic’: 二分类回归, XGBClassifier默认是这个,因为XGBClassifier是分类器
    • ‘reg:linear’: 线性回归, XGBRegressor默认使用这个
    • ‘multi:softmax’: 多分类中的softmax
    • ‘multi:softprob’: 与softmax相同,但是每个类别返回的是当前类别的概率值而不是普通的softmax值
  • n_jobs: 线程数量
    • 以前使用的是nthread, 现在已经不使用了,直接使用n_jobs即可
    • 经测试发现并不是越多越快, 猜测原因可能是因为各个线程之间交互需要代价
  • reg_alpha: L1正则化系数
  • reg_lambda: L2正则化系数
  • subsample: 样本的下采样率
  • colsample: 构建每棵树时的样本特征下采样率
  • scale_pos_weight: 用于平衡正负样本不均衡问题, 有助于样本不平衡时训练的收敛
    • 具体调参实验还需测试[待更新]
    • 这个值可以作为计算损失时正样本的权重
  • learning_rate: shrinkage参数
    • 更新叶子结点权重时,乘以该系数,避免步长过大,减小学习率,增加学习次数
    • 在公式中叫做eta, 也就是 \(\eta\)
  • min_child_weight[default=1]: [待更新]
  • max_leaf_nodes: 最大叶子结点数目
    • 也是用于控制过拟合, 和max_depth的作用差不多
  • importance_type: 指明特征重要性评估方式, 只有在booster为’gbtree’时有效
    • ‘gain’: [默认], is the average gain of splits which use the feature
    • ‘cover’: is the average coverage of splits which use the feature
    • ‘weight’: is the number of times a feature appears in a tree
    • ‘total_gain’: 整体增益
    • ‘total_cover’: 整体覆盖率

常用函数

  • feature_importances_:

    • 返回特征的重要性列表
    • 特征重要性可以由不同方式评估
    • 特征重要性评估指标(importance_type)在创建时指定, 使用plot_importance函数的话,可以在使用函数时指定
  • plot_importance: 按照递减顺序给出每个特征的重要性排序图

    • 使用方式

      1
      2
      3
      4
      from xgboost import plot_importance
      from matplotlib import pyplot
      plot_importance(model)
      pyplot.show()
    • 详细定义

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      def plot_importance(booster, ax=None, height=0.2,
      xlim=None, ylim=None, title='Feature importance',
      xlabel='F score', ylabel='Features',
      importance_type='weight', max_num_features=None,
      grid=True, show_values=True, **kwargs):
      """Plot importance based on fitted trees.
      Parameters
      ----------
      booster : Booster, XGBModel or dict
      Booster or XGBModel instance, or dict taken by Booster.get_fscore()
      ax : matplotlib Axes, default None
      Target axes instance. If None, new figure and axes will be created.
      grid : bool, Turn the axes grids on or off. Default is True (On).
      importance_type : str, default "weight"
      How the importance is calculated: either "weight", "gain", or "cover"
      * "weight" is the number of times a feature appears in a tree
      * "gain" is the average gain of splits which use the feature
      * "cover" is the average coverage of splits which use the feature
      where coverage is defined as the number of samples affected by the split
      max_num_features : int, default None
      Maximum number of top features displayed on plot. If None, all features will be displayed.
      height : float, default 0.2
      Bar height, passed to ax.barh()
      xlim : tuple, default None
      Tuple passed to axes.xlim()
      ylim : tuple, default None
      Tuple passed to axes.ylim()
      title : str, default "Feature importance"
      Axes title. To disable, pass None.
      xlabel : str, default "F score"
      X axis title label. To disable, pass None.
      ylabel : str, default "Features"
      Y axis title label. To disable, pass None.
      show_values : bool, default True
      Show values on plot. To disable, pass False.
      kwargs :
      Other keywords passed to ax.barh()
      Returns
      -------
      ax : matplotlib Axes
      """