本文使用的JAX源码commit id:4dd1f001c626eb15f1a8deac58d97b578a1bd85c
1. 源码编译debug版本JAX
1.1 准备工作
- 下载JAX源码:
git clone https://github.com/google/jax.git
- 修改根目录下的WORKSPACE文件,以使用指定
commit id
版本TensorFlow中的XLA基础库。修改内容如下:
1
2
3
4
5
6
7
8
9
10
# sha256求取方法:curl -L https://github.com/tensorflow/tensorflow/archive/4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b.tar.gz | sha256sum
# strip_prefix后面的数字为指定的commit id
http_archive(
name = "org_tensorflow",
sha256 = "78edf5464de6e0fc2e9358cfca93b2208ad9bf742147e40c35bc6274b5dedd64",
strip_prefix = "tensorflow-4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b",
urls = [
"https://github.com/tensorflow/tensorflow/archive/4d5bc8437a60d31cc9d8e8e005cd90fbe6ebee8b.tar.gz",
],
)
- 修改根目录下的.bazelrc文件,以提供
dbg
编译参数,这里主要参考tensorflow/.bazelrc文件的内容。具体需要在30行
后添加如下代码:
1
2
3
4
5
6
7
8
# Debug config
build:dbg -c dbg
build:dbg --per_file_copt=+.*,-tensorflow.*@-g0
build:dbg --per_file_copt=+tensorflow/core/kernels.*@-g0
build:dbg --per_file_copt=+jaxlib.*@-g
build:dbg --per_file_copt=+examples/jax_cpp/main.*@-g
build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON
build:dbg --copt -DDEBUG_BUILD
-
在build/build.py#L512后面添加
config_args += ["--config=dbg"]
语句 -
因为JAX使用C++14标准,而TensorFlow使用C++17标准,所以在编译debug版本的JAX时会出现
CompiledFunctionCache::kDefaultCapacity
未定义的错误。因此,我们需要对使用的TensorFlow代码也进行一些修改:- 使用
find ~/.cache/bazel/ -path "*external/org_tensorflow" | grep "_bazel_root/[A-Za-z0-9]*/external/org_tensorflow$"
1查找JAX所依赖的TensorFlow路径。示例:~/.cache/bazel/_bazel_root/42210d9a2e5c41f7817f753f6f92c412/external/org_tensorflow - 按照如下方式修改tensorflow/compiler/xla/python/jax_jit.cc文件中
CompiledFunctionCache
类静态变量kDefaultCapacity
的定义:
1 2 3 4
-static constexpr int kDefaultCapacity = 4096; +static const int kDefaultCapacity; +const int CompiledFunctionCache::kDefaultCapacity = 4096;
- 使用
1.2 编译与安装
一切准备就绪后,即可使用如下命令进行jaxlib
debug版本的编译与安装:
1
2
3
4
5
#!/bin/bash -ex
python build/build.py --enable_cuda
# pip uninstall -y jaxlib
pip install dist/*.whl
jaxlib
安装完成后,只需在JAX根目录下运行pip install -e .
命令即可完成JAX python代码部分的安装2。
Upgrade Note: To upgrade to the latest version from GitHub, just run git pull from the JAX repository root, and rebuild by running build.py or upgrading jaxlib if necessary. You shouldn’t have to reinstall jax because pip install -e sets up symbolic links from site-packages into the repository.
2. JAX程序转HLO并执行
examples/jax_cpp/main.cc程序示例默认使用CPU后端执行JIT过程。为支持GPU后端运行程序,我们需要对该示例进行部分修改。
2.1 C++源码使用GPU后端
使用如下代码替换examples/jax_cpp/main.cc文件中的内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
// An example for reading a HloModule from a HloProto file and execute the
// module on PJRT GPU client.
//
// To build a HloModule,
//
// $ python3 jax/tools/jax_to_ir.py \
// --fn examples.jax_cpp.prog.fn \
// --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
// --constants '{"z": 2.0}' \
// --hlo_text_dest /tmp/fn_hlo.txt \
// --hlo_proto_dest /tmp/fn_hlo.pb
//
// To load and run the HloModule,
//
// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec --check_visibility=false
// $ bazel-bin/examples/jax_cpp/main
// 2021-01-12 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = (
// f32[2,2] {
// { 1.5, 1.5 },
// { 3.5, 3.5 }
// }
// )
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
int main(int argc, char** argv) {
tensorflow::port::InitMain("", &argc, &argv);
// Load HloModule from file.
std::string hlo_filename = "/tmp/fn_hlo.txt";
std::function<void(xla::HloModuleConfig*)> config_modifier_hook =
[](xla::HloModuleConfig* config) { config->set_seed(42); };
std::unique_ptr<xla::HloModule> test_module =
LoadModuleFromFile(hlo_filename, xla::hlo_module_loader_details::Config(),
"txt", config_modifier_hook)
.ValueOrDie();
const xla::HloModuleProto test_module_proto = test_module->ToProto();
// Run it using JAX C++ Runtime (PJRT).
// Get a GPU client.
bool asynchronous = true;
xla::GpuAllocatorConfig allocator_config;
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client{nullptr};
int node_id = 0;
std::unique_ptr<xla::PjRtClient> client =
xla::GetGpuClient(asynchronous, allocator_config, distributed_client, node_id).ValueOrDie();
// Compile XlaComputation to PjRtExecutable.
xla::XlaComputation xla_computation(test_module_proto);
xla::CompileOptions compile_options;
std::unique_ptr<xla::PjRtExecutable> executable =
client->Compile(xla_computation, compile_options).ValueOrDie();
// Prepare inputs.
xla::Literal literal_x =
xla::LiteralUtil::CreateR3<float>({{{1.0f, 1.0f, 1.0f}, {2.0f, 2.0f, 2.0f}, {3.0f, 3.0f, 3.0f}, {4.0f, 4.0f, 4.0f}},
{{5.0f, 5.0f, 5.0f}, {6.0f, 6.0f, 6.0f}, {7.0f, 7.0f, 7.0f}, {8.0f, 8.0f, 8.0f}}});
xla::Literal literal_y =
xla::LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}});
// xla::Literal literal_z =
// xla::LiteralUtil::CreateR1<float>({1.0f});
xla::Literal literal_z =
xla::LiteralUtil::CreateR1<float>({1.0f, 1.0f});
// xla::Literal literal_z =
// xla::LiteralUtil::CreateR2<float>({{1.0f, 1.0f}, {1.0f, 1.0f}, {1.0f, 1.0f}, {1.0f, 1.0f}});
std::unique_ptr<xla::PjRtBuffer> param_x =
client->BufferFromHostLiteral(literal_x, client->addressable_devices()[0])
.ValueOrDie();
std::unique_ptr<xla::PjRtBuffer> param_y =
client->BufferFromHostLiteral(literal_y, client->addressable_devices()[0])
.ValueOrDie();
std::unique_ptr<xla::PjRtBuffer> param_z =
client->BufferFromHostLiteral(literal_z, client->addressable_devices()[0])
.ValueOrDie();
// Execute on GPU.
xla::ExecuteOptions execute_options;
// One vector<buffer> for each device.
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results =
executable->Execute({{param_x.get(), param_y.get(), param_z.get()}}, execute_options)
.ValueOrDie();
// Get result.
std::shared_ptr<xla::Literal> result_literal =
results[0][0]->ToLiteral().ValueOrDie();
LOG(INFO) << "result = " << *result_literal;
return 0;
}
2.2 添加GPU后端编译依赖
修改examples/jax_cpp/BUILD编译配置文件,在tf_cc_binary.deps
域中添加pjrt:gpu_device
和jit:xla_gpu_jit
两个依赖。修改后内容如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
)
licenses(["notice"])
tf_cc_binary(
name = "main",
srcs = ["main.cc"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla:literal",
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:status",
"@org_tensorflow//tensorflow/compiler/xla:statusor",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:gpu_device",
"@org_tensorflow//tensorflow/compiler/jit:xla_gpu_jit",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
"@org_tensorflow//tensorflow/core/platform:logging",
"@org_tensorflow//tensorflow/core/platform:platform_port",
],
)
2.3 编译生成main
可执行文件
使用如下命令编译修改后的examples/jax_cpp/main.cc
程序:
1
2
3
4
5
6
#!/bin/bash -ex
bazel build --verbose_failures=true --config=avx_posix \
--config=mkl_open_source_only --config=cuda --config=dbg \
examples/jax_cpp:main \
--experimental_repo_remote_exec --check_visibility=false
编译完成后,可在bazel-bin/examples/jax_cpp/
目录下看到main
可执行文件。
2.4 生成HLO模块文件
在JAX项目根目录下创建jax2hlo
目录,并在该目录下创建prog.py
和jax2hlo.sh
文件,两个文件的内容如下:
-
prog.py
文件内容:1 2 3 4
import jax.numpy as jnp def fn(x, y, z, alpha, beta): return alpha * jnp.dot(x, y) + beta * z
-
jax2hlo.sh
文件内容:1 2 3 4 5 6 7 8 9
#!/bin/bash -ex python ../jax/tools/jax_to_ir.py \ --fn jax2hlo.prog.fn \ --input_shapes '[("x", "f32[2,4,3]"), ("y", "f32[3,2]"), ("z", "f32[2]")]' \ --constants '{"alpha": 2.0, "beta": 3.0}' \ --ir_format HLO \ --ir_human_dest /tmp/fn_hlo.txt \ --ir_dest /tmp/fn_hlo.pb
在jax2hlo
目录下执行./jax2hlo.sh
脚本,可在/tmp/
目录下看到fn_hlo.txt和fn_hlo.pb两个生成文件
2.5 运行main
可执行文件
在JAX根目录下执行如下命令运行main
可执行文件:
1
2
3
4
TF_CPP_VMODULE=dot_decomposer=4 TF_CPP_LOG_THREAD_ID=1 \
TF_CPP_MIN_LOG_LEVEL=0 CUDA_VISIBLE_DEVICES=0 \
XLA_FLAGS="--xla_dump_to=./xla_main_output --xla_dump_hlo_pass_re=.* --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_hlo_snapshots" \
bazel-bin/examples/jax_cpp/main
3. 附录
3.1 问题记录
说明:正常情况下该问题不会出现。一般出现nccl编译错误时,第一时间应先检查一下所使用的
bazel
编译命令,可能忘记添加--config=cuda
选项。
如果出现fatal error: third_party/nccl/nccl.h: No such file or directory
编译错误,可尝试设置如下环境变量:
1
2
3
export TF_NCCL_VERSION='2.11.4'
export NCCL_INSTALL_PATH=/usr/local/lib
export NCCL_HDR_PATH=/usr/local/include
3.2 源码编译nccl库
1
2
3
4
5
git clone https://github.com/NVIDIA/nccl.git
cd nccl/
make CUDA_HOME=/usr/local/cuda
sudo make install