北冥有鱼 记录生活点滴,分享学习心得

JAX程序转HLO执行

Posted by YuChen on February 17, 2022

本文使用的JAX源码commit id:4dd1f001c626eb15f1a8deac58d97b578a1bd85c

1. 源码编译debug版本JAX

1.1 准备工作

  • 下载JAX源码:git clone https://github.com/google/jax.git
  • 修改根目录下的WORKSPACE文件,以使用指定commit id版本TensorFlow中的XLA基础库。修改内容如下:
1
2
3
4
5
6
7
8
9
10
# sha256求取方法:curl -L https://github.com/tensorflow/tensorflow/archive/4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b.tar.gz | sha256sum
# strip_prefix后面的数字为指定的commit id
http_archive(
  name = "org_tensorflow",
  sha256 = "78edf5464de6e0fc2e9358cfca93b2208ad9bf742147e40c35bc6274b5dedd64",
  strip_prefix = "tensorflow-4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b",
  urls = [
      "https://github.com/tensorflow/tensorflow/archive/4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b.tar.gz",
  ],
)
  • 修改根目录下的.bazelrc文件,以提供dbg编译参数,这里主要参考tensorflow/.bazelrc文件的内容。具体需要在30行后添加如下代码:
1
2
3
4
5
6
7
8
# Debug config
build:dbg -c dbg
build:dbg --per_file_copt=+.*,-tensorflow.*@-g0
build:dbg --per_file_copt=+tensorflow/core/kernels.*@-g0
build:dbg --per_file_copt=+jaxlib.*@-g
build:dbg --per_file_copt=+examples/jax_cpp/main.*@-g
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
build:dbg --copt -DDEBUG_BUILD
  • build/build.py#L512后面添加config_args += ["--config=dbg"]语句

  • 因为JAX使用C++14标准,而TensorFlow使用C++17标准,所以在编译debug版本的JAX时会出现CompiledFunctionCache::kDefaultCapacity未定义的错误。因此,我们需要对使用的TensorFlow代码也进行一些修改:

    • 使用find ~/.cache/bazel/ -path "*external/org_tensorflow" | grep "_bazel_root/[A-Za-z0-9]*/external/org_tensorflow$"1查找JAX所依赖的TensorFlow路径。示例:~/.cache/bazel/_bazel_root/42210d9a2e5c41f7817f753f6f92c412/external/org_tensorflow
    • 按照如下方式修改tensorflow/compiler/xla/python/jax_jit.cc文件中CompiledFunctionCache类静态变量kDefaultCapacity的定义:
    1
    2
    3
    4
    
    -static constexpr int kDefaultCapacity = 4096;
    +static const int kDefaultCapacity;
    
    +const int CompiledFunctionCache::kDefaultCapacity = 4096;
    

1.2 编译与安装

一切准备就绪后,即可使用如下命令进行jaxlib debug版本的编译与安装:

1
2
3
4
5
#!/bin/bash -ex

