Jiahong 的个人博客

凡事预则立,不预则废


  • Home

  • Tags

  • Archives

  • Navigation

  • Search

Python——Typer工具的使用

  • 参考链接:
    • Typer, build great CLIs. Easy to code. Based on Python type hints.
      • 本文大部分内容基于这篇说明文档

Typer 整体介绍

  • Typer 是一款基于 Python 类型注解(Type Hints)的 CLI 开发库,被称为CLI 界的 FastAPI ,是 FastAPI 的同作者开发的“轻兄弟”库
  • 理解:Typer(是 CLI 开发库)的本质是将 Python 函数变成可以在 Terminal 直接调用的命令形式
  • Typer 基于 Click 构建,兼具代码简洁、自动生成帮助文档、终端自动补全、支持复杂子命令等特性,开发者仅需少量代码就能构建专业、易用的命令行工具,同时对新手极其友好
    • 注:Click 是 Python 生态中一款经典、成熟的命令行界面(CLI)开发框架,也是 Typer 的底层核心依赖——Typer 本质上是对 Click 的高级封装,基于 Click 实现了所有 CLI 核心能力,同时通过 Python 类型注解简化了 Click 的使用流程

前置补充:Click 介绍

  • Click 由 Pallets 团队开发(同 Flask 开发团队),是 Python 中最主流的 CLI 开发库之一,解决了 Python 内置 argparse 模块代码繁琐、体验不佳、扩展能力弱的问题,核心优势是:
    • 基于装饰器实现简洁的命令 / 参数定义
    • 原生支持子命令、选项 / 参数解析、终端补全
    • 兼容所有主流终端,支持跨平台(Windows/Linux/macOS)
    • 提供丰富的扩展能力(如进度条、密码输入、颜色输出)

Typer 核心优势

  • 基于 Python 原生类型注解,无需学习新语法,少量代码即可实现 CLI 功能
  • 自动生成 --help 帮助文档、终端自动补全(支持 Bash/Zsh/Fish/PowerShell),bool 类型参数自动生成 --xxx/--no-xxx 双选项
  • 可灵活扩展,从简单单命令到多层嵌套子命令,可随项目复杂度无缝升级
  • 无缝兼容 Python 代码,无需修改现有 Python 脚本,直接通过 typer 命令将普通函数转为 CLI 工具
  • 内置美观的错误提示(基于 Rich)、进度条、彩色输出,提升用户使用体验

Typer 安装

  • Typer 支持 Python 3.6+,推荐在虚拟环境中安装,执行以下命令即可:

    1
    pip install typer
  • 安装完成后会自动附带三个核心依赖:

    • Click :Python 经典 CLI 框架,Typer 的底层基础
    • Rich :实现美观的格式化输出、彩色错误提示
    • shellingham :自动检测当前终端类型,支持自动补全安装

Typer 使用简单示例

  • Typer 的使用分为无侵入式运行普通脚本 和显式使用 Typer 开发 两种方式,先从最简单的无侵入式开始
  • 示例来自:

无侵入式:普通脚本直接转 CLI

  • 无需在代码中引入 Typer,直接将普通带类型注解的 Python 函数转为 CLI 工具

  • Step 1,创建 main.py,编写普通函数:

    1
    2
    3
    # 仅需普通Python代码+类型注解,无Typer相关代码
    def main(name: str):
    print(f"Hello {name}!")
  • Step 2,通过 typer 命令运行:

    1
    2
    3
    4
    # 查看帮助
    typer main.py run --help
    # 传入参数运行
    typer main.py run 张三
  • Step 3,效果:自动识别 name 为必填字符串参数,缺失时会抛出美观的错误提示,无需手动处理参数解析

显式使用(入侵式):引入 Typer 开发

  • 在代码中引入 Typer,可直接通过 python 命令运行,更适合正式开发

  • 修改 main.py,仅需2行新增代码(导入+运行):

    1
    2
    3
    4
    5
    6
    7
    import typer  # 新增:导入Typer

    def main(name: str):
    print(f"Hello {name}!")

    if __name__ == "__main__":
    typer.run(main) # 新增:运行Typer应用
  • 直接通过 Python 运行,体验与原生 CLI 工具一致:

    1
    2
    3
    4
    # 查看帮助
    python main.py --help
    # 传入参数运行
    python main.py 张三

Typer 进阶用法:多子命令开发

  • 当 CLI 工具需要多个功能时,Typer 支持通过 @app.command() 装饰器创建子命令 ,类似 Git 的 git add/git commit 模式,结构清晰

基础多子命令代码示例

  • 创建包含 hello(问候)和 goodbye(告别)两个子命令的工具,goodbye 新增布尔可选参数 formal(正式模式):
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    import typer

    # 1. 创建Typer应用实例,作为CLI入口
    app = typer.Typer()

    # 2. 用@app.command()装饰函数,转为子命令
    @app.command()
    def hello(name: str):
    """子命令:普通问候"""
    print(f"Hello {name}!")

    @app.command()
    def goodbye(name: str, formal: bool = False):
    """子命令:告别,--formal 开启正式模式"""
    if formal:
    print(f"Goodbye Ms. {name}. Have a good day!")
    else:
    print(f"Bye {name}!")

    # 3. 运行应用
    if __name__ == "__main__":
    app()

基础多子命令运行与使用示例

  • 查看全局帮助(显示所有子命令):

    1
    python main.py --help
    • 会自动列出 hello、goodbye 两个子命令,以及 --install-completion(安装终端补全)等全局选项
  • 查看子命令帮助:

    1
    2
    # 查看goodbye子命令的帮助,会显示--formal/--no-formal选项
    python main.py goodbye --help
  • 执行子命令:

    1
    2
    3
    4
    # 普通告别
    python main.py goodbye 张三
    # 正式模式告别(bool参数自动生成--formal选项)
    python main.py goodbye --formal 张三

基础多子命令示例关键特性说明

  • 布尔参数自动优化 :定义 formal: bool = False 后,Typer 会自动生成 --formal(开启)和 --no-formal(关闭)两个选项,无需手动配置
  • 帮助文档自动生成 :函数的文档字符串("""注释""")会自动作为子命令的帮助说明,--help 中直接显示
  • 命令名省略规则 :仅单个命令时,可直接 python main.py 参数;多个子命令时,必须显式指定子命令名(如 python main.py hello 张三)

Typer 核心语法:参数定义

  • Typer 完全基于Python 原生类型注解 定义 CLI 参数,无需学习额外的装饰器或语法,支持所有常见类型

基础类型参数

  • 直接通过类型注解定义,Typer 自动解析为 CLI 位置参数/选项:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    @app.command()
    def demo(
    # 必填字符串参数(位置参数)
    name: str,
    # 可选整数参数,默认值10(选项参数)
    age: int = 10,
    # 布尔参数,默认False(自动生成--is-adult/--no-is-adult)
    is_adult: bool = False
    ):
    # 推荐使用 `typer.echo()` 替代 `print()`,支持彩色输出、跨终端兼容
    typer.echo(f"姓名:{name},年龄:{age},是否成年:{is_adult}")

常用参数类型

  • 除基础类型外,Typer 还支持以下常用类型,直接注解即可:
    • 列表 :List[str],接收多个参数
    • 路径 :Path,自动校验文件/目录是否存在
    • 枚举 :Enum,实现参数可选值限制
    • 文件 :typer.File(),直接读取文件对象
  • 示例(枚举+列表):
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    from typing import List
    from enum import Enum
    import typer

    # 枚举:限制gender的可选值
    class Gender(str, Enum):
    MALE = "male"
    FEMALE = "female"

    @app.command()
    def user(
    name: str,
    gender: Gender, # 仅能传入male/female
    hobbies: List[str] = None # 接收多个爱好,如--hobbies 篮球 读书
    ):
    typer.echo(f"姓名:{name},性别:{gender},爱好:{hobbies}")

Typer 高级功能:子命令组与全局配置

  • 当 CLI 工具功能复杂时,可创建子命令组(如按模块拆分:db backup/db restore),并通过 @app.callback() 实现全局配置(如环境、全局参数)

Typer 子命令组示例(数据库工具)

  • 示例代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import typer

    # 全局应用入口
    app = typer.Typer(help="多功能CLI工具")

    # 创建子命令组:db(数据库相关操作)
    db_app = typer.Typer(help="数据库操作子命令组")
    # 将子命令组添加到全局应用,命令名为db
    app.add_typer(db_app, name="db")

    # db子命令组下的子命令:backup(备份)
    @db_app.command()
    def backup(path: str = "./backup.db"):
    typer.echo(f"正在备份数据库到:{path}")

    # db子命令组下的子命令:restore(恢复)
    @db_app.command()
    def restore(path: str = "./backup.db"):
    typer.echo(f"正在从{path}恢复数据库")

    # 全局子命令:无归属,直接在根目录
    @app.command()
    def version():
    typer.echo("CLI工具版本:v1.0.0")

    if __name__ == "__main__":
    app()
  • 运行:

    1
    2
    3
    4
    5
    6
    # 查看数据库子命令组帮助
    python main.py db --help
    # 执行数据库备份
    python main.py db backup --path D:/mydb.db
    # 执行全局版本命令
    python main.py version

Typer 全局配置(@app.callback())

  • 通过 @app.callback() 定义全局参数(如运行环境 env),所有子命令均可共享:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import typer
    from typer import Context

    app = typer.Typer(help="带全局配置的CLI工具")

    # 全局回调:定义全局参数,所有子命令生效,所有子命令执行前,都会先执行这个函数
    @app.callback()
    def global_config(ctx: Context, env: str = typer.Option("dev", help="运行环境:dev/prod/test")):
    # 将全局参数存入ctx,供子命令获取
    ctx.ensure_object(dict)
    ctx.obj["env"] = env # 注:所有子命令的 ctx 对象都可以传入 env 参数,并可通过 ctx.obj["env"] 获取到 env

    @app.command()
    def run(ctx: Context):
    # 获取全局配置的env
    env = ctx.obj["env"]
    typer.echo(f"在{env}环境中运行程序...")

    if __name__ == "__main__":
    app()
  • 运行:

    1
    2
    3
    4
    # 用默认dev环境运行
    python main.py run
    # 指定prod环境运行
    python main.py run --env prod

Typer 实用功能:终端补全与进度条

安装终端自动补全

  • Typer 支持一键安装终端自动补全,输入以下命令后按提示操作即可:

    1
    python main.py --install-completion
    • 支持 Bash、Zsh、Fish、PowerShell 等主流终端,安装后输入命令时按 Tab 即可自动补全子命令和参数

内置进度条

  • 处理耗时操作(如下载、批量处理)时,Typer 内置进度条功能,无需额外安装库,一行代码实现:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    import typer
    import time

    app = typer.Typer()

    @app.command()
    def download(total: int = 10):
    """模拟下载,显示进度条"""
    # 用 typer.progressbar 创建进度条
    with typer.progressbar(range(total), label="下载中") as progress:
    for i in progress:
    time.sleep(0.5) # 模拟耗时操作
    typer.echo("下载完成!")

    if __name__ == "__main__":
    app()

Python——Ray-远程函数与本地函数的区别


整体说明

  • 远程函数与本地函数的区别主要在 序列化机制 和 执行位置 两个维度
  • 序列化本质差异:
    • 本地函数可以理解为“传引用”,依赖执行环境已有定义
      • 注:本地函数也不仅仅是 “传引用”
        • Python 的 pickle 序列化函数时,实际上是序列化函数的名称和所在模块的路径
        • 反序列化时,需要在目标环境中导入同名模块、找到同名函数
        • 因此 Python 本地函数调用依赖目标环境与源环境“同构”
        • Ray 跨节点时,Worker 进程的 __main__ 模块通常与 Driver 不同,所以会失败
    • Ray 远程函数是 “传定义+环境” ,集群自动同步,支持跨节点;
  • 执行位置差异:
    • 本地函数固定在调用方进程,无分布式能力;
    • Ray 远程函数由集群调度,可分布式并发执行;
  • 使用场景:
    • 本地函数:适用于单进程/单节点的简单逻辑,无需分布式;
    • Ray 远程函数:适用于分布式计算、并发任务、跨节点执行,是 Ray 分布式能力的核心
  • 核心差异总览
    对比维度 本地函数(未用 @ray.remote 装饰) Ray 远程函数(用 @ray.remote 装饰)
    序列化方式 依赖 Python 原生 pickle,仅序列化「函数引用」 Ray 自定义序列化(结合 pickle+集群元数据),序列化「函数元信息+代码定义」
    序列化限制 无法跨节点传递(远程节点无函数定义,引用失效) 可跨节点传递(集群自动同步函数定义到执行节点)
    执行位置 固定在「调用方所在的本地进程/线程」 分布式调度到「集群任意节点的 Worker 进程」(可指定资源)
    执行特性 同步执行,阻塞调用方;无并发调度能力 异步执行,返回 ObjectRef;支持集群级并发/分布式调度
    依赖传递 需手动确保执行环境有函数依赖(如导入、变量) Ray 自动打包函数依赖(如嵌套函数、闭包变量)并分发

序列化机制:“仅传引用” vs “传定义+元信息”

  • 序列化的核心目的是:让函数能在「非定义环境」中被正确执行
  • 两者的序列化逻辑完全不同:

本地函数:仅序列化“函数引用”,无实际代码

  • Python 原生 pickle 序列化本地函数时,不会打包函数的代码本身 ,只会记录函数的「模块路径+函数名」(比如 __main__.add)
  • 这种“引用式序列化”仅在「同一进程/同一节点且函数已定义」的场景下有效,跨节点会直接失效
  • 错误示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import ray

    ray.init(ignore_reinit_error=True)

    # 本地函数
    def add_remote(a, b):
    return a + b

    # 直接传递远程函数的引用(Ray 自动处理序列化)
    @ray.remote
    def execute_remote_func(func, x, y):
    return func(x,y) # 远程工作进程无法识别调用方的 local func,错误

    # 跨节点调度执行(单节点可以成功,但集群有多个节点会失败)
    result_ref = execute_remote_func.remote(add_remote, 2, 3)
    print(ray.get(result_ref)) # 单节点输出:5(成功执行);多节点执行错误

    ray.shutdown()

Ray 远程函数:序列化“函数元信息+代码定义”

  • Ray 对远程函数的序列化做了增强 :
    • 1)序列化时,不仅记录函数引用,还会打包函数的代码定义、依赖模块、闭包变量(若有);
    • 2)远程节点接收后,会自动还原函数的执行环境(无需手动导入);
    • 3)底层用 Ray 自定义的序列化器(兼容 pickle,但更适合分布式场景)
  • 正确示例:远程函数跨节点调用成功
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import ray

    ray.init(ignore_reinit_error=True)

    # Ray 远程函数(已注册,自动序列化代码)
    @ray.remote
    def add_remote(a, b):
    return a + b

    # 直接传递远程函数的引用(Ray 自动处理序列化)
    @ray.remote
    def execute_remote_func(func, x, y):
    return ray.get(func.remote(x, y)) # 远程节点能识别并执行

    # 跨节点调度执行(即使集群有多个节点也能成功)
    result_ref = execute_remote_func.remote(add_remote, 2, 3) # 注意:传入的参数 add_remote 本身也需要是 @ray.remote 封装过的 Ray 远程函数
    print(ray.get(result_ref)) # 输出:5(成功执行)

    ray.shutdown()

补充:Ray 还支持 嵌套远程函数 闭包变量传递

  • 比如在远程函数中引用本地变量,Ray 会自动序列化传递:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import ray

    ray.init(ignore_reinit_error=True)

    @ray.remote
    def outer_remote(x):
    # 闭包变量 x 会被 Ray 自动序列化到远程节点
    @ray.remote
    def inner_remote(y):
    return x + y
    return inner_remote.remote(10)

    print(ray.get(ray.get(outer_remote.remote(5)))) # 输出:15

    ray.shutdown()

执行位置:“本地固定” vs “集群分布式调度”

  • 执行位置的差异是两者最直观的区别,直接决定了是否能利用集群资源:

本地函数:执行在 调用方所在进程

  • 本地函数的执行位置完全固定:
    • 无论在哪里调用(即使在远程函数内部调用本地函数),函数都会在 发起调用的进程 中执行【存疑】
      • 问题:这里描述有错(部分书籍会这样写),理论上远程函数内部无法调用本地函数,所以应该加上一句,在可以调用成功的前提下
    • 若在远程函数中调用本地函数,本质是在「远程节点的 Worker 进程」中执行,但该进程没有本地函数的定义(除非手动同步代码),所以必然失败;
    • 无并发能力:多个调用会串行执行在同一个进程/线程(或 Python 多进程的子进程,但需手动管理)

