Skip to content

Commit dc5fc68

Browse files
authored
[METAL] Split kernels and compile them separately (apache#7980)
1 parent aefa0c8 commit dc5fc68

File tree

9 files changed

+161
-86
lines changed

9 files changed

+161
-86
lines changed

apps/android_camera/app/src/main/jni/tvm_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
#ifdef TVM_OPENCL_RUNTIME
5858
#include "../src/runtime/opencl/opencl_device_api.cc"
5959
#include "../src/runtime/opencl/opencl_module.cc"
60+
#include "../src/runtime/source_utils.cc"
6061
#endif
6162

6263
#ifdef TVM_VULKAN_RUNTIME

apps/android_rpc/app/src/main/jni/tvm_runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
#ifdef TVM_OPENCL_RUNTIME
6363
#include "../src/runtime/opencl/opencl_device_api.cc"
6464
#include "../src/runtime/opencl/opencl_module.cc"
65+
#include "../src/runtime/source_utils.cc"
6566
#endif
6667

6768
#ifdef TVM_VULKAN_RUNTIME

golang/src/tvm_runtime_pack.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,4 @@
6868
// Uncomment the following lines to enable OpenCL
6969
// #include "../../src/runtime/opencl/opencl_device_api.cc"
7070
// #include "../../src/runtime/opencl/opencl_module.cc"
71+
// #include "../src/runtime/source_utils.cc"

src/runtime/metal/metal_module.mm

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "../file_utils.h"
3131
#include "../meta_data.h"
3232
#include "../pack_args.h"
33+
#include "../source_utils.h"
3334
#include "../thread_storage_scope.h"
3435
#include "metal_common.h"
3536

@@ -43,7 +44,9 @@
4344
public:
4445
explicit MetalModuleNode(std::string data, std::string fmt,
4546
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
46-
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
47+
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
48+
parsed_kernels_ = SplitKernels(data);
49+
}
4750
const char* type_key() const final { return "metal"; }
4851

4952
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
@@ -71,6 +74,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
7174
return "";
7275
}
7376
}
77+
7478
// get a from primary context in device_id
7579
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
7680
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
@@ -85,44 +89,52 @@ void SaveToBinary(dmlc::Stream* stream) final {
8589
if (it != e.smap.end()) return it->second;
8690
// compile
8791
NSError* err_msg = nil;
88-
if (e.lib == nil) {
89-
if (fmt_ == "metal") {
90-
MTLCompileOptions* opts = [MTLCompileOptions alloc];
91-
opts.languageVersion = MTLLanguageVersion2_3;
92-
opts.fastMathEnabled = YES;
93-
// opts = nil;
94-
e.lib = [w->devices[device_id]
95-
newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
96-
options:opts
97-
error:&err_msg];
98-
[opts dealloc];
99-
if (e.lib == nil) {
100-
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
101-
}
102-
if (err_msg != nil) {
103-
LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
104-
}
105-
} else {
106-
// Build from library.
107-
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
108-
auto data = dispatch_data_create(data_.c_str(), data_.length(), q,
109-
^{
110-
});
111-
e.lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg];
112-
if (err_msg != nil || e.lib == nil) {
113-
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
114-
}
92+
id<MTLLibrary> lib = nil;
93+
std::string source;
94+
auto kernel = parsed_kernels_.find(func_name);
95+
// If we cannot find this kernel in parsed_kernels_, it means that all kernels going together
96+
// without explicit separator. In this case we use data_ with all kernels. It done for backward
97+
// compatibility.
98+
if (kernel != parsed_kernels_.end())
99+
source = kernel->second;
100+
else
101+
source = data_;
102+
if (fmt_ == "metal") {
103+
MTLCompileOptions* opts = [MTLCompileOptions alloc];
104+
opts.languageVersion = MTLLanguageVersion2_3;
105+
opts.fastMathEnabled = YES;
106+
// opts = nil;
107+
lib =
108+
[w->devices[device_id] newLibraryWithSource:[NSString stringWithUTF8String:source.c_str()]
109+
options:opts
110+
error:&err_msg];
111+
[opts dealloc];
112+
if (lib == nil) {
113+
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
114+
}
115+
if (err_msg != nil) {
116+
LOG(INFO) << "Warning: " << [[err_msg localizedDescription] UTF8String];
117+
}
118+
} else {
119+
// Build from library.
120+
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
121+
auto data = dispatch_data_create(source.c_str(), source.length(), q,
122+
^{
123+
});
124+
lib = [w->devices[device_id] newLibraryWithData:data error:&err_msg];
125+
if (err_msg != nil || lib == nil) {
126+
LOG(FATAL) << "Fail to compile metal lib:" << [[err_msg localizedDescription] UTF8String];
115127
}
116128
}
117-
id<MTLFunction> f =
118-
[e.lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
129+
id<MTLFunction> f = [lib newFunctionWithName:[NSString stringWithUTF8String:func_name.c_str()]];
119130
ICHECK(f != nil) << "cannot find function " << func_name;
120131
id<MTLComputePipelineState> state =
121132
[w->devices[device_id] newComputePipelineStateWithFunction:f error:&err_msg];
122133
ICHECK(state != nil) << "cannot get state:"
123134
<< " for function " << func_name
124135
<< [[err_msg localizedDescription] UTF8String];
125136
[f release];
137+
[lib release];
126138
// The state.threadExecutionWidth can change dynamically according
127139
// to the resource constraint in kernel, so it is not strictly hold
128140
// Turn of warp aware optimziation for now.
@@ -135,13 +147,10 @@ void SaveToBinary(dmlc::Stream* stream) final {
135147
private:
136148
// device specific entry
137149
struct DeviceEntry {
138-
// library
139-
id<MTLLibrary> lib = nil;
140150
// state cache;
141-
std::unordered_map<std::string, id<MTLComputePipelineState> > smap;
151+
std::unordered_map<std::string, id<MTLComputePipelineState>> smap;
142152

143153
~DeviceEntry() {
144-
if (lib != nil) [lib release];
145154
for (auto&& kv : smap) {
146155
[kv.second release];
147156
}
@@ -159,6 +168,8 @@ void SaveToBinary(dmlc::Stream* stream) final {
159168
std::vector<DeviceEntry> finfo_;
160169
// internal mutex when updating the module
161170
std::mutex mutex_;
171+
// parsed kernel data
172+
std::unordered_map<std::string, std::string> parsed_kernels_;
162173
};
163174

164175
// a wrapped function class to get packed func.

src/runtime/opencl/opencl_common.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,14 +326,6 @@ class OpenCLModuleNode : public ModuleNode {
326326
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
327327
const std::string& func_name, const KTRefEntry& e);
328328

329-
/*
330-
* \brief Splits the provided serialized source file into separate
331-
* source for each kernel primitive.
332-
* \param source The serialized program source file (fmt: cl)
333-
* \return Mapping from primitive name to kernel source
334-
*/
335-
std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;
336-
337329
private:
338330
// The workspace, need to keep reference to use it in destructor.
339331
// In case of static destruction order problem.

src/runtime/opencl/opencl_module.cc

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <unordered_map>
3030
#include <vector>
3131

32+
#include "../source_utils.h"
3233
#include "opencl_common.h"
3334

3435
namespace tvm {
@@ -188,6 +189,11 @@ void OpenCLModuleNode::Init() {
188189

189190
// split into source artifacts for each kernel
190191
parsed_kernels_ = SplitKernels(GetSource("cl"));
192+
ICHECK(!parsed_kernels_.empty()) << "The OpenCL module expects a kernel delimited "
193+
<< "source from code generation, but no kernel "
194+
<< "delimiter was found.";
195+
ICHECK_EQ(workspace_->num_registered_kernels, parsed_kernels_.size())
196+
<< "The number of registered kernels does not match number of parsed kernel sources";
191197
// zero initialize cl_program pointers for each device kernel
192198
for (auto& kv : parsed_kernels_) {
193199
programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
@@ -242,39 +248,6 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
242248
return kernel;
243249
}
244250

245-
std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
246-
std::string source) const {
247-
std::unordered_map<std::string, std::string> split_kernels;
248-
if (source.size()) {
249-
std::string del{"// Function: "};
250-
size_t end;
251-
size_t begin = source.find(del);
252-
ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
253-
<< "source from code generation, but no kernel "
254-
<< "delimiter was found.";
255-
for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
256-
begin += del.size();
257-
end = source.find('\n', begin);
258-
std::string func_name = source.substr(begin, end - begin);
259-
begin = ++end;
260-
// std::string::substr returns either start of next kernel
261-
// or std::string::npos, in the latter case substr returns
262-
// all characters until the end of the source string.
263-
end = source.find(del, begin);
264-
std::string func_source =
265-
source.substr(begin, (end == std::string::npos) ? end : end - begin);
266-
split_kernels.insert({func_name, func_source});
267-
begin = end;
268-
if (end == std::string::npos) {
269-
break;
270-
}
271-
}
272-
}
273-
ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
274-
<< "The number of registered kernels does not match number of parsed kernel sources";
275-
return split_kernels;
276-
}
277-
278251
Module OpenCLModuleCreate(std::string data, std::string fmt,
279252
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
280253
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);

src/runtime/source_utils.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file source_utils.cc
22+
*/
23+
#include "source_utils.h"
24+
25+
namespace tvm {
26+
namespace runtime {
27+
28+
std::unordered_map<std::string, std::string> SplitKernels(std::string source,
29+
std::string delimiter) {
30+
std::unordered_map<std::string, std::string> split_kernels;
31+
if (source.size()) {
32+
size_t begin = source.find(delimiter);
33+
size_t end = begin;
34+
while (end != std::string::npos) {
35+
begin += delimiter.size();
36+
end = source.find('\n', begin);
37+
std::string func_name = source.substr(begin, end - begin);
38+
begin = ++end;
39+
end = source.find(delimiter, begin);
40+
std::string func_source =
41+
source.substr(begin, (end == std::string::npos) ? end : end - begin);
42+
split_kernels.insert({func_name, func_source});
43+
begin = end;
44+
}
45+
}
46+
return split_kernels;
47+
}
48+
} // namespace runtime
49+
} // namespace tvm

src/runtime/source_utils.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file source_utils.h
22+
* \brief Minimum source manipulation utils for runtime.
23+
*/
24+
25+
#ifndef TVM_RUNTIME_SOURCE_UTILS_H_
26+
#define TVM_RUNTIME_SOURCE_UTILS_H_
27+
28+
#include <string>
29+
#include <unordered_map>
30+
31+
namespace tvm {
32+
namespace runtime {
33+
/*!
34+
* \brief Split the source file on separate kernels by specified delimiter.
35+
* \param source The source code of the kernels.
36+
* \param delimiter The delimiter which is using for splitting kernels.
37+
* \return Mapping from primitive name to kernel source
38+
*/
39+
std::unordered_map<std::string, std::string> SplitKernels(std::string source,
40+
std::string delimiter = "// Function: ");
41+
} // namespace runtime
42+
} // namespace tvm
43+
44+
#endif // TVM_RUNTIME_SOURCE_UTILS_H_

src/target/source/codegen_metal.cc

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,27 +325,30 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO
325325
runtime::Module BuildMetal(IRModule mod, Target target) {
326326
using tvm::runtime::Registry;
327327
bool output_ssa = false;
328-
CodeGenMetal cg;
329-
cg.Init(output_ssa);
330328

329+
std::stringstream code;
330+
std::stringstream source;
331+
std::string fmt = "metal";
331332
for (auto kv : mod->functions) {
332333
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc";
334+
code << "// Function: " << kv.first->name_hint << std::endl;
335+
CodeGenMetal cg;
336+
cg.Init(output_ssa);
333337
auto f = Downcast<PrimFunc>(kv.second);
334338
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
335339
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
336340
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
337341
cg.AddFunction(f);
342+
std::string fsource = cg.Finish();
343+
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
344+
source << fsource;
345+
fsource = (*f)(fsource).operator std::string();
346+
fmt = "metallib";
347+
}
348+
code << fsource;
338349
}
339350

340-
std::string code = cg.Finish();
341-
std::string fmt = "metal";
342-
std::string source = "";
343-
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
344-
source = code;
345-
code = (*f)(code).operator std::string();
346-
fmt = "metallib";
347-
}
348-
return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
351+
return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
349352
}
350353

351354
TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);

0 commit comments

Comments
 (0)