python build/build.py --enable_cuda
# pip uninstall -y jaxlib
pip install dist/*.whl

jaxlib安装完成后,只需在JAX根目录下运行pip install -e .命令即可完成JAX python代码部分的安装2

Upgrade Note: To upgrade to the latest version from GitHub, just run git pull from the JAX repository root, and rebuild by running build.py or upgrading jaxlib if necessary. You shouldn’t have to reinstall jax because pip install -e sets up symbolic links from site-packages into the repository.

2. JAX程序转HLO并执行

examples/jax_cpp/main.cc程序示例默认使用CPU后端执行JIT过程。为支持GPU后端运行程序,我们需要对该示例进行部分修改。

2.1 C++源码使用GPU后端

使用如下代码替换examples/jax_cpp/main.cc文件中的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// An example for reading a HloModule from a HloProto file and execute the
// module on PJRT GPU client.
//
// To build a HloModule,
//
// $ python3 jax/tools/jax_to_ir.py \
// --fn examples.jax_cpp.prog.fn \
// --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
// --constants '{"z": 2.0}' \
// --hlo_text_dest /tmp/fn_hlo.txt \
// --hlo_proto_dest /tmp/fn_hlo.pb
//
// To load and run the HloModule,
//
// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec --check_visibility=false
// $ bazel-bin/examples/jax_cpp/main
// 2021-01-12 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = (
// f32[2,2] {
//   { 1.5, 1.5 },
//   { 3.5, 3.5 }
// }
// )

#include <memory>
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"

int main(int argc, char** argv) {
  tensorflow::port::InitMain("", &argc, &argv);

  // Load HloModule from file.
  std::string hlo_filename = "/tmp/fn_hlo.txt";
  std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
      [](xla::HloModuleConfig* config) { config->set_seed(42); };
  std::unique_ptr<xla::HloModule> test_module =
      LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
                         "txt", config_modifier_hook)
          .ValueOrDie();
  const xla::HloModuleProto test_module_proto = test_module->ToProto();

  // Run it using JAX C++ Runtime (PJRT).

  // Get a GPU client.
  bool asynchronous = true;
  xla::GpuAllocatorConfig allocator_config;
  std::shared_ptr<xla::DistributedRuntimeClient> distributed_client{nullptr};
  int node_id = 0;
  std::unique_ptr<xla::PjRtClient> client =
      xla::GetGpuClient(asynchronous, allocator_config, distributed_client, node_id).ValueOrDie();

  // Compile XlaComputation to PjRtExecutable.
  xla::XlaComputation xla_computation(test_module_proto);
  xla::CompileOptions compile_options;
  std::unique_ptr<xla::PjRtExecutable> executable =
      client->Compile(xla_computation, compile_options).ValueOrDie();

  // Prepare inputs.
  xla::Literal literal_x =
      xla::LiteralUtil::CreateR3<float>({{{1.0f, 1.0f, 1.0f}, {2.0f, 2.0f, 2.0f}, {3.0f, 3.0f, 3.0f}, {4.0f, 4.0f, 4.0f}},
                                        {{5.0f, 5.0f, 5.0f}, {6.0f, 6.0f, 6.0f}, {7.0f, 7.0f, 7.0f}, {8.0f, 8.0f, 8.0f}}});
  xla::Literal literal_y =
      xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}});

  // xla::Literal literal_z =
  //     xla::LiteralUtil::CreateR1<float>({1.0f});
  xla::Literal literal_z =
      xla::LiteralUtil::CreateR1<float>({1.0f, 1.0f});
  // xla::Literal literal_z =
  //     xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}, {1.0f, 1.0f}, {1.0f, 1.0f}});
  
  std::unique_ptr<xla::PjRtBuffer> param_x =
      client->BufferFromHostLiteral(literal_x, client->addressable_devices()[0])
          .ValueOrDie();
  std::unique_ptr<xla::PjRtBuffer> param_y =
      client->BufferFromHostLiteral(literal_y, client->addressable_devices()[0])
          .ValueOrDie();
  std::unique_ptr<xla::PjRtBuffer> param_z =
      client->BufferFromHostLiteral(literal_z, client->addressable_devices()[0])
          .ValueOrDie();

  // Execute on GPU.
  xla::ExecuteOptions execute_options;
  // One vector<buffer> for each device.
  std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
      executable->Execute({{param_x.get(), param_y.get(), param_z.get()}}, execute_options)
          .ValueOrDie();

  // Get result.
  std::shared_ptr<xla::Literal> result_literal =
      results[0][0]->ToLiteral().ValueOrDie();
  LOG(INFO) << "result = " << *result_literal;
  return 0;
}

2.2 添加GPU后端编译依赖

修改examples/jax_cpp/BUILD编译配置文件,在tf_cc_binary.deps域中添加pjrt:gpu_devicejit:xla_gpu_jit两个依赖。修改后内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
load(
    "@org_tensorflow//tensorflow:tensorflow.bzl",
    "tf_cc_binary",
)

licenses(["notice"])

tf_cc_binary(
    name = "main",
    srcs = ["main.cc"],
    deps = [
        "@org_tensorflow//tensorflow/compiler/xla:literal",
        "@org_tensorflow//tensorflow/compiler/xla:literal_util",
        "@org_tensorflow//tensorflow/compiler/xla:shape_util",
        "@org_tensorflow//tensorflow/compiler/xla:status",
        "@org_tensorflow//tensorflow/compiler/xla:statusor",
        "@org_tensorflow//tensorflow/compiler/xla/pjrt:gpu_device",
        "@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit",
        "@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
        "@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
        "@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
        "@org_tensorflow//tensorflow/core/platform:logging",
        "@org_tensorflow//tensorflow/core/platform:platform_port",
    ],
)

2.3 编译生成main可执行文件

使用如下命令编译修改后的examples/jax_cpp/main.cc程序:

1
2
3
4
5
6
#!/bin/bash -ex

bazel build --verbose_failures=true --config=avx_posix       \
    --config=mkl_open_source_only --config=cuda --config=dbg \
    examples/jax_cpp:main                                    \
    --experimental_repo_remote_exec --check_visibility=false

编译完成后,可在bazel-bin/examples/jax_cpp/目录下看到main可执行文件。

2.4 生成HLO模块文件

在JAX项目根目录下创建jax2hlo目录,并在该目录下创建prog.pyjax2hlo.sh文件,两个文件的内容如下:

  • prog.py文件内容:

    1
    2
    3
    4
    
    import jax.numpy as jnp
    
    def fn(x, y, z, alpha, beta):
        return alpha * jnp.dot(x, y)  + beta * z
    
  • jax2hlo.sh文件内容:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    
    #!/bin/bash -ex
    
    python ../jax/tools/jax_to_ir.py                                               \
        --fn jax2hlo.prog.fn                                                       \
        --input_shapes '[("x", "f32[2,4,3]"), ("y", "f32[3,2]"), ("z", "f32[2]")]' \
        --constants '{"alpha": 2.0, "beta": 3.0}'                                  \
        --ir_format HLO                                                            \
        --ir_human_dest /tmp/fn_hlo.txt                                            \
        --ir_dest /tmp/fn_hlo.pb
    

jax2hlo目录下执行./jax2hlo.sh脚本,可在/tmp/目录下看到fn_hlo.txtfn_hlo.pb两个生成文件

2.5 运行main可执行文件

在JAX根目录下执行如下命令运行main可执行文件:

1
2
3
4
TF_CPP_VMODULE=dot_decomposer=4 TF_CPP_LOG_THREAD_ID=1                                                                                           \
    TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0                                                                                                \
    XLA_FLAGS="--xla_dump_to=./xla_main_output --xla_dump_hlo_pass_re=.* --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_hlo_snapshots" \
    bazel-bin/examples/jax_cpp/main

3. 附录

3.1 问题记录

说明:正常情况下该问题不会出现。一般出现nccl编译错误时,第一时间应先检查一下所使用的bazel编译命令,可能忘记添加--config=cuda选项。

如果出现fatal error: third_party/nccl/nccl.h: No such file or directory编译错误,可尝试设置如下环境变量

1
2
3
export TF_NCCL_VERSION='2.11.4'
export NCCL_INSTALL_PATH=/usr/local/lib
export NCCL_HDR_PATH=/usr/local/include

3.2 源码编译nccl库

1
2
3
4
5
git clone https://github.com/NVIDIA/nccl.git
cd nccl/
make CUDA_HOME=/usr/local/cuda

sudo make install




更多文章