远程函数:执行在「集群 Worker 进程」

  • Ray 远程函数的执行位置由 Ray 集群的调度器统一管理:

    • 1)调用 func.remote() 时,会向 Ray 调度器提交一个任务
    • 2)调度器根据集群节点的资源(CPU、GPU、内存)情况,将任务分配到任意可用节点的 Worker 进程
    • 3)执行完成后,结果会存储在 Ray 的对象存储中,通过 ray.get() 可获取
    • 4)支持并发:多个 remote() 调用会被调度到不同 Worker 进程/节点,并行执行
  • 示例:远程函数分布式执行(多节点/多进程并发)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
       import ray
    import os
    import time

    # os.environ["RAY_DEDUP_LOGS"] = "0" # 本意是让每个进程结果都完整输出,但这行代码仅当前进程生效,需要启动前配置环境变量才可以
    # # 如果是通过代码定义,则在 os.environ 设置在 ray.init() 之前进行才能生效,因为 Worker 进程在初始化时读取环境变量
    ray.init(ignore_reinit_error=True)

    # Ray 远程函数:打印执行节点的进程 ID 和节点名
    @ray.remote
    def add_remote(a, b):
    node_name = ray.util.get_node_ip_address() # 获取执行节点 IP
    pid = os.getpid() # 获取执行进程 ID
    print(f"在节点 {node_name} 的进程 {pid} 执行 add({a}, {b})")
    time.sleep(1) # 模拟耗时操作
    return a + b

    # 提交 5 个并发任务(会被调度到不同 Worker 进程)
    start = time.time()
    result_refs = [add_remote.remote(i, i*2) for i in range(5)]
    results = ray.get(result_refs) # 等待所有任务完成
    end = time.time()

    print("结果:", results) # 输出:[0, 3, 6, 9, 12]
    print(f"总耗时: {end - start:.2f}s") # 约 1s(并发执行,而非 5s 串行)

    ray.shutdown()
  • 执行上述脚本:

    1
    2
    export RAY_DEDUP_LOGS=0
    python demo.py
    • 注意:仅在代码里面添加 os.environ["RAY_DEDUP_LOGS"] = "0" 是不够的,因为:
      • Ray 的日志去重功能是在 Worker 进程启动时就决定的,而 Worker 是由 Ray 的主进程(Driver)启动的
      • 上面的代码在 ray.init() 之后才启动 Worker,那么环境变量必须在 Driver 启动 Worker 之前就传递过去,否则 Worker 进程会继承默认的去重配置
      • 所以最安全的打印所有日志的方式就是再启动脚本前配置环境变量
    • 另一种实现方式是在远程函数中返回 PID,然后由 Driver 打印
  • 输出示例:

    1
    2
    3
    4
    5
    6
    7
    8
    2025-11-04 11:42:43,175 INFO worker.py:1918 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
    (add_remote pid=14393) 在节点 127.0.0.1 的进程 14393 执行 add(2, 4)
    (add_remote pid=14399) 在节点 127.0.0.1 的进程 14399 执行 add(4, 8)
    (add_remote pid=14398) 在节点 127.0.0.1 的进程 14398 执行 add(1, 2)
    (add_remote pid=14400) 在节点 127.0.0.1 的进程 14400 执行 add(3, 6)
    (add_remote pid=14396) 在节点 127.0.0.1 的进程 14396 执行 add(0, 0)
    结果: [0, 3, 6, 9, 12]
    总耗时: 1.62s

附录:远程调用时传入的函数指针必须是远程函数

  • 在 Ray 中不支持直接传入 local 函数指针作为远程函数的执行对象,需通过 Ray 装饰器(@ray.remote)将函数注册为远程可执行,再通过 函数名.remote() 调用(本质是基于函数标识而非指针传递)
  • 总结:
    • 不推荐将普通函数作为参数传递给 Ray 远程函数
    • 推荐使用 @ray.remote 装饰器或在远程函数内部定义逻辑
    • 注意:一些代码在单机环境下可能碰巧能运行,但不具有可移植性和可靠性(这一点需要注意 Ray 本地调试通过可能也无法分布式运行)

错误示例(未注册本地函数)

  • 若 add 未被 @ray.remote 注册,它只是一个本地函数 ,无法在 Ray 分布式环境中执行,直接传递给远程函数(如 execute_func)会报错

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import ray

    ray.init(ignore_reinit_error=True)

    # 未注册的本地函数
    def add(a, b):
    return a + b

    # 已注册的远程函数
    @ray.remote
    def execute_func(func, x, y):
    # 此处调用本地函数会失败,因为 func 在远程节点无定义
    # # 远程节点的工作进程无法导入本地主模块的 add_local 函数,也无法序列化传递普通函数,可能会直接抛出 SerializationError
    # # 单进程/单节点下调用指针函数可以执行,但是分布式情况下,local_func 无法被序列化,会出错
    return func(x, y) # 报错:NameError,PicklingError 或 SerializationError

    # 调用会抛出异常
    try:
    result = ray.get(execute_func.remote(add, 4, 6))
    except Exception as e:
    print("错误:", e) # 提示无法序列化或找不到函数

    ray.shutdown()
  • 核心原因:Ray 远程函数执行依赖序列化传输和集群节点间代码同步

    • 未注册的本地函数无法被序列化为集群可识别的任务,且远程节点没有该函数的定义,会导致执行失败

正确示例(远程函数调用)

  • Ray 的远程函数依赖集群调度,通过 @ray.remote 显式注册后使用远程调用函数调用

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    import ray

    ray.init(ignore_reinit_error=True)

    # 定义远程函数(会注册到 Ray 集群)
    @ray.remote
    def add(a, b):
    return a + b

    # 远程函数,可接收其他远程函数的调用结果
    @ray.remote
    def execute_func(func, x, y):
    # 这里 func 是远程函数标识,通过 .remote() 触发执行
    result = ray.get(func.remote(x, y)) # 使用远程调用的方式调用函数指针,实现调用远程函数,正确!
    # result = func(x, y) # remote 函数无法被直接调用,错误!
    # result = add(x,y) # remote 函数无法被直接调用,错误!
    # result = add_local(x, y) # add_local 当做 local 函数调用(注意:不再是指针传入),正确!
    return result

    # # 不使用 remote 直接调用 远程函数,错误
    # result1 = add(2, 3)

    # 使用remote直接调用远程函数,正确
    result1 = ray.get(add.remote(2, 3))
    print("直接调用结果:", result1) # 输出:5

    # 间接通过另一个远程函数调用(模拟"传递函数逻辑")
    result2 = ray.get(execute_func.remote(add, 2, 3))
    print("间接调用结果:", result2) # 输出:10

    ray.shutdown()
  • Ray 的远程函数依赖集群调度,需通过 @ray.remote 显式注册,无法像本地代码那样传递函数指针(内存地址在分布式环境中无效)

  • 若需在远程函数中复用其他函数逻辑,直接传递已注册的远程函数名(如示例中的 add),再通过 func.remote() 调用即可

Math——f-divergence


f-divergence定义

  • \( f \)-散度(\( f \)-divergence)是概率论和信息论中的一种概念,用于衡量两个概率分布之间的差异
  • 形式上,对于两个概率分布 \( P \) 和 \( Q \),定义在一个共同的样本空间上, \( f \)-散度可以被定义为:
    $$ D_f(P|Q) = \int_{\Omega} q(x) f\left(\frac{p(x)}{q(x)}\right) dx $$
    • \( \Omega \) 是样本空间
    • \( p(x) \) 和 \( q(x) \) 分别是 \( P \) 和 \( Q \) 的概率密度函数
    • \( f \) 是一个凸函数,满足 \( f(1) = 0 \),这是为了保证当 \( P = Q \) 时 \( D_f(P|Q) = 0 \)
  • \( f \)-散度的一个重要性质是它是非负的,即 \( D_f(P|Q) \geq 0 \),并且只有当 \( P = Q \) 时等号成立。这意味着 \( f \)-散度可以作为两个概率分布之间距离的一种度量,尽管它不满足距离的所有公理(比如对称性)

常见的f-divergence例子

  • Kullback-Leibler散度 (KL散度),其中 \( f(u) = u \log u \)
    • 注意带入以后可以消去分母得到KL散度的最终公式
  • Hellinger距离,这里 \( f(u) = (\sqrt{u} - 1)^2 \)
  • 总变差距离,此时 \( f(u) = |u - 1| \)
  • χ²散度,使用 \( f(u) = \frac{(u - 1)^2}{u} \)

附录:KL散度的非负性证明

  • 核心,利用Jensen不等式证明 Kullback-Leibler(KL)散度是非负的
  • 对于两个概率分布 \( P \) 和 \( Q \) 在同一空间 \( \mathcal{X} \) 上,KL 散度定义为:
    $$
    D_{\text{KL} }(P \parallel Q) = \sum_{x \in \mathcal{X} } P(x) \log \frac{P(x)}{Q(x)}
    $$
    • 注意 KL 散度的积分权重和分子是相同的(这是由其含义和非负性决定的,详情见附录),若对换分子分母,得到的是 KL 的负数值
  • 对于连续变量,定义为:
    $$
    D_{\text{KL} }(P \parallel Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} , dx
    $$
    • 其中 \( p(x) \) 和 \( q(x) \) 分别是 \( P \) 和 \( Q \) 的概率密度函数
  • 进一步地,KL散度可以表示为:
    $$
    D_{\text{KL} }(P \parallel Q) = \mathbb{E}_{P} \left[ \log \frac{P(x)}{Q(x)} \right]
    $$
    • 即 \( \log \frac{P(x)}{Q(x)} \) 在分布 \( P \) 下的期望

应用Jensen不等式求负KL散度

  • 由于 \( \log(x) \) 是一个凹函数(伞状),根据Jensen不等式,对于凹函数有:
    $$
    \mathbb{E}[\log X] \leq \log \mathbb{E}[X]
    $$
    • 令 \( X = \frac{Q(x)}{P(x)} \),则:
      $$
      \mathbb{E}_{P} \left[ \log \frac{Q(x)}{P(x)} \right] \leq \log \left( \mathbb{E}_{P} \left[ \frac{Q(x)}{P(x)} \right] \right)
      $$
  • 计算期望:
    $$
    \mathbb{E}_{P} \left[ \frac{Q(x)}{P(x)} \right] = \sum_{x} P(x) \cdot \frac{Q(x)}{P(x)} = \sum_{x} Q(x) = 1
    $$
  • 因此:
    $$
    \mathbb{E}_{P} \left[ \log \frac{Q(x)}{P(x)} \right] \leq \log(1) = 0
    $$

推导KL散度的非负性

  • 注意到:
    $$
    \mathbb{E}_{P} \left[ \log \frac{Q(x)}{P(x)} \right] = -D_{\text{KL} }(P \parallel Q)
    $$
  • 因此:
    $$
    -D_{\text{KL} }(P \parallel Q) \leq 0 \implies D_{\text{KL} }(P \parallel Q) \geq 0
    $$
  • 当且仅当 \( P(x) = Q(x) \) 对所有 \( x \) 成立时,\( \frac{P(x)}{Q(x)} = 1 \),此时:
    $$
    D_{\text{KL} }(P \parallel Q) = \sum_{x} P(x) \log 1 = 0
    $$
  • KL散度始终满足:
    $$
    D_{\text{KL} }(P \parallel Q) \geq 0
    $$
  • 且 \( D_{\text{KL} }(P \parallel Q) = 0 \) 当且仅当 \( P = Q \)

附录:KL 散度=交叉熵与熵的差 推导

  • KL 散度本质上是交叉熵与熵的差,反映了用错误模型编码时的“额外信息量”

熵 \(H(P)\) 和 交叉熵 \(H(P, Q)\) 的定义

  • 对于离散分布 \(P(x)\),熵定义为:
    $$
    H(P) = -\sum_x P(x) \log P(x)
    $$
    • 它表示在分布 \(P\) 下,平均需要多少信息量(比特或 nats)来编码事件
  • 交叉熵定义为:
    $$
    H(P, Q) = -\sum_x P(x) \log Q(x)
    $$
    • 它表示在真实分布是 \(P\) 时,如果用分布 \(Q\) 来编码,平均需要的信息量

KL 散度=两者的差

  • 交叉熵与熵的差:
    $$
    \begin{align}
    H(P, Q) - H(P) &= \left[ -\sum_x P(x) \log Q(x) \right] - \left[ -\sum_x P(x) \log P(x) \right] \\
    &= -\sum_x P(x) \log Q(x) + \sum_x P(x) \log P(x) \\
    &= \sum_x P(x) \left[ \log P(x) - \log Q(x) \right] \\
    &= \sum_x P(x) \log \frac{P(x)}{Q(x)} \\
    &= D_{\mathrm{KL} }(P | Q)
    \end{align}
    $$

理解

  • 熵 \(H(P)\) :理想编码长度
  • 交叉熵 \(H(P, Q)\) :用错误分布 \(Q\) 编码的平均长度
  • KL 散度 :额外的编码长度,也就是交叉熵比真实熵多出来的部分

附录:卡方散度 和 KL 散度对比

卡方散度定义

  • 设 \(P,Q\) 为两个概率分布,且 \(Q\) 绝对连续于 \(P\)(\(P(x)=0\Rightarrow Q(x)=0\)), 皮尔逊卡方散度(简称卡方散度) 定义为:
    $$
    \chi^2(P|Q) = \int \frac{(P(x)-Q(x))^2}{Q(x)} dx
    $$
  • 离散形式:
    $$
    \chi^2(P|Q) = \sum_i \frac{(P_i-Q_i)^2}{Q_i}
    $$
  • 卡方散度的其他形式:
    $\chi^2(P | Q) = \sum_x \frac{(P(x) - Q(x))^2}{Q(x)} = \sum_x Q(x) \left(\frac{P(x)}{Q(x)} - 1\right)^2 = E_Q\left[\left(\frac{P}{Q} - 1\right)^2\right]$$
    • 等价形式:
      $$
      \begin{align}
      \chi^2(P | Q) &= \mathbb{E}_Q\left[\left(\frac{P}{Q} - 1\right)^2\right] = \mathbb{E}_Q\left[\left(\frac{P}{Q}\right)^2\right] - 2\mathbb{E}_Q\left[\frac{P}{Q}\right] + \mathbb{E}_Q[1] \\
      &= \mathbb{E}_Q\left[\left(\frac{P}{Q}\right)^2\right] - 2\sum_x Q(x)\cdot\frac{P(x)}{Q(x)} + 1 \\
      &= \mathbb{E}_Q\left[\left(\frac{P}{Q}\right)^2\right] - 2 + 1 \\
      &= \mathbb{E}_Q\left[\left(\frac{P}{Q}\right)^2\right] - 1 \\
      \end{align}
      $$

回顾 KL 散度定义

  • 连续形式
    $$
    D_{\mathrm{KL} }(P|Q) = \int P(x)\log\frac{P(x)}{Q(x)} dx
    $$
  • 离散形式:
    $$
    D_{\mathrm{KL} }(P|Q) = \sum_i P_i\log\frac{P_i}{Q_i}
    $$

卡方散度与 KL 散度的关系

  • 关系1:泰勒展开关系
    • 当 \(P\) 接近 \(Q\) 时,对 \(\log\frac{P}{Q}\) 在 \(P=Q\) 处展开:
      $$
      D_{\mathrm{KL} }(P|Q) = \frac{1}{2}\chi^2(P|Q) + o\bigl(|P-Q|^2\bigr)
      $$
      • 即:KL 散度在局部等价于卡方散度的 1/2
  • 关系2:不等式关系
    • 由 Jensen 不等式可证:
      $$
      D_{\mathrm{KL} }(P|Q) \le \chi^2(P|Q)
      $$

两者特点对比

  • 对比详情:
    特性 KL 散度 \(D_{\mathrm{KL} }(P|Q)\) 卡方散度 \(\chi^2(P|Q)\)
    形式 含对数,信息论度量 二次型,统计检验度量
    对称性 非对称:\(D_{\mathrm{KL} }(P|Q)\neq D_{\mathrm{KL} }(Q|P)\) 非对称:\(\chi^2(P|Q)\neq \chi^2(Q|P)\)
    非负性 满足 \(D_{\mathrm{KL} }\ge 0\) 满足 \(\chi^2\ge 0\)
    对小 \(Q_i\) 敏感,但存在对数约束,爆炸缓慢 及其敏感,但无对数,更容易爆炸
    权重 按 \(P_i\) 加权 按 \(1/Q_i\) 加权
    来源 信息论、编码、熵 皮尔逊卡方检验、拟合优度
    优化 常用于变分推断、生成模型 常用于密度比、分布检验
  • 重点:卡方散度比 KL 散度更不稳定(在 \(Q(x)\) 极小时,卡方散度很容易出现爆炸)
    • 卡方散度是被 \(\frac{1}{Q}\) 修饰的,当 \(Q(x)\) 减小时,是线性增长
    • KL 散度是被 \(\log \frac{1}{Q}\) 修饰的,当 \(Q(x)\) 减小时,对数增长就慢很多

Math——线性规划求解方法和理解

本文包含对线性规划的直观理解,不严谨,后续有新的问题/理解持续更新

  • 参考链接:
    • 运筹学中应该如何理解互补松弛性。这条性质又该如何运用?
    • 第4章 对偶理论和敏感度分析
    • 线性规划对偶问题的定义,有什么直觉上的解释吗?:原始问题到对偶问题最好的一种很简洁的解释
    • 互联网广告算法漫谈——浅谈广告中的出价技术。注意:该参考链接中没有把预算约束相关的互补松弛定理写出来,且2.3中存在一些较为明显的小bug,但整体求解思路和结论没问题

原始问题

  • 问题描述:
    • 假设你是一个木匠有200单位的木头和90单位的时间
    • 木匠可以制作桌子或者椅子
      • 桌子成本为5单位木头+2单位时间,售价10元
      • 椅子成本为2单位木头+1单位时间,售价3元
  • 目标:在已有资源情况下,最大化收入,应该生产多少桌子和椅子?
  • 问题形式化描述:
    • 假设应该生产 \(x_1\) 把桌子和 \(x_1\) 把椅子
      $$
      \begin{align}
      \max \ \ 10x_1 &+ 3x_2 \\
      5x_1 + 2x_2 &<= 200 \\
      3x_1 + \ \ x_2 &<= 90 \\
      x_1,x_2 &>= 0 \\
      \end{align}
      $$
  • 作图法可求得最优解为 \(x_1^* = 30, x_2^*=0\),此时最大收益为300
    • 在二维坐标轴上先画出可行域,然后按照目标直线斜率找到最优点

对偶问题

  • 对偶问题描述:
    • 上述原始问题可以换一个视角看
    • 假设现在你是一个原材料收购商(想要以最低价格收购木匠的原材料)
    • 目标:对单位木头和单位时间进行出价,以最低的价格买完木匠的资源(假设木匠愿意卖出的前提是收购上出价的最小值不小于木匠原始问题中收益的最大值)
      • 实际上最好是刚好等于木匠原始问题的最大收益
  • 对偶问题形式化描述
    $$
    \begin{align}
    \min \ \ 200p_1 &+ 90p_2 \quad – 总付款 \\
    5p_1 + 3p_2 &>= 10 \quad – 一张桌子的资源售价不低于一张桌子的收益 \\
    2p_1 + \ \ p_2 &>= 3 \quad – 一张椅子的资源售价不低于一张椅子的收益 \\
    p_1,p_2 &>= 0 \quad – 售价不为负数 \\
    \end{align}
    $$
  • 其中 \(p_1, p_2\) 分别称为单位木头和单位时间的影子价格
  • 作图法可求得最优解为 \(p_1^* = 0, p_2^* = 3.3\),此时最小支付金额为300

互补松弛定理的理解

从原始问题的约束视角出发

等价于从对偶问题的解出发

  • 对偶问题中,最优解是 \(p_1^* = 0, p_2^* = 3.3\)
    • \(p_1^* = 0\) 意味着我们的木材过量了,其实不需要这么多木材,原始问题中,最优解对应的木材约束是松的( \(5x_1^* + 2x_2^*=150 < 200\) )
    • \(p_2^* = 3.3\) 说明时间资源非常紧俏,原始问题中,最优解对应的时间约束是紧的( \(3x_1^* + \ \ x_2^* = 90\) )
  • 对应互补松弛的含义:
    • 如果在最优条件下一个约束不等式是松的(木材),那么这个约束对应的影子价格为0
    • 反过来说,如果某个约束对应的影子价格严格大于0,那么这个约束不等式一定是紧的
    • 总的来说,原始问题的约束和对偶问题变量(影子价格)总有一个要为0

从对偶问题的约束视角出发

等价于从原始问题的解出发

  • 原始问题中,最优解是 \(x_1^* = 30, x_2^*=0\)
    • \(x_1^* = 30\) 意味着桌子非常合算,应该多生产桌子,对偶问题中,桌子约束是紧的( \(5p_1^* + 3p_2^* = 10\) )
    • \(x_2^*=0\) 以为这椅子不合算,不应该生产椅子,对偶问题中,椅子的约束是松的( \(2p_1^* + \ \ p_2^* = 3.3 > 3\) )
  • 补充互补松弛的含义:
    • 如果在对偶最优条件下一个约束不等式是松的(椅子),那么这个约束对应的原始问题变量最优解( \(x_2^*\) )为0
    • 反过来说,如果某个原始问题变量(桌子)对应的解( \(x_1^*\) )严格大于0,那么对偶问题中这个约束不等式一定是紧的
    • 总的来说,对偶问题的约束和对应原始问题变量总有一个要为0

互补松弛定理的公式化

$$
(5p_1^* + 3p_2^* - 10)x_1^* = 0 \\
(2p_1^* + p_2^* - 3)x_2^* = 0 \\
(5x_1^* + 2x_2^* - 200)p_1^* = 0 \\
(3x_1^* + x_2^* - 90)p_2^* = 0 \\
$$


附录:USCB推导

  • 《A Unified Solution to Constrained Bidding in Online Display Advertising》——论文阅读
    • 这篇文章的约束很多,每个商家都有自己的约束
    • 推导时用到的对偶变换和互补松弛定理均可由论文推导得出【有时间再详细推导】

附录:BCB推导(单约束)

  • 《Budget Constrained Bidding by Model-free Reinforcement Learning in Display Advertising》——论文原文
    • 这篇文章中的问题定义比较简单,整体只有一个预算约束
    • 上述结果详细的推导可以参考:
      • 智能出价——BCB求解
      • 互联网广告算法漫谈——浅谈广告中的出价技术。注意:该参考链接中没有把预算约束相关的互补松弛定理写出来,且2.3中存在一些较为明显的小bug,但整体求解思路和结论没问题
    • 推导结果 \(bid = \frac{v_i}{\lambda}\) 与常用的方法(RL-MPCA)结果不一致,但可以证明本质是等价的

附录:CPC约束推导(单约束)

  • 问题描述:单位置、二价拍卖,且CPM计费场景,CPC约束下最大化商家点击量
  • 推导过程可参考论文Bid Optimization by Multivariable Control in Display Advertising
  • 基本推导思路:先通过拉格朗日乘子法得到最优解的形式(这里先忽略边际条件 \(0\le x_i \le 1\) ),再将原始问题转换成对偶问题,进一步分情况讨论得到最终解
  • 问题定义
    $$
    \begin{align}
    &\max \sum_i x_i \cdot ctr_i \\
    \text{s.t.} &\quad \frac{\sum_i x_i \cdot wp_i}{\sum_i x_i \cdot ctr_i} \le cpc \\
    &\quad 0 \le x_i \le 1, \forall i
    \end{align}
    $$
  • 第一步:推导最优出价形式:
    • 写出拉格朗日函数并求导:
      $$\mathcal{L}(x, \lambda, \mu) = - \sum_i x_i \cdot ctr_i + \lambda \left(\sum_i x_i \cdot wp_i - \sum_i x_i \cdot ctr_i \cdot cpc\right) + \sum_i \mu_i (x_i - 1)$$
    • 对任意的 \(x_i\) 求导有:
      $$ \frac{\partial \mathcal{L}(x, \lambda, \mu)}{\partial x_i} = - \sum_i ctr_i + \lambda \sum_i wp_i - \lambda \sum_i ctr_i \cdot cpc + \sum_i \mu_i $$
    • 令上述导数为0有(\(\mu_i\) 来自边界条件 \(0\le x_i \le 1\),为了得到最优解形式,接下来先忽略边界条件,最后会证明在满足边界条件下,该形式也是最优的):
      $$
      \begin{align}
      wp_i &= \frac{ctr_i + \lambda \cdot cpc \cdot ctr_i}{\lambda} \\
      &= \frac{1 + \lambda \cdot cpc}{\lambda} \cdot ctr_i
      \end{align}
      $$
      • 所以我们令出价等于下面的形式:
        $$bid_i = \frac{1 + \lambda \cdot cpc}{\lambda} \cdot ctr_i$$
  • 第二步:验证最优出价形式:
    • 原始问题对应的对偶问题为:
      $$
      \begin{align}
      &\mathop{\min}_{\lambda, r_i} \sum_i r_i \\
      \text{s.t.} &\quad \lambda(wp_i - cpc\cdot ctr_i) + r_i \ge ctr_i \quad \text(1)\\
      &\quad \lambda \ge 0 \\
      &\quad r_i \ge 0, \forall i
      \end{align}
      $$
    • 互补松弛条件:
      $$
      \begin{align}
      x_i(\lambda(wp_i - cpc\cdot ctr_i) + r_i - ctr_i) = 0 \quad &\text{(2)} \\
      r_i(x_i - 1) = 0, \forall i \quad &\text{(3)}
      \end{align}
      $$
    • 将最优出价公式 \(bid_i = \frac{ctr_i + \lambda \cdot cpc \cdot ctr_i}{\lambda}\) 带入公式(2)可得:
      $$ x_i(\lambda(wp_i - bid_i) + r_i) = 0$$
      • 当 \(x_i \gt 0\) 时,有 \(wp_i - bid_i = -\frac{r_i}{\lambda} \lt 0\),进一步推得 \(bid_i \ge wp_i\)
      • 当 \(x_i = 0\) 时,由公式(3)有 \(r_i = 0\);将最优出价公式 \(wp_i = \frac{ctr_i + \lambda \cdot cpc \cdot ctr_i}{\lambda}\) 带入公式(1)可得 \(\lambda(wp_i - bid_i) + r_i \ge 0\),进一步推得 \(wp_i - bid_i \ge 0\),即\(bid_i \le wp_i\)
    • 证毕
  • 如何理解最优出价形式?
    $$
    \begin{align}
    bid_i = \frac{1 + \lambda \cdot cpc}{\lambda} \cdot ctr_i = \color{red}{(\frac{1}{\lambda \cdot cpc} + 1)} \cdot cpc \cdot ctr_i
    \end{align}
    $$
    • 二价计费场景中 ,计费比未知,所以引入了一个大于 1 的出价系数:
      $$ k = \color{red}{\frac{1}{\lambda \cdot cpc} + 1} $$
      • 用来提升出价以做到目标CPC达成(可以证明,在整个周期内流量足够多的情况下,如果实际CPC小于目标CPC,则此时一定不是点击最大化的出价策略)
      • 实际使用中,由于 \(1\) 是全局固定值,商家的 \(cpc\) 是商家粒度的固定值,\(\lambda\) 是商家粒度的变量,可以合并成一个变量即可(\(\lambda\) 和 \(k\) 是一一对应的),最终可以忽略 \(k\) 值的具体形式,只需要直接调节 \(k\) 即可,此时最优公式为:
        $$ bid_i = \color{red}{k} \cdot cpc \cdot ctr_i $$
    • 如果竞争环境非常激烈,计费比趋近于1(同时考虑预估值准确),此时每次出价都按照 \(\color{red}{bid_i = cpc \cdot ctr_i} \),可保证投放周期内实际CPC的期望刚好等于目标CPC,\(\color{red}{k=1}\) 就是最优的出价策略
    • 调控系数的其他功能 :从推导来看,系数 \(k\) 可以用于补足二价计费的Gap;在实际应用中,这个 k 值还可以解决 CTR 均值预估值不准确的问题,比如CTR预估过高 ,\(k\) 会小于1 ,从而保证不超成本
      • 可以注意到:在这个假设下有矛盾点,\(k < 1\) 时对应的 \(\lambda < 0\),并不满足对拉格朗日乘子的要求,但不用担心,这里实际上 \(k = k_1 \cdot k_2\),其中,由 \(\lambda\) 导出的 \(k_1\) 依然是大于1的,用来调平CTR预估值 \(k_2\) 是小于1的,实际上,\(\lambda \geq 0\) 始终成立

附录:oCPC场景约束推导(单约束)

  • 问题描述1:单位置、二价拍卖,且CPC计费场景,CPS约束下最大化商家订单量
  • 实际上,本问题中与上文单位置拍卖的CPM计费场景,CPC约束下最大化商家点击量非常相似,仅需把对应的参数替换一下即可(\(ctr_i \rightarrow cvr_i\),\(cpc \rightarrow cps\)),于是有最优出价形式是:
    $$
    \begin{align}
    wp_i = \frac{1 + \lambda \cdot cps}{\lambda} \cdot cvr_i = \color{red}{(\frac{1}{\lambda \cdot cps} + 1)} \cdot cps \cdot cvr_i
    \end{align}
    $$
    • 出价系数:
      $$ k = \color{red}{\frac{1}{\lambda \cdot cps} + 1} $$
    • 注:实际使用中,同上描述,最终可以忽略 \(k\) 值的具体形式,只需要直接调节 \(k\) 即可,此时最优公式为:
      $$ bid_i = \color{red}{k} \cdot cps \cdot cvr_i $$
  • 问题描述2:单位置、二价拍卖,且CPC计费场景,ROI约束下最大化商家Revenue
  • 此时可以进一步表达为如下形式(\(cvr_i \rightarrow cvr_i\cdot rev_i\),\(cps \rightarrow rate = \frac{1}{ROI}\), ):
    $$
    \begin{align}
    wp_i &= \frac{1 + \lambda \cdot 1/ROI}{\lambda} \cdot cvr_i \cdot rev_i \\
    &= \frac{ROI + \lambda}{\lambda \cdot ROI} \cdot rev_i \cdot cvr_i \\
    &= \frac{ROI + \lambda}{\lambda} \cdot \frac{rev_i \cdot cvr_i}{ROI} \\
    &= \color{red}{(\frac{ROI}{\lambda} + 1)} \cdot \frac{rev_i \cdot cvr_i}{ROI} \\
    \end{align}
    $$
    • 出价系数:
      $$ k = \color{red}{\frac{ROI}{\lambda} + 1}$$
    • 注:实际使用中,同上描述,最终可以忽略 \(k\) 值的具体形式,只需要直接调节 \(k\) 即可,此时最优公式为:
      $$ bid_i = \color{red}{k} \cdot \frac{rev_i \cdot cvr_i}{ROI} $$

附录:紧约束和松约束

  • 紧约束(Tight Constraint)和松约束(Slack Constraint)是描述约束条件对可行解集影响的两个概念
  • 紧约束指的是那些在其边界上限制了最优解的约束条件。换句话说,如果改变某个约束条件会直接影响到最优解的位置或值,那么这个约束条件就是紧的。例如,在线性规划问题中,如果一个不等式约束以“=”的形式满足于最优解处,那么这个约束就是紧约束。紧约束对于确定最优解至关重要,因为它们直接定义了最优解所在的位置
  • 松约束则指的是那些在最优解处并没有起到实际限制作用的约束条件。也就是说,即使这些约束不存在,也不会改变问题的最优解。这类约束条件提供了额外的空间,但在这个空间内的点并不会比边界上的点更优。因此,松约束的存在不会影响最终的优化结果,但在某些情况下,它们可能为寻找最优解提供便利或增加灵活性。

Math——运筹优化开源求解器-GLPK的使用

本文介绍各种运筹优化开源求解器-GLPK的使用

  • GLPK是一款完全开源免费的运筹优化求解器,可以任意商用

Ubuntu安装GLPK

  • 据说Ubuntu安装较为方便,所以建议首选Ubuntu

  • 在网站下载文件:https://ftp.gnu.org/gnu/glpk/

    • 可以下载任意版本,建议选最新
  • 安装命令

    1
    2
    3
    4
    tar -xzvf glpk-xxx.tar.gz
    ./configure
    make
    sudo make install
  • 安装后直接执行可能出现错误

    1
    error while loading shared libraries: libglpk.so.36:...
  • 解决方案(原始解决方案地址):

    1
    https://github.com/rstudio/renv/issues/1881

Ubuntu下GLPK的使用

  • 下列式子参考了:线性规划工具 GLPK 的安装及基本使用

  • 创建问题描述文件glpkDemo.mod

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    /* Variables */
    var x1 >= 0;
    var x2 >= 0;
    var x3 >= 0;

    /* Object function */
    maximize z: 3*x1 + x2 +2*x3;

    /* Constrains */
    s.t. con1: x1 + x2 + 3*x3 <= 30;
    s.t. con2: 2*x1 +2*x2 + 5*x3 <= 24;
    s.t. con3: 4*x1 + x2 + 2*x3 <= 36;

    end;
  • 执行命令解决问题

    1
    glpsol -m glpkDemo.mod -o ./output/glpkDemo.sol
  • 输出文件

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    Problem:    glpkDemo
    Rows: 4
    Columns: 3
    Non-zeros: 12
    Status: OPTIMAL
    Objective: z = 28 (MAXimum)

    No. Row name St Activity Lower bound Upper bound Marginal
    ------ ------------ -- ------------- ------------- ------------- -------------
    1 z B 28
    2 a B 12 30
    3 b NU 24 24 0.166667
    4 c NU 36 36 0.666667

    No. Column name St Activity Lower bound Upper bound Marginal
    ------ ------------ -- ------------- ------------- ------------- -------------
    1 x1 B 8 0
    2 x2 B 4 0
    3 x3 NL 0 0 -0.166667

    Karush-Kuhn-Tucker optimality conditions:

    KKT.PE: max.abs.err = 0.00e+00 on row 0
    max.rel.err = 0.00e+00 on row 0
    High quality

    KKT.PB: max.abs.err = 0.00e+00 on row 0
    max.rel.err = 0.00e+00 on row 0
    High quality

    KKT.DE: max.abs.err = 2.22e-16 on column 1
    max.rel.err = 3.17e-17 on column 1
    High quality

    KKT.DB: max.abs.err = 0.00e+00 on row 0
    max.rel.err = 0.00e+00 on row 0
    High quality

    End of output
  • Activity这一列就是想要的解

  • 其他输出项如何理解?

自动化生成问题

  • 使用shell或者Python自动生成.mod文件,然后自然解析.sol文件,实现自动化测试参数

CV——ViT

  • 参考链接:
    • 原始论文:
      • (ViT)An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, Google Brain Team, ICLR 2021

Paper Summary

  • 整体总结:
    • ViT(Vision Transformer)已经成为计算机视觉必不可少的组件了,目前围绕 ViT 已经有了许多变体,本文是 ViT 相关的第一篇文章
    • ViT 核心:将 Transformer 用在图像领域
      • 与之前在计算机视觉中使用自注意力的工作不同
        • 除了初始的 patch 提取步骤外,ViT 没有向架构引入特定于图像的归纳偏置 (image-specific inductive biases)
          • 理解:这里所谓的归纳偏置就是类似 平移等变性和局部性 等 CNN 对图像领域的核心假设
        • ViT 将图像建模为一系列 patches,并使用与 NLP 中相同的标准 Transformer 编码器进行处理
          • 这种简单但可扩展的策略,在与大规模数据集上的预训练相结合时,效果出人意料地好
    • 注:本文之前的状况:
      • 在视觉领域,注意力机制要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构不变
      • 本文表明这种对 CNN 的依赖并非必要,直接应用于图像块序列的纯 Transformer 可以在图像分类任务上表现良好
      • Vision Transformer (ViT) 相比卷积网络更好,同时训练所需的计算资源要少得多

Introduction and Discussion

  • 基于自注意力的架构(特别是 Transformer),已成为自然语言处理中的首选模型
    • 主导方法是在大型文本语料库上进行预训练,然后在较小的特定任务数据集上进行微调 (2019)
    • 得益于 Transformer 的计算效率和可扩展性,训练具有前所未有规模的模型(超过 100B 参数)已成为可能 (2020;2020)
    • 且随着模型和数据集的增长,性能仍未出现饱和迹象
  • 在计算机视觉中,卷积架构仍然占主导地位 (1989;2012;2016)
    • 受 NLP 成功经验的启发,多项工作尝试将类似 CNN 的架构与自注意力相结合 (2018;2020),有些则完全替换了卷积 (2019;2020a)
      • 完全替换了卷积的这一类模型理论上是高效的,但由于使用了专门的注意力模式,尚未在现代硬件加速器上有效扩展
    • 结论:在当前大规模图像识别中,经典的 ResNet 类架构仍然是 SOTA (2018;2020;2020)
  • 本文尝试将标准的 Transformer 直接应用于图像,并进行尽可能少的修改
    • 本文将图像分割成块 (patches),并将这些块的线性嵌入序列作为 Transformer 的输入
    • 图像块的处理方式与 NLP 应用中的 Token(词)相同
    • 本文以监督方式训练模型进行图像分类
  • ViT 效果与数据量级有关:
    • 数据量不足时:
      • 当在中型数据集(如未使用强正则化的 ImageNet)上训练时,这些模型的准确率适中,比同等大小的 ResNet 低几个百分点
      • 这一看似不及预期的结果可能是意料之中的:Transformer 缺乏 CNN 固有的一些归纳偏置 (inductive biases),例如平移等变性和局部性,因此在数据量不足的情况下无法很好地泛化
        • 理解:这里所谓的归纳偏置就是类似 平移等变性和局部性 等 CNN 对图像领域的核心假设
    • 数据量充足时:
      • 如果模型在更大的数据集(14M-300M 张图像)上训练,情况就会发生变化:大规模训练胜过归纳偏置
        • 当在足够规模的数据上进行预训练并迁移到数据点较少的下游任务时, ViT 获得了出色的结果
      • 在公共 ImageNet-21k 数据集或内部的 JFT-300M 数据集上预训练时,ViT 在多个图像识别基准上接近或超越了 state of the art
        • 最佳模型在 ImageNet 上达到了 \(88.55\%\) 的准确率,在 ImageNet-Real 上达到了 \(90.72\%\),在 CIFAR-100 上达到了 \(94.55\%\),在包含 19 个任务的 VTAB 套件上达到了 \(77.63\%\)

Related Work

  • 详情见原文
  • Transformer 已成为许多 NLP 任务中的 SOTA 方法
    • 基于 Transformer 的大型模型通常先在大型语料库上进行预训练,然后针对手头任务进行微调:
      • BERT (2019) 使用去噪自监督预训练任务
      • GPT 系列工作使用语言建模作为其预训练任务 (2018;2019;2020)
  • 像素粒度的 Transformer:
    • 将自注意力直接应用于图像需要每个像素关注其他每个像素(注意:这里是每个像素,而像素是很多很多的)
    • 由于计算成本与像素数量呈二次关系,这无法扩展到实际的输入尺寸
      • 为了在图像处理中应用 Transformer,过去已经尝试了几种近似方法
    • Parmar 等人 (2018) 仅对每个查询像素的局部邻域应用自注意力,而不是全局
      • 这种局部多头点积自注意力块可以完全替代卷积 (2019;2019;2020)
    • Sparse Transformers (2019) 采用可扩展的近似方法来处理全局自注意力 ,以便应用于图像
    • 另一种扩展注意力的方法是将其应用于不同大小的块中 (2019),极端情况下仅沿单个轴应用 (2019;2020a)
      • 理解:这里
    • 许多这些专门的注意力架构在计算机视觉任务上展示了有希望的结果,但需要复杂的工程才能在硬件加速器上高效实现
  • 块粒度的 Transformer
    • (2020) 的模型从输入图像中提取大小为 \(2 \times 2\) 的块,并在其上应用完全的自注意力
    • 该模型与 ViT 非常相似,但 ViT 更进一步,证明了大规模预训练使 vanilla Transformer 能够与最先进的 CNN 竞争(甚至更好)
    • (2020) 使用了 \(2 \times 2\) 像素的小块大小,这使得该模型仅适用于小分辨率图像,而也处理中等分辨率的图像
  • 将卷积神经网络与各种形式的自注意力结合的方法:
    • 通过增强特征图用于图像分类 (2019),或使用自注意力进一步处理 CNN 的输出
    • 用于目标检测 (2018;2020)、视频处理 (2018;2019)、图像分类 (2020)、无监督目标发现 (2020) 或统一文本-视觉任务 (2020c;2019;2019)
  • image GPT (iGPT) (2020a) 在降低图像分辨率和颜色空间后将 Transformer 应用于图像像素
    • iGPT 模型以无监督方式作为生成模型进行训练,然后可以对生成的表示进行微调或线性探查以用于分类性能,在 ImageNet 上达到了 \(72\%\) 的最大准确率
  • 工作为越来越多探索超出标准 ImageNet 数据集的更大规模图像识别的研究增添了新的内容
    • 使用额外的数据源可以在标准基准上达到 SOTA 的结果 (2018;2019;2020)
    • Sun 等人 (2017) 研究了 CNN 性能如何随数据集大小扩展
    • Djolonga 等人 (2020) 对从大规模数据集(如 ImageNet-21k 和 JFT-300M)进行的 CNN 迁移学习进行了实证探索
  • 也关注后两个数据集,但训练的是 Transformer,而不是先前工作中使用的基于 ResNet 的模型

Method

  • 在模型设计上,尽可能严格地遵循原始 Transformer (2017)
  • 这种刻意简化的设置的一个优点是,可扩展的 NLP Transformer 架构及其高效实现几乎可以开箱即用

Vision Transformer

  • 模型概述如图 1 所示
  • 标准 Transformer 接收的输入是一个 1D 的 Token 嵌入序列
  • 定义 2D 图像:
    $$ \mathbf{x} \in \mathbb{R}^{H \times W \times C} $$
  • 将上述图像重塑为一系列展平的 2D 图像块
    $$ \mathbf{x}_p \in \mathbb{R}^{N \times (P^2 \cdot C)} $$
    • \((H, W)\) 是原始图像的分辨率
    • \(C\) 是通道数
    • \((P, P)\) 是每个图像块的分辨率
    • \(N = HW / P^2\) 是产生的图像块数量,也是 Transformer 的有效输入序列长度
  • Transformer 在其所有层中使用恒定的隐向量尺寸 \(D\),因此本文将图像块展平,并使用一个可训练的线性投影将其映射到 \(D\) 维(公式 1)
    • 本文将此投影的输出称为图像块嵌入 (patch embeddings)
  • 类似于 BERT 的 [class] Token,本文在向嵌入后的图像块序列 \((\mathbf{z}_0^0 = \mathbf{x}_{\text{class} })\) 前添加一个可学习的嵌入,其在 Transformer 编码器输出端的状态 \((\mathbf{z}_L^0)\) 用作图像表示 \(\mathbf{y}\)(公式 4)
    • 在预训练和微调期间,一个分类头 (classification head) 被附加到 \(\mathbf{z}_L^0\) 上
    • 分类头在预训练时通过一个带有一个隐藏层的 MLP 实现,在微调时通过一个单线性层实现
  • 位置嵌入 (position embeddings) 被添加到图像块嵌入中以保留位置信息
    • 本文使用标准的可学习 1D 位置嵌入
      • 注:因为本文作者没有观察到使用更高级的 2D 感知位置嵌入能带来显著的性能提升(附录 D.4)
    • 得到的嵌入向量序列作为编码器的输入
  • Transformer 编码器 (2017) 由多头自注意力 (MSA, multiheaded self-attention, 见附录 A) 和 MLP 块的交替层组成(公式 2, 3)
    • 在每个块之前应用层归一化 (LN, Layernorm),在每个块之后应用残差连接 (2019; 2019)
      $$
      \begin{array}{rlr}\mathbf{z}_0 &= [\mathbf{x}_{\text{class} };\mathbf{x}_p^1\mathbf{E};\mathbf{x}_p^2\mathbf{E};\dots ;\mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}, \quad \quad & \mathbf{E}\in \mathbb{R}^{(P^2\cdot C)\times D},\mathbf{E}_{pos}\in \mathbb{R}^{(N + 1)\times D}\\
      \mathbf{z}’_\ell &= \text{MSA}(\text{LN}(\mathbf{z}_{\ell - 1})) + \mathbf{z}_{\ell - 1}, & \ell = 1\dots L\\
      \mathbf{z}_\ell &= \text{MLP}(\text{LN}(\mathbf{z}’_\ell)) + \mathbf{z}’_\ell, & \ell = 1\dots L\\
      \mathbf{y} &= \text{LN}(\mathbf{z}_L^0)
      \end{array} \tag {4}
      $$
      • 注:这里 MSA 本质就是 Transformer 中的 MHA
Inductive bias,归纳偏置
  • Vision Transformer 的图像特定归纳偏置比 CNN 少得多
    • 在 CNN 中,局部性、二维邻域结构和平移等变性被融入到整个模型的每一层中
    • 在 ViT 中,只有 MLP 层是局部的和平移等变的,而自注意力层是全局的
  • 二维邻域结构的使用非常少:
    • 在模型开始时通过将图像切割成图像块,以及在微调时为不同分辨率的图像调整位置嵌入(如下所述)
  • 除此之外,初始化时的位置嵌入不携带任何关于图像块二维位置的信息,所有图像块之间的空间关系都必须从头学习
Hybrid Architecture
  • 作为原始图像块的替代方案,输入序列可以从 CNN 的特征图中形成 (1989)
  • 在这种混合模型中,图像块嵌入投影 \(\mathbf{E}\)(公式 1)被应用于从 CNN 特征图中提取的图像块
  • 作为一种特殊情况,图像块的空间尺寸可以是 1x1,这意味着输入序列是通过简单地展平特征图的空间维度并投影到 Transformer 维度而获得的
  • 分类输入嵌入和位置嵌入如上所述被添加

Fine-Tuning And Higher Resolution

  • 在大型数据集上预训练 ViT,然后微调到(较小的)下游任务
  • 移除预训练的预测头,并附加一个零初始化的 \(D \times K\) 前馈层
    • \(K\) 是下游类别的数量
    • 理解:此时相当于用前馈层作为分类头,实现多分类任务
  • 以比预训练更高的分辨率进行微调通常是有益的 (2019; 2020)
    • 当输入更高分辨率的图像时,保持图像块大小不变,这会导致更长的有效序列长度
  • 理论上 Vision Transformer 可以处理任意的序列长度(直到内存限制)
    • 但预训练的位置嵌入可能不再有意义
  • 可以考虑根据预训练位置嵌入在原始图像中的位置,对其进行 2D 插值
    • 注:这种分辨率调整和图像块提取是 Vision Transformer 中手动注入关于图像 2D 结构的归纳偏置的唯一两个点

Experiments

  • 本文评估了 ResNet、Vision Transformer (ViT) 以及混合模型的表示学习能力
  • 为了解每个模型的数据需求,本文在不同规模的数据集上进行了预训练,并评估了许多基准任务
  • 在考虑模型预训练的计算成本时,ViT 表现得非常出色,以更低的预训练成本在大多数识别基准上达到了最先进的水平
  • 最后使用自监督进行了一个小型实验,并表明自监督 ViT 为未来带来了希望

Setup

Datasets
  • 为探索模型的可扩展性,使用了
    • 1)包含 1k 个类别和 1.3M 张图像的 ILSVRC-2012 ImageNet 数据集(下文中简称为 ImageNet)
    • 2)1)的超集 ImageNet-21k 包含 21k 个类别和 14M 张图像 (2009)
    • 3)JFT (2017):包含 18k 个类别和 303M 张高分辨率图像
  • 遵循 Kolesnikov 等 (2020) 的做法,针对下游任务的测试集对预训练数据集进行了去重
  • 本文将这些数据集上训练的模型迁移到几个基准任务上:
    • 使用原始验证标签和清理后的 ReaL 标签的 ImageNet (2020),CIFAR-10/100 (2009),Oxford-IIIT Pets (2012),以及 Oxford Flowers-102 (2008)。对于这些数据集,预处理步骤遵循 Kolesnikov 等 (2020)
Model Variants
  • 基于 BERT (2019) 使用的配置来设置 ViT 配置,如表 1 总结
  • “Base”和“Large”模型直接采用自 BERT,本文增加了更大的“Huge”模型
    • 在下文中,本文使用简短的符号来表示模型大小和输入图像块大小:
      • 例如,ViT-L/16 表示“Large”变体,输入图像块大小为 \(16\times 16\)
    • 注:Transformer 的序列长度与图像块大小的平方成反比 ,因此具有更小图像块大小的模型计算成本更高
  • 对于基线 CNN,使用 ResNet (2016),但将批量归一化层 (2015) 替换为组归一化 (2018),并使用了标准化卷积 (2019)
    • 这些修改改善了迁移性能 (2020),将修改后的模型记为“ResNet (BiT)”
  • 对于混合模型,将中间特征图馈送到 ViT,图像块大小为一个“像素”
  • 为了试验不同的序列长度
    • (i) 采用标准 ResNet50 第 4 阶段的输出
    • (ii) 移除第 4 阶段,将相同数量的层放在第 3 阶段(保持总层数不变),并采用这个扩展后的第 3 阶段的输出
      • 注: (ii) 会产生 4 倍长的序列长度,以及一个计算成本更高的 ViT 模型
Training & Fine-tuning
  • 使用 Adam (2015) 训练所有模型,包括 ResNets
    • 设置 \(\beta_{1} = 0.9\),\(\beta_{2} = 0.999\),批大小为 4096,并应用 0.1 的高权重衰减
    • 注:作者实验中发现这对所有模型的迁移都有用(附录 D.1 表明,与常见做法相反,在本文设置中,Adam 对 ResNets 的效果略优于 SGD)
  • 本文使用线性学习率预热和衰减,详见附录 B.1
  • 对于微调,本文对所有模型使用带动量的 SGD,批大小为 512,见附录 B.1.1
  • 对于表 2 中的 ImageNet 结果,本文以更高分辨率进行了微调:
    • ViT-L/16 为 512,ViT-H/14 为 518,并且还使用了 Polyak & Juditsky (1992) 平均,因子为 0.9999 (2019; 2020b)
Metrics
  • 通过 few-shot 或微调准确率报告下游数据集上的结果
    • 微调准确率反映了模型在相应数据集上微调后的性能
    • few-shot 准确率通过求解一个正则化最小二乘回归问题获得,该问题将(冻结的)训练图像子集的表示映射到 \(\{- 1,1\}^{K}\) 目标向量
  • 这种公式化允许作者以封闭形式获得精确解
  • 虽然本文主要关注微调性能,但有时本文会使用线性 Few-shot 准确率进行快速的即时评估,因为微调成本太高
Comparison To State of the art
  • 首先将最大的模型(ViT-H/14 和 ViT-L/16)与文献中最先进的 CNN 进行比较
    • 第一个比较点是 Big Transfer (BiT) (2020)
      • Big Transfer 使用大型 ResNet 进行有监督迁移学习
    • 第二个是 Noisy Student (2020)
      • Noisy Student 是一个大型 EfficientNet,使用半监督学习在 ImageNet 和移除标签的 JFT-300M 上训练
      • 目前,Noisy Student 是 ImageNet 上的 SOTA 模型,BiT-L 是本文报告的其他数据集上的最先进模型
    • 注:所有模型均在 TPUv3 硬件上训练,本文报告了预训练每个模型所需的 TPUv3 核心天数 (TPUv3-core-days),即用于训练的核心数乘以训练天数
  • 表 2 显示了结果
    • 在 JFT-300M 上预训练的较小的 ViT-L/16 模型在所有任务上均优于 BiT-L(在同一数据集上预训练),同时所需的训练计算资源大大减少
    • 更大的模型 ViT-H/14 进一步提升了性能,尤其是在更具挑战性的数据集上(ImageNet、CIFAR-100 和 VTAB 套件)
    • 注:该模型预训练所需计算量仍然远少于先前的先进模型
    • 注:预训练效率不仅可能受到架构选择的影响,还可能受到其他参数的影响,例如训练计划、优化器、权重衰减等
      • 本文在第 4.4 节中对不同架构的性能与计算量进行了受控研究
    • 最后,在公共 ImageNet-21k 数据集上预训练的 ViT-L/16 模型在大多数数据集上也表现良好,同时预训练所需的资源更少:使用一个标准 8 核的云 TPUv3 训练大约需要 30 天
  • 图 2 将 VTAB 任务分解为各自的任务组,并与该基准上之前的 SOTA 方法进行了比较:
    • BiT、VIVI(一个在 ImageNet 和 Youtube 上共同训练的 ResNet (2020))和 S4L(在 ImageNet 上进行有监督加半监督学习 (2019a))
    • ViT-H/14 在自然 (Natural) 和结构化 (Structured) 任务上优于 BiT-R152x4 和其他方法
    • 在专业化 (Specialized) 任务上,前两个模型的性能相似

Pre-training Data Requirements,预训练的数据需求

  • Vision Transformer 在大型 JFT-300M 数据集上预训练时表现良好
    • 与 ResNets 相比,ViT 对视觉的归纳偏置较少,那么数据集大小有多关键?进行了两个系列的实验
第一个实验:在规模递增的数据集上预训练 ViT 模型
  • ImageNet、ImageNet-21k 和 JFT300M
  • 为了提升在较小数据集上的性能,本文优化了三个基本的正则化参数:权重衰减、Dropout 和标签平滑
  • 图 3 显示了微调到 ImageNet 后的结果(其他数据集上的结果见表 5)
    • 当在最小的数据集 ImageNet 上预训练时,尽管进行了(适度的)正则化,ViT-Large 模型的性能仍不如 ViT-Base 模型
      • 使用 ImageNet-21k 预训练时,它们的性能相似
    • 只有在使用 JFT-300M 时,才看到更大模型带来的全部好处
    • 图 3 还展示了不同大小的 BiT 模型所跨越的性能区域
    • BiT CNN 在 ImageNet 上优于 ViT,但随着数据集变大,ViT 实现了反超
第二个实验:在 JFT-300M 数据集的随机子集(9M、30M 和 90M,以及完整数据集)上训练 ViT 模型
  • 本文没有对较小的子集进行额外的正则化,而是在所有设置中使用相同的超参数
    • 注:本文评估的是模型的内在属性,而不是正则化的效果
  • 但本文使用了 early-stopping,并报告训练期间达到的最佳验证准确率
  • 图 4 包含了结果,为了节省计算量,本文报告 Few-shot 线性准确率而非完整的微调准确率
    • 在较小的数据集上,Vision Transformer 比计算成本相当的 ResNet 更容易过拟合
      • 例如,ViT-B/32 比 ResNet50 稍快,且 ViT-B/32 在 9M 子集上的表现要差得多
    • 在 90M+ 子集上表现更好
    • 注:ResNet152x2 和 ViT-L/16 也是如此
    • 这个结果强化了直觉:卷积归纳偏置对较小的数据集有用,但对于较大的数据集,直接从数据中学习相关模式就足够了,甚至是有益的
  • ImageNet 上的 Few-shot 结果(图 4)以及 VTAB 上的低数据量结果(表 2)对于极低数据量的迁移来说似乎很有希望
    • 对 ViT Few-shot 特性的进一步分析是未来工作的一个令人兴奋的方向

Scaling Study

  • 本文通过评估从 JFT-300M 迁移的性能,对不同模型进行了受控的扩展性研究
    • 在这种设置下,数据大小不会成为模型性能的瓶颈,本文评估了每个模型的性能与预训练成本的关系
  • 模型集合包括:
    • 7 个 ResNets,R50x1,R50x2,R101x1,R152x1,R152x2,预训练 7 个 epoch
    • R152x2 和 R200x3 预训练 14 个 epoch
    • 6 个 Vision Transformers,ViT-B/32,B/16,L/32,L/16,预训练 7 个 epoch
    • L/16 和 H/14 预训练 14 个 epoch
    • 5 个混合模型,R50+ViT-B/32,B/16,L/32,L/16 预训练 7 个 epoch
    • R50+ViT-L/16 预训练 14 个 epoch(对于混合模型,模型名称末尾的数字代表的不是图像块大小,而是 ResNet 主干网络中的总降采样率)
  • 图 5 包含了迁移性能与总预训练计算量的对比(关于计算成本的详细信息,请参见附录 D.5)
    • 每个模型的详细结果在附录的表 6 中提供
    • 可以观察到几种模式
      • 第一,在性能/计算量的权衡上,Vision Transformer 主导了 ResNet
        • 为了达到相同的性能(在 5 个数据集上平均),ViT 使用的计算量大约减少 \(2 - 4\) 倍
      • 第二,在小的计算预算下,混合模型的性能略优于 ViT,但对于更大的模型,这种差异消失了
        • 这一结果有点出乎意料,因为人们可能期望卷积局部特征处理能在任何规模下帮助 ViT
      • 第三,在尝试的范围内,Vision Transformer 似乎没有出现饱和(未来前景无限)

Inspecting Vision Transformer,审视 ViT

  • 为理解 Vision Transformer 如何处理图像数据,本文分析了其内部表示
  • Vision Transformer 的第一层将展平的图像块线性投影到一个低维空间(公式 1)
    • 图 7(左)显示了学习到的嵌入滤波器的前几个主成分
      • 这些成分类似于每个图像块内部精细结构的低维表示的合理基函数
  • 在投影之后,一个学习到的位置嵌入被添加到图像块表示中
    • 图 7(中)显示,模型学习了在位置嵌入的相似性中编码图像内的距离,即更近的图像块倾向于具有更相似的位置嵌入
    • 此外:行-列结构出现了
      • 同一行/列的图像块具有相似的嵌入
      • 理解:Position Embedding 看起来已经被隐含的学到了
    • 最后,对于更大的网格,有时会显现出一种正弦结构(附录 D)
      • 位置嵌入学习表示二维图像拓扑结构
        • 这一事实解释了为什么手工制作的二维感知嵌入变体没有带来改进(附录 D.4)
  • 自注意力使 ViT 即使在最底层也能整合整个图像的信息
    • 本文研究了网络在多大程度上利用了这种能力
    • 基于注意力权重计算了整合信息的图像空间中的平均距离(图 7,右)
      • 这个“注意力距离”类似于 CNN 中的感受野大小
  • 发现:
    • 一些头在最底层就已经关注到图像的大部分区域,这表明模型确实使用了全局整合信息的能力
      • 其他注意力头在低层始终具有很小的注意力距离
      • 这种高度局部化的注意力在混合模型(在 Transformer 之前应用了 ResNet)中不太明显(图 7,右),表明它可能起到了与 CNN 中早期卷积层类似的功能
    • 注意力距离随着网络深度增加而增加
    • 从全局来看,模型关注的图像区域与分类任务语义相关(图 6)

Self-Supervision,自监督

  • Transformer 在 NLP 任务中的成功很大程度上不仅源于其出色的可扩展性,还源于大规模的自监督预训练 (2019; 2018)
    • 本文还对用于自监督的掩码图像块预测 (masked patch prediction) 任务进行了初步探索,模仿了 BERT 中使用的掩码语言建模任务
    • 通过自监督预训练,较小的 ViT-B/16 模型在 ImageNet 上达到了 \(79.9\%\) 的准确率,比从头开始训练显著提高了 \(2\%\),但仍比有监督预训练低 \(4\%\)
    • 附录 B.1.2 包含了更多细节
  • 注:本文将对对比预训练 (2020b; 2020; 2019; Hé2020) 的探索留给未来的工作

附录 A:MultiHead Self-Attention,多头自注意力

  • 标准的 qkv 自注意力 (SA, (2017)) 中,对于输入序列 \(\mathbf{z}\in \mathbb{R}^{N\times D}\) 中的每个元素,计算序列中所有值 \(\mathbf{v}\) 的加权和
  • 注意力权重 \(A_{ij}\) 基于序列中两个元素之间的成对相似性及其各自的查询 \(\mathbf{q}^i\) 和键 \(\mathbf{k}^j\) 表示
    $$\begin{array}{rlr}\left[\mathbf{q},\mathbf{k},\mathbf{v}\right] = \mathbf{z}\mathbf{U}_{qkv} & \mathbf{U}_{qkv}\in \mathbb{R}^{D\times 3D_h}, & (5)\ A = \text{softmax}\left(\mathbf{q}\mathbf{k}^\top /\sqrt{D_h}\right) & A\in \mathbb{R}^{N\times N}, & (6)\ \text{SA}(\mathbf{z}) = Av. & & \end{array} \tag {7}$$
  • 多头自注意力 (MSA) 是 SA 的扩展,在其中并行运行 \(k\) 个自注意力操作(称为“头”),并投影它们的拼接输出
    • 为了在改变 \(k\) 时保持计算量和参数数量恒定,\(D_h\) (公式 5) 通常设置为 \(D / k\)
      $$\text{MSA}(\mathbf{z}) = [\text{SA}_1(z);\text{SA}_2(z);\dots ;\text{SA}_k(z)]\mathbf{U}_{msa}\qquad \mathbf{U}_{msa}\in \mathbb{R}^{k\cdot D_h\times D} \tag {8}$$

附录 B:Experiment Details

B.1 Training

  • 表 3 总结了作者针对不同模型的训练设置
    • 发现:在 ImageNet 上从头开始训练模型时,强正则化是关键
    • Dropout(如果使用)在每个密集层之后应用,除了 qkv 投影层以及在将位置嵌入添加到 patch 嵌入之后直接应用
    • 混合模型使用与其对应的 ViT 模型完全相同的设置进行训练
    • 最后,所有训练均在 224 分辨率下进行
  • 表 3:训练的超参数
    • 所有模型均以 4096 的批量大小和 10k 步的学习率预热进行训练
    • 对于 ImageNet,本文发现额外应用全局范数为 1 的梯度裁剪是有益的
    • 训练分辨率为 224
B.1.1 Fine-Tuning
  • 使用带动量为 0.9 的 SGD 微调所有 ViT 模型
  • 本文对学习率进行小范围网格搜索,学习率范围见表 4
  • 使用训练集中的小子集(Pets 和 Flowers 为 10%,CIFAR 为 2%,ImageNet 为 1%)作为开发集,并在剩余数据上进行训练
  • 为了获得最终结果,在整个训练集上进行训练,并在相应的测试数据上进行评估
  • 对于微调 ResNet 和混合模型,使用完全相同的设置,唯一的例外是 ImageNet,在学习率扫描中增加了另一个值 0.06
  • 对于 ResNet,也运行 Kolesnikov 等人 (2020) 的设置,并在此次运行和作者的扫描中选择最佳结果
    • 除非另有说明,否则所有微调实验均在 384 分辨率下运行(以不同于训练的分辨率进行微调是常见做法 (2020))
  • 将 ViT 模型迁移到另一个数据集时,本文会移除整个头(两个线性层),并将其替换为一个零初始化的、输出目标数据集所需类别数的线性层
    • 本文发现这比简单地重新初始化最后一层更稳健一些
  • 对于 VTAB,遵循 Kolesnikov 等人 (2020) 的协议,并对所有任务使用相同的超参数设置
    • 使用 0.01 的学习率并训练 2500 步 (表 4)
    • 通过对两个学习率和两个调度进行小范围扫描,并选择在 200 个示例的验证集上具有最高 VTAB 分数的设置来选定此设置
    • 遵循 Kolesnikov 等人 (2020) 中使用的预处理,除了本文不使用特定于任务的输入分辨率
    • 相反,本文发现 Vision Transformer 从对所有任务采用高分辨率 \((384 \times 384)\) 中获益最多
  • 表 4:微调的超参数
    • 所有模型均使用余弦学习率衰减、批量大小为 512、无权重衰减以及全局范数为 1 的梯度裁剪进行微调
    • 除非另有说明,否则微调分辨率为 384
B.1.2 Self-Supervision
  • 采用掩码 patch 预测目标进行初步的自监督实验
    • 破坏 \(50\%\) 的 patch 嵌入,通过将其嵌入替换为可学习的 [mask] 嵌入 \((80\%)\)、随机的其他 patch 嵌入 \((10\%)\) 或保持不变 \((10\%)\)
      • 注:此设置与 Devlin 等人 (2019) 用于语言的设置非常相似
    • 使用每个被破坏 patch 的相应 patch 表示来预测其 3 位平均颜色(即总共 512 种颜色)
  • 在 JFT 上以 4096 的批量大小训练了自监督模型 100 万步(约 14 个 epoch)
    • 使用 Adam,基础学习率为 \(2 \cdot 10^{-4}\),预热 10k 步并采用余弦学习率衰减
    • 作为预训练的预测目标,本文尝试了以下设置:
      • 1)仅预测平均 3 位颜色(即 512 种颜色的 1 个预测)
      • 2)并行预测 \(16 \times 16\) 个 patch 的 \(4 \times 4\) 下采样版本及 3 位颜色(即 16 个 512 种颜色的预测)
      • 3)使用 L2 对整个 patch 进行回归(即对 3 个 RGB 通道进行 256 次回归)
      • 令人惊讶的是,所有方法都运行良好(L2 稍差一些)
    • 注:本文仅报告选项 1)的最终结果,因为它显示出最佳的 few-shot 性能
      • 本文还尝试了 Devlin 等人 (2019) 使用的 \(15\%\) 破坏率,但根据本文的 few-shot 指标,结果也稍差
  • 最后,作者指出,本文掩码 patch 预测实例化不需要大量的预训练也不需要像 JFT 这样的大型数据集就能在 ImageNet 分类上带来类似的性能提升
    • 也就是说,在 10 万预训练步骤后下游性能的提升出现递减,并且在 ImageNet 上进行预训练时也看到了类似的提升

附录 C:Additional Results

  • 本文报告了与论文中图表相对应的详细结果
  • 表 5 对应论文中的图 3,显示了在不同规模的数据集(ImageNet、ImageNet-21k 和 JFT-300M)上预训练的不同 ViT 模型的迁移性能
  • 表 6 对应论文中的图 5,显示了不同规模的 ViT、ResNet 和混合模型的迁移性能,以及它们预训练的预估计算成本

附录 D:Additional Analysis

D.1 SGD vs. Adam for ResNets

  • ResNet 通常使用 SGD 进行训练,而本文使用 Adam 优化器则非常规
  • 本文比较了在 JFT 上用 SGD 和 Adam 预训练的两个 ResNet(50x1 和 152x2)的微调性能
    • 对于 SGD,本文使用 Kolesnikov 等人 (2020) 推荐的超参数
  • 结果呈现在表 7 中
    • Adam 预训练在大多数数据集上和平均表现上均优于 SGD 预训练
    • 这证明了选择 Adam 作为在 JFT 上预训练 ResNet 的优化器的合理性
    • 注:绝对数值低于 Kolesnikov 等人 (2020) 报告的值,因为本文仅预训练了 7 个 epoch,而不是 30 个

D.2 Transformer Shape

  • 本文对 Scaling Transformer 架构的不同维度进行了消融实验,以找出哪些维度最适合扩展到非常大的模型
  • 图 8 显示了不同配置下 ImageNet 上的 5-shot 性能
    • 所有配置均基于一个具有 8 层、\(D = 1024\)、\(D_{MLP} = 2048\) 和 patch 大小为 32 的 ViT 模型,即所有线的交点
    • 结论:
      • Scaling 深度带来的改进最大,直到 64 层都非常明显
        • 但在 16 层之后已经可以看到收益递减,开始明显,后续逐步变成对数趋势
      • 缩放网络的宽度似乎带来的变化最小
      • 减小 patch 大小从而增加有效序列长度,显示出惊人地稳健的改进,且未引入参数
        • 注:图中 Patch Size 应该是表达计算成本为 X 轴的,不是其绝对值
    • 这些发现表明,计算量可能比参数数量更能预测性能,并且如果可能的话,缩放应侧重于深度而非宽度
      • 总的来说,按比例缩放所有维度会带来稳健的改进

D.3 Head Type and Class Token

  • 为尽可能接近原始的 Transformer 模型
    • 本文使用了额外的 [class] token,并将其作为图像表示
    • 然后通过一个小型 MLP 将该 token 的输出转换为类别预测,该 MLP 在单个隐藏层中使用 tanh 作为非线性激活函数
  • 此设计继承自用于文本的 Transformer 模型,本文在整篇论文中都使用它
    • 最初尝试仅使用图像 patch 嵌入,对其进行全局平均池化 (GAP),然后接一个线性分类器(就像 ResNet 的最终特征图一样),效果非常差
    • 本文发现:
      • 这既不是由于额外的 token,也不是由于 GAP 操作
      • 相反,性能差异完全可以通过对不同学习率的需求来解释,详见图 9

D.4 Position Embedding

  • 本文对使用位置嵌入编码空间信息的不同方式进行了消融实验
  • 本文尝试了以下情况:
    • 不提供位置信息 (Providing no positional information) :将输入视为一个无序的 patch 集合 (a bag of patches)
    • 一维位置嵌入 (1-dimensional positional embedding) :将输入视为按光栅扫描顺序 (raster order) 排列的 patch 序列(本文所有其他实验的默认设置)
    • 二维位置嵌入 (2-dimensional positional embedding) :将输入视为二维平面上的 patch 网格
      • 在这种情况下,学习两组嵌入,每组对应一个坐标轴,即 X 轴嵌入 (X-embedding) 和 Y 轴嵌入 (Y-embedding),每个的维度大小为 \(D/2\)。然后,根据 patch 在输入中的坐标,作者拼接 X 和 Y 的嵌入来得到该 patch 的最终位置嵌入
    • 相对位置嵌入 (Relative positional embeddings) :考虑 patch 之间的相对距离而不是它们的绝对位置来编码空间信息
      • 使用一维相对注意力 (1-dimensional Relative Attention),其中定义了所有可能的 patch 对之间的相对距离
      • 对于每一对给定的 patch(一个作为 query,另一个作为注意力机制中的 key/value),有一个偏移量 \(p_q - p_k\),每个偏移量关联一个嵌入
      • 然后简单地运行一个额外的注意力,使用原始的 query(query 的内容),但使用相对位置嵌入作为 keys
      • 然后将来自相对注意力的 logits 作为一个偏置项 (bias term),并在应用 softmax 之前将其加到主注意力(基于内容的注意力)的 logits 上
  • 除了不同的空间信息编码方式,本文还尝试了将这些信息整合到模型中的不同方法
  • 对于一维和二维位置嵌入,本文尝试了三种不同的情况:
    • (1) 在模型的主干(Stem)之后,将输入馈送到 Transformer 编码器之前(本文中所有其他实验的默认设置)
    • (2) 在每层开始时学习并添加位置嵌入到输入中
    • (3) 在每层开始时添加学习到的位置嵌入到输入中(层之间共享)
  • 表 8 总结了在 ViT-B/16 模型上进行的此消融研究的结果
    • 没有位置嵌入的模型与有位置嵌入的模型性能之间存在巨大差距,但不同编码位置信息的方式之间几乎没有区别
    • 推测:ViT 的 Transformer 编码器在 patch 级别的输入上操作,而不是像素级别,因此如何编码空间信息的差异不那么重要
    • 更准确地说,在 patch 级别的输入中,空间维度远小于原始像素级别的输入,例如 \(14 \times 14\) 对比 \(224 \times 224\),并且在这种分辨率下学习表示空间关系对于这些不同的位置编码策略来说同样容易
      • 即便如此,网络学习到的位置嵌入相似性的具体模式也取决于训练超参数(图 10)

D.5 Empirical Computation Costs,计算成本

  • 计算成本详情见原始论文

D.6 Axial Attention

  • 轴向注意力 (Axial Attention) (2020;2019) 是一种简单而有效的技术,用于在组织为多维张量的大型输入上运行自注意力
    • 轴向注意力的一般思想是执行多个注意力操作,每个操作都沿着输入张量的单个轴进行,而不是将一维注意力应用于输入的扁平化版本
    • 在轴向注意力中,每个注意力沿着特定轴混合信息,同时保持其他轴上的信息独立
  • 沿着这个思路,Wang 等人 (2020b) 提出了 AxialResNet 模型,其中 ResNet50 中所有 \(3 \times 3\) 卷积核大小的卷积都被轴向自注意力(即行和列注意力)取代,并辅以相对位置编码
    • 本文已经实现了 AxialResNet 作为基线模型
  • 此外,本文修改了 ViT 以处理二维形状的输入,而不是一维的 patch 序列,并合并了轴向 Transformer 块 (Axial Transformer blocks)
    • 其中,本文不是使用一个自注意力后接一个 MLP,而是使用一个行自注意力加一个 MLP,后接一个列自注意力加一个 MLP
  • 图 13 展示了在 JFT 数据集上预训练的 Axial ResNet、Axial-ViT-B/32 和 Axial-ViT-B/16 在 ImageNet 5-shot 线性评估上的性能,与预训练计算量的关系,计算量以 FLOPs 数量(左)和推理时间(每秒样本数,右)两种形式表示
    • Axial-ViT-B/32 和 Axial-ViT-B/16 在性能方面均优于其对应的 ViT-B 模型,但这需要更多的计算量
      • 这是因为在 Axial-ViT 模型中,每个具有全局自注意力的 Transformer 块被两个轴向 Transformer 块(一个用于行自注意力,一个用于列自注意力)取代,并且尽管轴向情况下自注意力操作的序列长度较小,但每个 Axial-ViT 块中多了一个 MLP
    • 对于 AxialResNet,尽管在准确性/计算量权衡方面看起来合理(图 13,左),但在 TPU 上,简单的实现极其缓慢(图 13,右)

D.7 Attention Distance

  • 为了理解 ViT 如何使用自注意力跨图像整合信息,本文分析了不同层注意力权重所跨越的平均距离(图 11)
    • 此“注意力距离”类似于 CNN 中的感受野大小
  • 在较低层中,不同头之间的平均注意力距离差异很大,一些头关注图像的大部分区域,而另一些头则关注查询位置附近或周围的小区域
    • 随着深度增加,所有头的注意力距离都增加
      • 注:图 11 的纵轴是每个 Head 的 Attention 距离的平均观测值
    • 在网络的后半部分,大多数头广泛地关注各个 token
  • 图 11:按头和网络深度划分的注意力区域大小
    • 通过平均查询像素与所有其他像素之间的距离(按注意力权重加权),为 128 个示例图像计算注意力距离
    • 每个点表示在某一层的 16 个头之一在所有图像上的平均注意力距离
    • 图像宽度为 224 像素

D.8 Attention Maps

  • 为计算从输出 token 到输入空间的注意力图(图 6 和 14),本文使用了 Attention Rollout (2020)
    • 简而言之,我们平均 ViT-L/16 在所有头上的注意力权重,然后递归地乘以所有层的权重矩阵
    • 这说明了注意力在所有层中跨 token 的混合

D.9 ObjectNet Results

  • 本文还按照 Kolesnikov 等人 (2020) 的评估设置,在本文的旗舰模型 ViT-H/14 上评估了 ObjectNet 基准,获得了 \(82.1\%\) 的 Top-5 准确率和 \(61.7\%\) 的 Top-1 准确率

D.10 VTAB Breakdown

  • 表 9 显示了在 VTAB-1k 的每个任务上获得的分数

RL——IQL

  • 参考链接
    • 原始论文:ICLR 2022 Poster, Offline reinforcement learning with implicit q-learning
    • 相关论文:(AWR)ADVANTAGE-WEIGHTED REGRESSION: SIMPLE AND SCALABLE OFF-POLICY REINFORCEMENT LEARNING

IQL 的基本思想

  • 常规的方法会直接约束策略或者正则来减少OOD问题,IQL则通过SARSA style的方法仅在见过的state-action上进行学习,不直接面对OOD问题
  • 策略学习使用了AWR(Advantage Weighted Regression)方法

多步动态规划和 Single-step 方法

多步动态规划(Multi-step DP)

  • 多步动态规划方法(multi-step dynamic programming methods,简写作Multi-step DP)
  • 已有Offline RL方法的很大一部分是基于约束或正则化的近似动态规划(例如,Q-learning 或 actor-critic 方法),constraint或Regularization用于限制与行为策略的偏差。 我们将这些方法称为多步动态规划(Multi-step DP)算法,因为它们对多次迭代执行真正的动态规划,因此如果提供高覆盖率数据,原则上可以恢复最优策略。通常情况下Multi-step DP问题也可以分为:
    • 显式密度模型(explicit density model):BRAC,BCQ,BEAR等
    • 隐式差异约束(implicit divergence constraints):AWAC,CRR,AWR等
  • 如何理解显示密度模型和隐式约束模型的定义?
    • 显式密度模型:直接建模State-Action的价值分布,从而得到最优策略
    • 隐式差异约束:不直接建模State-Action的价值分布,更多是模仿优质策略行为的思想
  • 问题:显示密度模型中的“密度”是什么意思?
    • 这里的密度是指概率密度,显示密度模型即会直接定义并学习概率密度函数的模型

Single-step 方法

  • Single-step 方法(Single-step Methods)是指一类方法,这类方法仅依赖于单步策略迭代的方法,即对行为策略的价值函数或Q函数进行拟合,然后提取相应的贪心策略,或者完全避免价值函数并利用行为克隆目标
  • 这类方法避免了访问看不见的状态动作对,因为它们要么根本不使用价值函数,要么学习行为策略的价值函数
  • IQL 就是一种 Single-step 方法
  • 传统的模仿学习也属于 Single-step 方法

多步动态规划和 Single-step 方法的比较

  • from https://zhuanlan.zhihu.com/p/497358947

IQL 之前的方案

一般的 Offline RL 学习方法

  • 思路:按照贝尔曼最优方程迭代
  • 损失函数:
    $$
    L_{TD}(\theta) = \mathbb{E}_{(s,a,s’) \sim D} \left[ (r(s, a) + \gamma \max_{a’} Q_{\theta’}(s’, a’) - Q_\theta(s, a))^2 \right]
    $$
  • 分析:
    • 直接使用上述损失函数存在值高估问题
    • 大多数最近的离线RL方法修改了上述值函数损失(或直接约束argmax这个策略本身选择动作的方位),以正则化值函数,使其生成的策略接近数据,缓解值高估问题

能避免 OOD 的学习方法

  • 思路:按照SARSA-style的方法迭代,即贝尔曼期望方程( \(a’\sim \pi_\beta\) )
  • 损失函数:SARSA-style的损失函数如下
    $$
    L(\theta) = \mathbb{E}_{(s,a,s’,a’) \sim D} \left[ (r(s, a) + \gamma Q_{\theta’}(s’, a’) - Q_\theta(s, a))^2 \right]
    $$
    • 按照上面的损失函数学习,学到的 \(Q_\theta(s,a)\) 本质是行为策略对应的Q值,也就是说,当样本无限时,Q值收敛到
      $$
      Q_\theta^*(s, a) \approx r(s, a) + \gamma \mathbb{E}_{s’ \sim p(\cdot|s,a), a’ \sim \pi_\beta(\cdot|s’)} \left[ Q_{\theta’}(s’, a’) \right]
      $$
  • 分析:
    • 本质上是在估计数据集上的状态和动作分布下,Q值的期望
    • 显然上面学到的只是行为策略对应的Q值,不是我们想要的最优Q值(行为策略不一定是最优策略)
    • 上面的方法更像是在对行为策略进行模仿

Offline RL 的最优 Q 值目标

  • 思路:避免OOD且能学到“最优策略”的迭代形式,限制了argmax动作不访问OOD的状态动作对
  • 损失函数:
    $$
    L(\theta) = \mathbb{E}_{(s,a,s’) \sim D} \left[ (r(s, a) + \gamma \max_{a’ \in A, \pi_\beta(a’|s’) > 0} Q_{\theta’}(s’, a’) - Q_\theta(s, a))^2 \right]
    $$
  • 分析:
    • 既保证使用的最大Q值对饮动作不超过数据集(避免了OOD),又可以在支持集上最大化当前策略
    • 上面的定义实际上也可能访问到支持集以外的动作,后续需要使用期望回归来改进为SARSA-style的形式
  • 注意:IQL 并不直接学习上述目标( \(\pi_\beta(a’|s’) > 0\) 导致无法学习),只是隐式的学习上述目标 ,具体方法是引入期望回归(Expectile Regression)
    • BCQ等方法已经学习过上述目标的改进版本
    • 上述目标无法直接学习,因为判断 \(\pi_\beta(a’|s’) > 0\) 需要维护一个表格,统计所有数据,状态动作空间很大时无法实现,除非像BCQ一样,用一个网络去学习概率

IQL 的解决方案

期望回归与分位数回归

  • 期望回归(Expectile Regression) ,是估计随机变量的各种统计量的方法,定义如下:

    • 某个随机变量 \(X\) 的 \(\tau \in (0, 1)\) 期望值定义为以下非对称最小二乘问题的解:
      $$
      \mathop{\arg\min}_{m_\tau} \mathbb{E}_{x \sim X} \left[ L_\tau^2(x - m_\tau) \right], \quad \text{ Where } \quad L_\tau^2(u) = |\tau - 1(u < 0)| u^2.
      $$
    • \(L_\tau^2(u)\) 也常常写作 \(L_\tau^e(u)\)
    • 给定 \(\tau\), \(m_\tau\) 就是在拟合随机变量的某个 \(\tau\) 期望点,不同的 \(\tau\) 下 \(m_\tau\) 也会不同,学到的,比如 \(\tau=0.5\) 时就是对应期望
    • 分析:
      • 当 \(\tau > 0.5\) 时,这种非对称损失函数会降低小于 \(m_\tau\) 的 \(x\) 值的权重,而增加大于 \(m_\tau\) 的 \(x\) 值的权重
      • 当 \(\tau = 0.5\) 时,损失函数退化成对称的,等价于均方误差MSE(这里把 \(u\) 看做是误差项)
        $$ L^{\tau=0.5}_{2}(u) = |0.5 - \Bbb{1}(u<0)|u^2 = \frac{1}{2}u^2 $$
  • 条件随机变量的期望回归

    • 对于给定的条件随机变量 \(y = f(x)\),假定 \((x,y)\) 成对出现在数据集 \(\mathcal{D}\) 中,则可以定义:
      $$\mathop{\arg\min}_{m_\tau(x)} \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ L_\tau^2(y - m_\tau(x)) \right]$$
    • 给定 \(\tau\), \(m_\tau(x)\) 是一个关于 \(x\) 的函数,不同的 \(\tau\) 得到的拟合函数不同,相同的 \(\tau\),给定不同的 \(x\) 会得到不同的 \(m_\tau(x)\), \(m_\tau(x)\) 本质是在拟合 \(y\),下图中最右侧的图展示了条件随机变量的期望回归
  • 分位数回归(Quantile Regression)定义如下:
    $$
    \mathop{\arg\min}_{m_\tau} \mathbb{E}_{x \sim X} \left[ L_\tau^1(x - m_\tau) \right], \quad \text{ Where } \quad L_\tau^1(u) = (\tau - 1(u < 0)) u.
    $$

    • \(L_\tau^1(u)\) 也常常写作 \(L_\tau^q(u)\)
    • \((\tau - 1(u < 0)) u\) 不使用绝对值的原因是此时无论 \(u\) 取值正负 \(L_\tau^1(u) \ge 0\) 都成立,相当于已经给整体加了绝对值了,最终目标是类似MAE的形式
  • 分位数回归和期望回归的对比

    • 常规的MSE叫做mean,等价于求均值,等价于 \(\tau = 0.5\) 的期望回归(expectile regression)
    • 常规的MAE叫做median,等价于求中位数,等价于 \(\tau = 0.5\) 的分位数回归(quantile regression)
  • 更多比较

    • 修正:左边第二行需要使用绝对值 \(\mathcal{R}_\tau^e(u) = u^2|\tau - \mathbf{1}(u < 0)|\)
  • 问题:为什么使用期望回归而不是分位数回归?

    • 审稿人也有这个疑问,作者的回答是实验得到的,没有正面给出回答?, \(\tau=0.9\) 时效果最好

基于期望回归的 Q 值学习

  • 借助期望回归来学习Q值:
    $$
    L(\theta) = \mathbb{E}_{(s,a,s’,a’) \sim D} \left[ L_\tau^2(r(s, a) + \gamma Q_{\theta’}(s’, a’) - Q_\theta(s, a)) \right]
    $$
  • 其中 \(\mathcal{D} \sim \pi_\beta\),选择合适的 \(\tau\) 后,可以学到一个大于 \(Q^{\pi_\beta}(s,a)\) (行为策略对应的Q值)的 \(Q(s,a)\)
  • 理解:给定 \((s,a)\) 的情况下,存在许多不同的 \((s’,a’)\) 样本,当 \(\tau > 0.5\) 时,相当于是通过这种非对称损失函数降低小于 \(Q_\theta(s, a)\) 的动作状态对 \((s’, a’)\) 所对应的目标值 \(r(s, a) + \gamma Q_{\theta’}(s’, a’)\) 的权重,增加大于 \(Q_\theta(s, a)\) 的动作状态对 \((s’, a’)\) 所对应的目标值 \(r(s, a) + \gamma Q_{\theta’}(s’, a’)\) 的权重,从而学到较大的 \((s’,a’)\) 对应的目标值,极端情况下,学到的是最大值 \(r(s, a) + \gamma \max_{(s,a,s’,a’) \sim \mathcal{D}} Q_{\theta’}(s’, a’)\)
  • 上面的损失函数还存在一些不足,由于环境可能是动态变化的,状态 \(s’\) 是按照概率 \(p(s’|s,a)\) 出现,所以以上损失函数还使得Q学到了环境转换的信息。具体来说,学到的Q值高不一定是选到了优秀动作的反应,还可能是因为运气好碰上了转移到一个较好的状态 \(s’\) 上
    • 补充说明1:即使是随机环境,在状态 \(s\) 下,选择 \(a\) 后有一定概率得到较优秀的 \(s’\),能说明在状态 \(s\) 下,选择 \(a\) 是较为优秀的吗?回答是不一定!因为在这种随机环境的情况下,最优贝尔曼方程里面,我们也需要对 \(s’\) 计算期望 \(\mathbb{E}_{s’\sim p(s’|s,a)}\) 而不是取最大 \(\max_{s’}\),这是我们的目标是找一个策略,使得按照这个策略交互得到的期望收益最大,而线上推断时,我们不能保证一定能走到最大的 \(s’\),除非是确定性环境,即 \((s,a)\) 确定后, \(s’\) 也是确定的
    • 补充问题1:如果是确定性的环境,是否可以直接使用上述损失函数?

IQL 的 Q 值学习

  • 由于基于期望回归的Q值学习引入了状态转移随机偏差,存在问题,所以需要进行改进:
  • 第一步:使用期望回归去从已知的 \(Q_{\hat{\theta}}(s,a)\) 中学习 \(V(s)\)
    $$ L_V(\psi) = \mathbb{E}_{(s,a) \sim D} \left[ L_\tau^2(Q_{\theta’}(s, a) - V_\psi(s)) \right] $$
    • 这里可以看出 \(V(s)\) 学到的是 \(\max_a Q_{\hat{\theta}}(s,a)\) 的思想,即对应V值的贝尔曼最优方程
  • 第二步:使用最优的 \(V\) 去学习 \(Q\)
    $$L_Q(\theta) = \mathbb{E}_{(s,a,s’) \sim D} \left[ (r(s, a) + \gamma V_\psi(s’) - Q_\theta(s, a))^2 \right] $$
    • 由于 \(V\) 在上一步已经通过期望回归学到了最优形式,这一步不需要继续使用期望回归了
  • 至此,我们已经实现了通过SARSA-style的形式,隐式的学到了近似最优Q值
  • 关于参数 \(\tau\) 的一些分析以及以上贝尔曼方程收敛性见附录

IQL 的策略学习

  • 虽然我们已经得到了近似最优Q值,但为了避免使用样本外的动作,这里做策略学习时,我们不能直接遍历所有动作
  • AWR提供了一种方法从近似最优Q值里面提取策略(因为策略学习并不影响Q值,所以更像是从近似最优Q值中提取策略):
    $$
    L_\pi(\phi) = \mathbb{E}_{(s,a) \sim D} \left[ \exp(\beta (Q_{\theta’}(s, a) - V_\psi(s))) \log \pi_\phi(a|s) \right]
    $$
    • 其中 \(\beta \ge 0\) 是温度系数。对于较小的超参数值,该目标类似于行为克隆(近似所有样本权重相等的策略梯度,原始策略梯度中,样本权重是温度系数为1的Q值),而对于较大的值,它试图恢复Q函数的最大值(Q值越大,对应的样本权重越大)。正如AWR等先前工作所示,此目标学习一个在分布约束下的最大化Q值的策略
  • 注意,策略学习时Q值收敛以后进行的(Q和V是交替更新),Q值学习和策略学习是串行的,且Q值学习彻底完成以后才进行策略学习,并不是交替进行
  • 思考:使用期望回归学到的 \(V\) 值是 \(V^{\pi^*} = \max_a Q_{\hat{\theta}}(s,a)\),为什么可以用最优的 \(V\) 值来更新策略 \(Q_{\theta’}(s, a) - V_\psi(s)\) ?
    • 这种做法是可以的,Q值和V值符合优势函数的定义,因为传统优势函数的定义也是 \(A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s)\),其中 \(V^\pi(s) = \mathbb{E}_{a \sim \pi(\cdot|s)}[Q^\pi(s,a)]\),看似与 IQL 中学到的 \(V\) 值不同,但此时将当前Q值和V值对应策略 \(\pi(a|s)\) 理解为选择Q值最大的动作或近似动作,实际上 \(Q\) 值和 \(V\) 值都满足传统的优势函数了
    • 理解 :(即使不满足原始优势函数)虽然此时的 \(V\) 值是 \(\max_a Q_{\hat{\theta}}(s,a)\),但是 \(Q_{\theta’}(s, a) - V_\psi(s)\) 依然可以对动作的好坏进行区分。实际上,只要可以保证动作越好,优势函数越大即可,即使所有动作都是负的或者都是正的也没问题,因为策略的实现是一个softmax,大家都降低的时候,降的少的动作上对一个的概率自然会提升。实践也告诉我们,\(V\) 值是否是当前状态下动作的期望结果并不重要
    • 特别说明 :AWR 中使用的 \(V\) 值是从历史样本的累计奖励上学习的,相当于是历史样本上的期望,也就是行为策略 \(\mu\)(多轮迭代下可能是混合策略)对应的 \(V^\mu\) 值,AWR 的整个推导中奖励 \(\mathcal{R}^\mu_{\mathbf{s},\mathbf{a}}\) 和 \(V^\mu\) 值都是使用行为策略 \(\mu\) 来表示的,奖励使用的是蒙特卡洛估计 \(\mathcal{R}^D_{\mathbf{s},\mathbf{a}} = \sum_{t=0}^T\gamma^t r_t\)

IQL 训练流程

  • 伪代码如下(说明:伪代码中最后一行策略更新公式有问题,应该是加号,或者把损失函数添上负号,因为这里是想要最大化目标, 作者开源代码中是正确的github.com/ikostrikov/implicit_q_learning,论文中写错了):

附录:为什么 AWR 和策略梯度法损失函数不同?

  • 副标题:不同AC框架算法策略更新公式对比分析,为什么相同的目标推导出来完全不同的更新公式?
  • 问题补充:
    • 普通AC(策略梯度法)更新公式是:
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{(s,a) \sim \pi_{\theta_k}}\Big[(Q^{\pi_{\theta_k}}(s,a)-V^{\pi_{\theta_k}}(s))\log\pi_\theta(a|s)\Big]$$
    • PPO更新公式:
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{(s,a) \sim \pi_{\theta_k}}\Big[\frac{\pi_\theta(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a) - \beta D_{KL}(\pi_{\theta_{k}}(\cdot|s), \pi_\theta(\cdot|s))\Big]$$
    • DDPG更新公式
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{s_t \sim \rho^\beta(s)} [Q_w(s_t,\mu_\theta(s_t))] $$
    • SAC更新公式
      $$\mathop{\arg\max}_{\theta}\mathbb{E}_{s_t \sim \mathcal{D}, \epsilon_t \sim \mathcal{N}}[\log \pi_\theta(f_\theta(\epsilon_t;s_t)\vert s_t) - Q_\theta(s_t, f_\theta(\epsilon_t; s_t))]$$
    • AWR更新公式:
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{(s,a) \sim \pi_\beta}\Big[exp\Big(\frac{1}{\beta}(R_{s,a}^{\mathcal{D}}-V^{\mathcal{D}}(s))\Big)\log\pi_\theta(a|s)\Big]$$
      • 其中 \(R_{s,a}^{\mathcal{D}} = \sum_{t=0}^\infty \gamma^t r_t\),不是网络,是真实的轨迹收益
    • IQL更新公式:
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{(s,a) \sim \pi_\beta}\Big[exp\Big(\beta (Q_{\theta’}(s, a) - V_\psi(s))\Big)\log\pi_\theta(a|s)\Big]$$
    • AWAC更新公式:
      $$\mathop{\arg\max}_{\theta} \mathbb{E}_{(s,a) \sim \pi_\beta}\Big[exp(\frac{1}{\lambda} A^{\pi_{\theta_k}}(s,a))\log\pi_\theta(a|s)\Big]$$
  • 基本推导思路总结:
    • 策略梯度法 :推导是直接从最初目标出发,视图求最初目标相对策略的梯度
    • PPO :更新公式是从策略提升的视角出发得到梯度提升的目标,通过限制策略变化幅度和重要性采样分别将未知策略的状态和动作采样的问题切换到已知策略
    • DDPG :直接以最大化Q值为目标来更新,可直接传导策略梯度
    • SAC :的目标中增加了熵,可以看成是DDPG的增加熵的版本
    • AWR、IQL和AWAC :更新公式都是相同的形式,是从策略提升的视角出发得到梯度提升的目标,并对该目标进行推导,得到最终的最优策略形式,再带入最优策略形式,从而得到更新公式
  • 也就是说,AWR、IQL和AWAC这三个方法的目标是为了策略提升量最大化 ,而策略梯度法的目标是为了原始目标最大化(梯度提升法)

附录:为什么 IQL 效果比 AWR 好?

  • IQL和 AWR 的 Q 值是不同策略的优势函数,IQL 的优势函数是在 \(\tau\) 分位点期望动作策略分布上的 Q 和 V,即 \(A^{\pi^*}(s,a) = Q^{\pi^*}(s,a) - V^{\pi^*}(s)\),而AWR的优势函数是真实的轨迹回报和V值 \(A^{\pi_k}(s,a) = R_{s,a}^{\mathcal{D}} - V^{\pi_k}(s)\)
  • IQL 不是迭代训练,是先学好 Q 值(不依赖策略),再利用学好的 Q 值一次性提取策略
  • 标准的 AWR 是 off-policy 的,是一种迭代训练的流程,V 值学习依赖策略与环境交互的轨迹数据,策略学习也依赖上一步的V值,V值,策略,轨迹三者是不断优化的
  • 如果把 AWR 直接用到 Offline R L场景下,则不再与环境交互,AWR 退化到学习一次V值,接着一次性学习策略;
    • Offline RL 下学到的 V 值是行为策略对应的 V 值,不是最优的 V 值,但这本身应该没有问题
    • 基于统计的 \(R_{s,a}^{\mathcal{D}}\) 方差可能很大
  • 使用公式 \(L_\pi(\phi) = \mathbb{E}_{(s,a) \sim D} \left[ \exp(\beta (Q_{\theta’}(s, a) - V_\psi(s))) \log \pi_\phi(a|s) \right]\) 来迭代策略时,Q 值和 V 值应该使用什么样的才是最优的?
    • 这个公式是从最大化策略提升项得到的,在推导策略提升时,这里使用的A值(对应到Q值和V值)是上一步策略对应的值 \(A^\mu(s,a)\),即旧策略 \(\mu\) 对应 Q 值和 V 值,而我们的目标是在 \(\mu\) 的基础上有所提升,得到优秀的新策略 \(\pi\),所以 Q 值和 V 值最好是优秀的策略对应的Q值和V值,否则可能我们的策略 \(\pi\) 在不好的策略上提升,结果也可能不是很优秀
  • 补充问题:可以随便使用一个策略来评估优势函数吗?
    • 回答是不可以,因为不同策略下,A 值选择不同动作以后的值是不同的,显然学到的策略也不同,从推导看,必须使用上一步的才可以

附录:贝尔曼方程收敛性及 \(\tau\) 的分析

  • 关于参数 \(\tau\) 的一些分析,原始论文中关于 \(\tau\) 的分析如下:

  • 当 \(\tau = 0.5\),相当于是SARSA算法;当 \(\tau \rightarrow 1\),相当于是Q-Learning算法

  • 对于任意的 \(\tau\),Q值和V值迭代都会收敛,且Q值和V值会收敛到 \(Q_{\tau}(s,a)\) 和 \(V_{\tau}(s)\),Lamma1中最后两行就是两者的贝尔曼方程,其中 \(\mathbb{E}_{a \sim \mu(\cdot|s)}^\tau\) 表示 \(\mu(\cdot|s)\) 分布下的 \(\tau\) 期望分位值(或 \(\tau\) 阶期望分位数)。注意,我们在说分位数时,还需要说明是那个随机变量或者哪个分布的分位数,否则没有意义

  • 为什么说Q值和V值迭代都会收敛到 \(Q_{\tau}(s,a)\) 和 \(V_{\tau}(s)\) 呢?

    • 理解:这里的 \(\tau\) 期望分位动作可以视作是一个策略,每次选择动作时,不选择最优动作,也不选择随机动作,而是选择 \(\tau\) 期望分位点动作,这样,可以得到跟论文中一样的结论:当 \(\tau = 0.5\),相当于是SARSA算法;当 \(\tau \rightarrow 1\),相当于是Q-Learning算法
    • 证明:定义一个策略如下:
      $$\pi_\tau(s) = \mathop{\text{arg_expectile}^\tau}_a(Q(s,a))$$
      该策略表示在状态 \(s\) 下,该策略会选择使得Q值等于 \(Q(s,a)\) 关于动作 \(a\) 的 \(\tau\) 期望分位点的动作,则期望分位动作策略对应的贝尔曼方程跟普通策略下的贝尔曼方程没有区别
    • 更详细的来说:
      • Q值:假定已经有了 \(V_\tau(s’)\),此时Q值的更新是学习当前状态 \(s\) 下,按照当前状态对应的 \(\tau\) 期望分位动作,以及后续策略也采用 \(\tau\) 期望分位动作得到的价值 \(V_\tau(s’)\) 来进行拟合的目标值(注意,这里跟其他贝尔曼方程一样,一旦动作决定了, \(r(s,a)\) 就确定了,我们所说的期望分位动作就是对动作 \(a\) 的分布而言的, \(Q(s,a)\) 的拟合只考虑 \((s,a)\) 状态动作对即可,不需要考虑期望分位动作);
      • V值:假定已经有了 \(Q_{\tau}(s,a)\),V值可以从 \(Q_{\tau}(s,a)\) 中学到 \(V_\tau(s’)\),这里需要使用 \(Q_{\tau}(s,a)\) 而不是 \(Q_{\pi_\beta(s,a)}\) 的原因是,V的本质是 \(Q(s,a)\) 关于动作 \(a\) 期望,但直接求期望只到了当前状态 \(s\) 这一层,如果使用 \(Q_{\pi_\beta(s,a)}\) 来学习那么学到的不是 \(V_\tau(s’)\) ( \(V_\tau(s’)\) 是指后续的动作也是 \(\tau\) 期望分位动作来定义的,正如Q值和V值的常规贝尔曼方程一样)

Implicit 名字的来源

  • Implicit 含义是“隐式的”,与隐式约束的隐式不等价,在IQL中表示通过期望回归隐式的学到了最优价值函数 \(V^*(s) = \max Q(s,a)\)

IQL 可能存在的问题

  • IQL 没有没有像 CQL 一样对非行为策略的 Q 值进行打压(甚至学习过程中全程未学习未知状态动作对的 Q 值),也没有像 BCQ 一样对动作选择进行限制,理论上可能会因为对 OOD 状态动作 Q 值高估而出现问题
  • IQL 源码实现时的解法:采用 Twin Q 来缓解高估问题(理解:对于数据集中存在的,两个 Q 网络都能估准;对于数据集中不存在的,可能都估不准,但是我们取最小的那个,可以缓解对未知状态动作对 Q 值的高估问题)

Python——Ray-分布式架构简单了解


整体介绍

  • Ray 是一个用于分布式计算的开源框架,专为构建和运行分布式应用程序而设计
  • Ray 提供了简洁的 API,让开发者能够轻松地将单机程序扩展到分布式集群上,同时保持代码的可读性和可维护性
  • Ray 最初由 UC Berkeley 的 RISELab 开发,现在由 Anyscale 公司维护,广泛应用于机器学习、强化学习、并行计算等领域
  • Ray 既可以在本地实现并行计算,又可以非常容易的扩展到集群模式,实现分布式计算
  • Ray 与深度学习框架如 TensorFlow、PyTorch 和 MXNet 等互相兼容

Ray 的核心架构

  • Ray的系统架构采用了混合任务调度的思路,遵循典型的 Master-Slave 设计,但与传统分布式系统有所不同

Ray 中的关键组件总结

  • Ray在集群部署模式下启动了以下关键组件:
    • GlobalScheduler(全局调度器) :运行在Master节点上,负责接收本地调度器提交的任务,并将任务分发给合适的本地任务调度器执行
    • RedisServer :Master节点上启动的Redis服务器,用于保存分布式任务的状态信息(ControlState),包括对象机器的映射、任务描述、任务 debug 信息等
    • LocalScheduler(本地调度器) :每个 Slave 节点上启动的本地调度器,用于提交任务到全局调度器,以及分配任务给当前机器的 Worker 进程
    • Worker进程 :每个 Slave 节点上可以启动多个 Worker 进程执行分布式任务,并将计算结果存储到 ObjectStore
    • ObjectStore(对象存储) :每个 Slave 节点上的存储系统,用于存储只读数据对象,Worker 可以通过共享内存的方式访问这些对象数据,有效减少内存拷贝和对象序列化成本。ObjectStore 底层由 Apache Arrow 实现
    • Plasma :每个 Slave 节点上的ObjectStore管理器,当 Worker 访问本地 ObjectStore 上不存在的远程数据对象时,Plasma 会主动拉取其它 Slave 上的对象数据到当前机器

执行模型

  • Ray的执行模型基于动态任务图 ,这与 TensorFlow 中的静态计算图有本质区别:
    • TensorFlow的计算图用于表征神经网络,在单个应用中执行很多次
    • Ray的任务图用于表征整个应用,并仅执行一次
    • 任务图对于前台是未知的,随着应用的运行而动态地构建
    • 一个任务的执行可能创建更多的任务,形成动态依赖关系

代码示例

并行计算示例(无状态)

  • 基于 Ray 的并行计算代码 Demo:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    import ray
    import time
    import numpy as np

    # 初始化 Ray,默认在本地启动
    ray.init()

    # 使用 @ray.remote 装饰器将函数转换为分布式任务
    @ray.remote
    def compute_square(x):
    # 模拟耗时计算
    time.sleep(1)
    return x * x

    # 生成一些数据
    data = np.arange(10)

    # 并行执行任务
    start_time = time.time()
    # 创建任务对象引用
    square_refs = [compute_square.remote(i) for i in data]
    # 等待所有任务完成并获取结果
    results = ray.get(square_refs)
    end_time = time.time()

    print(f"串行计算结果: {[i*i for i in data]}")
    print(f"Ray 并行计算结果: {results}")
    print(f"Ray 并行计算耗时: {end_time - start_time:.4f} 秒")

    # 关闭 Ray
    ray.shutdown()

    # 串行计算结果: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
    # Ray 并行计算结果: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
    # Ray 并行计算耗时: 1.4069 秒

串行计算示例(有状态)

  • 基于 Ray 的串行计算代码 Demo:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    import ray
    import time

    # 初始化 Ray
    ray.init()

    # 使用 @ray.remote 装饰器定义 Actor 类
    @ray.remote
    class Counter:
    def __init__(self):
    self.count = 0

    def increment(self):
    time.sleep(1) # 模拟耗时操作
    self.count += 1
    return self.count

    def get_count(self):
    return self.count

    # 创建 Actor 实例
    counter = Counter.remote()

    # 并行调用 Actor 方法
    start_time = time.time()
    # 提交多个增量任务
    increment_refs = [counter.increment.remote() for _ in range(10)]
    # 获取所有增量任务的结果
    results = ray.get(increment_refs)
    # 获取最终计数
    final_count = ray.get(counter.get_count.remote())
    end_time = time.time()

    print(f"每次增量结果: {results}")
    print(f"最终计数: {final_count}")
    print(f"执行耗时: {end_time - start_time:.4f} 秒")

    # 关闭 Ray
    ray.shutdown()

    # 每次增量结果: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    # 最终计数: 10
    # 执行耗时: 10.1293 秒

分布式调度示例(集群模式)

  • 以上代码经过非常简单的修改即可进入集群模式

  • Ray 集群部署包括三个步骤(下面以 6379 端口为例展示流程)

  • 第一步:启动主节点 ,运行 ray start --head 从主节点启动集群

    1
    ray start --head --port=6379 --redis-password='your_secure_password_123'
    • 注:可通过 --redis-password 设置密码(可选),防止未授权节点加入,也可以不使用该参数
  • 第二步:启动工作节点 ,运行 ray start --address=<主节点IP> 加入集群

    1
    ray start --address='<head-node-ip>:6379' --redis-password='your_secure_password_123'
    • 执行上述命令后工作节点就会:
      • 自动连接到主节点
      • 等待接收任务
      • 执行主节点分配的计算任务
      • 将结果返回给主节点
  • 第三步:在主节点上运行的代码中连接集群

    1
    ray.init(address='auto', _redis_password='your_secure_password_123')
    • 注:以上代码仅在主节点上运行,工作节点不需要显示运行任何代码,仅需要启动并加入集群即可
  • 关闭 Ray 服务:

    1
    ray stop
  • 特别说明:集群模式与普通单机并行模式的区别很小,仅需要增加修改以上代码即可(其他代码都不需要修改)

  • Ray 在分布式下默认有许多默认功能:

    • 自动负载均衡:Ray 会自动将任务分配到空闲节点
    • 容错能力:如果某个工作节点失败,Ray 会重新调度任务
  • 集群模式工作流程总结:

    • 主节点通过 Redis 将任务(remote 函数或者类对象)放入队列
    • 空闲的工作节点从队列中获取任务
    • 工作节点执行任务
    • 将运算结果通过 共享内存/Object Store 返回给主节点

附录:工作节点启动高级配置

  • 可以通过参数调整工作节点行为:
    1
    2
    3
    4
    5
    ray start --address='<head-node-ip>:6379' \
    --redis-password='your_secure_password_123' \
    --num-cpus=8 \ # 限制使用8个CPU核心
    --num-gpus=1 \ # 声明有1个GPU可用
    --object-store-memory=100000000 \ # 设置对象存储大小

附录:Ray 集群状态监控

  • Ray 提供了 Web UI 用于监控集群状态
  • 在主节点启动时已经启用了 Dashboard(默认端口8265)
  • 在浏览器访问:http://<主节点IP>:8265
  • 在 Dashboard 中可以看到:
    • 集群节点列表和资源使用情况
    • 当前运行的任务
    • 历史任务统计
    • 每个节点的CPU/内存使用情况

DL——模型训练预热


整体说明

  • 预热(Warm-up)是一种训练技巧:
    • 在模型训练初期采用一些策略,逐步调整超参数(如学习率、 Batch Size 大小等)或模型状态 ,使得训练过程更加稳定、高效的初始化阶段
    • 通过合理预热,可以显著提升训练稳定性、收敛速度和最终性能
  • 预热的核心目的是避免训练初期因参数随机初始化或学习率过高导致的梯度不稳定、收敛困难等问题
  • 常见的预热技术主要包含两类:
    • 学习率预热(Learning Rate Warm-up) :训练初期从极小的学习率(如0)逐步线性或非线性增加到预设值
    • 优化器预热 :
      • Adam 预热阶段可用小学习率,比如正常值的 \(1/10\)(Adam 优化器的自适应动量在初期可能不准确);
      • Adam 在预热阶段启用偏差修正 ,避免初期估计偏差过大
  • 其他预热技术还包括:Batch Size 预热(Batch Size 从小到大),模型参数预热(逐步解冻模型层) 和 混合精度预热等(初期禁用混合精度)
  • 最常见的预热技术是学习率预热,其中 Transformer 常使用 学习率线性预热(比如 BERT 训练中常用 10,000 步线性预热)
  • 术语:warm-up ratio
    • 如 warm-up ratio 等于 0.03,表示 warm-up 阶段(学习率上升阶段)步数占总训练阶段步数的 3%

学习率预热的相关策略

  • 学习率预热(Learning Rate Warm-up)是训练初期逐步增加学习率的策略,旨在稳定训练并提升最终性能。以下是常见的具体方法及其细节:

线性预热(Linear Warm-up)

  • 在预热步数 \(N\) 内,学习率从 \(0\)(或极小值 \(\epsilon\))线性增长到初始学习率 \(lr_{\text{base} }\)
    $$
    lr_t = \epsilon + \left(\frac{t}{N}\right) \cdot (lr_{\text{base} } - \epsilon)
    $$
    • 其中 \(t\) 是当前步数,\(t \leq N\)
  • 最常用的方式之一

余弦预热(Cosine Warm-up)

  • 结合余弦函数曲线调整学习率,初期缓慢增长,后期平滑过渡到目标值
    $$
    lr_t = \frac{1}{2} \left(1 + \cos\left(\pi \cdot \left(1 - \frac{t}{N}\right)\right)\right) \cdot lr_{\text{base} }
    $$
  • 注:也可与余弦退火结合,预热后直接进入衰减阶段
  • 更平滑的过渡,减少初期学习率突变
  • 一些大模型中会使用到

指数预热(Exponential Warm-up)

  • 学习率从 \(\epsilon\) 开始指数增长到 \(lr_{\text{base} }\)
    $$
    lr_t = \epsilon \cdot \left(\frac{lr_{\text{base} } }{\epsilon}\right)^{\frac{t}{N} }
    $$
  • 较少使用,因可能过早进入高学习率阶段

阶梯预热(Step Warm-up)

  • 将预热阶段分为多个离散区间,逐步跳跃式增加学习率

附录:torch 自带预热和学习率调度代码示例

  • 一个完整的PyTorch示例:先进行学习率预热,再正常训练模型

  • 以简单的图像分类任务(CIFAR-10)为基础,结合线性预热和余弦退火调度器

  • 代码示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import matplotlib.pyplot as plt

    class SimpleCNN(nn.Module):
    def__init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
    self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
    self.fc = nn.Linear(32 * 8 * 8, 10) # CIFAR-10输入为32x32,经过两次池化后为8x8
    self.pool = nn.MaxPool2d(2, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.pool(self.relu(self.conv1(x)))
    x = self.pool(self.relu(self.conv2(x)))
    x = x.view(-1, 32 * 8 * 8)
    x = self.fc(x)
    return x

    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=128, shuffle=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleCNN().to(device)
    ## 注:学习率包含在优化器 optimizer 中,使用不同的学习率调度器来执行 step,就可以实现不同的学习率调度
    optimizer = optim.AdamW(model.parameters(), lr=0.001) # 初始学习率设为0.001(预热目标值)

    warmup_steps = 500 # 预热步数
    total_steps = 5000 # 总训练步数

    # 线性预热函数
    def warmup_lambda(current_step):
    if current_step < warmup_steps:
    return float(current_step) / float(max(1, warmup_steps))
    else:
    return 1.0 # 预热结束后保持学习率

    # 预热阶段调度器
    warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda) # 基于优化器初始化调度器

    # 预热后的余弦退火调度器(从预热结束开始)
    cosine_scheduler = CosineAnnealingLR(
    optimizer, # 与预热阶段调度器初始化相同的优化器
    T_max=total_steps - warmup_steps, # 余弦周期长度
    eta_min=1e-6 # 最小学习率
    )

    criterion = nn.CrossEntropyLoss()
    lr_history = []
    for step in range(total_steps):
    inputs = torch.randn(128, 3, 32, 32).to(device)
    labels = torch.randint(0, 10, (128,)).to(device)

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    # 更新学习率
    if step < warmup_steps:
    warmup_scheduler.step() # 预热阶段,step 函数会按照 warmup_scheduler 的定义来修改学习率
    else:
    cosine_scheduler.step() # 预热后余弦退火,step 函数会按照 cosine_scheduler 的定义来修改学习率

    # 记录学习率,可打印出来观测
    lr_history.append(optimizer.param_groups[0]['lr'])

    if step % 200 == 0:
    print(f"Step {step}: LR = {optimizer.param_groups[0]['lr']:.6f}, Loss = {loss.item():.4f}")
  • 预热阶段(前500步):学习率从 0 线性增长到初始值 0.001
    $$ lr = \text{base_lr} \times \frac{\text{current_step} }{\text{warmup_steps} } $$

  • 正常训练阶段(500步后):切换为余弦退火调度器(CosineAnnealingLR),学习率从 0.001 逐渐衰减到 1e-6

    • 注: 余弦退火的周期长度 \( T_{\text{max} } \) 设为总步数减去预热步数
  • 总体来说,学习率曲线是先线性上升,后余弦式下降(平滑振荡衰减)的过程


附录:transformers 库的模型训练预热调度示例

  • transformers 库中使用模型训练预热代码(按照初始学习率 1e-4, epochs= )

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    import matplotlib.pyplot as plt
    import transformers
    import torch

    initial_lr = 1.0e-4 # 初始学习率
    warmup_ratio = 0.1 # 预热比例

    num_training_steps = 1000 # 总训练 step 数
    num_warmup_steps = int(num_training_steps * warmup_ratio) # 计算 warmup 的 step 数

    optimizer = torch.optim.AdamW([torch.tensor(0.0)], lr=initial_lr) # [torch.tensor(0.0)] 是虚拟的模型参数,可随意设置

    # 使用 transformers 库创建余弦退火学习率调度器
    lr_scheduler = transformers.get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps, # warmup step 数
    num_training_steps=num_training_steps, # 训练总 step 数
    # num_cycles=0.5, # 对应 cosine 曲线的周期,默认值是0.5,也就是半周期(递减)
    # last_epoch=-1, # 用于从 checkpoint 启动时恢复训练,设置为 ckpt 对应 step-1 即可
    # 比如从第 500 步的 ckpt启动,设置为499,从第0步启动,设置为-1(默认值)
    )

    learning_rates = []
    for _ in range(num_training_steps):
    learning_rates.append(optimizer.param_groups[0]["lr"])
    lr_scheduler.step() # 更新 optimizer.param_groups[0]["lr"]

    # 设置中文字体
    plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]

    # 以下为可视化代码
    plt.figure(figsize=(10, 6))
    plt.plot(learning_rates)
    plt.title('学习率变化曲线')
    plt.xlabel('训练步骤')
    plt.ylabel('学习率')
    plt.grid(True)
    plt.axvline(x=num_warmup_steps, color='r', linestyle='--', label='预热结束')
    plt.legend()

    plt.annotate(f'初始学习率: {initial_lr}', xy=(num_warmup_steps, initial_lr),
    xytext=(num_warmup_steps + 50, initial_lr * 1.5),
    arrowprops=dict(facecolor='black', shrink=0.05))
    plt.annotate(f'预热起点: 0', xy=(0, 0),
    xytext=(50, initial_lr * 0.2),
    arrowprops=dict(facecolor='black', shrink=0.05))
    plt.annotate(f'最终学习率: {learning_rates[-1]:.8f}', xy=(num_training_steps-1, learning_rates[-1]),
    xytext=(num_training_steps-200, learning_rates[-1] * 10),
    arrowprops=dict(facecolor='black', shrink=0.05))
    plt.tight_layout()
    plt.savefig('warmup_learning_rate_curve_cycles_0.5.png', dpi=300)
    # plt.show()
  • 可视化结果(半周期余弦 num_cycles=0.5 的结果):

    • warmup 阶段,学习率从 0 开始逐步提升到最大值
    • 正式训练阶段,学习率按照余弦调度器波动
  • 如果设置为 num_cycles=1,则会在指定训练步数内完成两个周期的学习率变化:

  • 如果设置为 num_cycles=1.5,则会在指定训练步数内完成两个周期的学习率变化:

  • 如果设置为 num_cycles=2,则会在指定训练步数内完成两个周期的学习率变化:


附录:预热有什么用?

  • 解决梯度不稳定问题 :模型初始阶段参数随机初始化,直接使用高学习率可能导致梯度爆炸或震荡
  • 解决学习率敏感性问题 :过大的初始学习率可能使模型跳过最优解附近区域;过小则导致收敛缓慢
  • 保证优化器适应性 :如 Adam 等自适应优化器在初期需要积累梯度统计量(如动量、方差),预热阶段可为优化器提供更稳定的初始估计

附录:一般预热多少步更合适?

  • 预热步数通常取决于模型规模和数据集大小:
    • 小规模数据:数百到几千步
    • 大规模训练(如LLM):数万步甚至更长(例如 GPT-3 的数千批次预热)
  • 另一种设置方式是:通常为总训练步数的 5-10%(例如 BERT 的 10k 步预热,总步数 100k)

DL——深度学习并行技术总结


整体说明

  • 并行化技术一般在训练大型深度学习模型时使用
  • 并行化技术氛围三种:
    • 数据并行 (Data Parallelism)
    • 模型并行 (Model Parallelism)
    • 流水线并行 (Pipeline Parallelism),有的地方也翻译为管道并行

各种并行方法之间的关系总结

  • 整体可分为 模型并行 (Model Parallelism) 和 数据并行
    • 数据并行 :每个 GPU 都拥有一个完整的模型副本,但处理不同的数据批次
    • 模型并行 :每个 GPU 只负责模型的一部分,所有 GPU 共同处理一个完整的数据批次,包括 张量并行 和 流水线并行 两种具体实现
  • 模型并行的进一步介绍:当模型太大无法放入单个 GPU 时,就需要使用模型并行
    • 将模型的不同部分分配给不同的 GPU
    • 优点是可以训练显存无法容纳的巨大模型
    • 缺点是实现相对复杂,且由于不同 GPU 间的通信和等待,可能会导致 GPU 利用率不高
  • 实际应用中,为了充分利用资源并训练超大模型,通常会结合多种并行化技术,形成 混合并行 策略,例如同时使用数据并行、流水线并行和张量并行

数据并行 (Data Parallelism)

  • 最常见、也最容易理解的并行化方法
  • 数据并行的工作方式 :
    • 训练数据集被分成多个子集(例如,一个 128 张图片的批次被分成 4 个 32 张图片的子批次)
    • 每个 GPU 拥有一个完整的模型,并独立处理一个子批次的数据
  • 数据并行的训练过程 :
    • 1)每个 GPU 计算其子批次的前向传播和反向传播,得到各自的梯度
    • 2)通过 All-Reduce 这样的通信操作,将所有 GPU 的梯度进行汇总和平均
    • 3)每个 GPU 用这个平均后的梯度来更新自己的模型参数,从而确保所有模型副本保持同步
  • 优点是实现简单,对模型结构无特殊要求 ,比如使用 PyTorch 的 DP 类就可以实现
  • 缺点是每个 GPU 都需要存储完整的模型 ,当模型参数量非常大时,会超出单个 GPU 的显存限制 ,此时数据并行就无法使用

流水线并行 (Pipeline Parallelism)

  • 流水线并行将模型的不同“层”(或一组层)分配给不同的 GPU,形成一个“流水线”
    • 例如,GPU 1 负责模型的第 1-4 层,GPU 2 负责第 5-8 层,以此类推
  • 数据并行的训练过程 :
    • 一个数据批次被分解成更小的“微批次”(micro-batches)
    • GPU 1 处理第一个微批次,完成后将输出传给 GPU 2
    • 当 GPU 1 开始处理第二个微批次时,GPU 2 就可以同时处理第一个微批次
  • 优点是部分解决了模型过大的问题(单层过大仍然无法解决)
  • 缺点是存在 “流水线气泡”(pipeline bubble) 问题,即流水线开始和结束时,部分 GPU 会处于空闲等待状态,导致 GPU 利用率并非 100%
    • 注:通过流水线的方式(即错位并行处理不同微批次的方式),可以一定程度上重叠不同 GPU 的计算,提高整体效率

张量并行 (Tensor Parallelism)

  • 张量并行 不按层切分模型,而是将模型中某个操作内部的 张量(例如一个大型矩阵)切分到不同的 GPU 上
  • 举例来说:一个 \(A \times B\) 的矩阵乘法,可以把矩阵 B 按列切分,每个 GPU 分别计算,最后再通过通信操作将结果合并
  • 优点是:
    • 可以进一步解决单个层或单个操作的参数过大的问题,从根本上解决了模型过大的问题
    • 因为所有 GPU 都在同一时间处理同一个微批次,所以不会有流水线气泡问题 ,GPU 利用率通常更高
  • 缺点是
    • 对模型结构有要求,通常只能在某些特定操作(如矩阵乘法、线性层)中应用
    • 需要频繁的 GPU 间通信来同步切分后的张量,这要求非常高速的 GPU 互联(例如 NVLink)
1…222324…67
Joe Zhou

Joe Zhou

Stay Hungry. Stay Foolish.

662 posts
53 tags
GitHub E-Mail
© 2026 Joe Zhou
Powered by Hexo
|
Theme — NexT.Gemini v5.1.4