3030#include " ../file_utils.h"
3131#include " ../meta_data.h"
3232#include " ../pack_args.h"
33+ #include " ../source_utils.h"
3334#include " ../thread_storage_scope.h"
3435#include " metal_common.h"
3536
4344 public:
4445 explicit MetalModuleNode (std::string data, std::string fmt,
4546 std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
46- : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
47+ : data_(data), fmt_(fmt), fmap_(fmap), source_(source) {
48+ parsed_kernels_ = SplitKernels (data);
49+ }
4750 const char * type_key () const final { return " metal" ; }
4851
4952 PackedFunc GetFunction (const std::string& name, const ObjectPtr<Object>& sptr_to_self) final ;
@@ -71,6 +74,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
7174 return " " ;
7275 }
7376 }
77+
7478 // get a from primary context in device_id
7579 id <MTLComputePipelineState > GetPipelineState (size_t device_id, const std::string& func_name) {
7680 metal::MetalWorkspace* w = metal::MetalWorkspace::Global ();
@@ -85,44 +89,52 @@ void SaveToBinary(dmlc::Stream* stream) final {
8589 if (it != e.smap .end ()) return it->second ;
8690 // compile
8791 NSError * err_msg = nil ;
88- if (e.lib == nil ) {
89- if (fmt_ == " metal" ) {
90- MTLCompileOptions * opts = [MTLCompileOptions alloc ];
91- opts.languageVersion = MTLLanguageVersion2_3;
92- opts.fastMathEnabled = YES ;
93- // opts = nil;
94- e.lib = [w->devices[device_id]
95- newLibraryWithSource: [NSString stringWithUTF8String: data_.c_str ()]
96- options: opts
97- error: &err_msg];
98- [opts dealloc ];
99- if (e.lib == nil ) {
100- LOG (FATAL) << " Fail to compile metal lib:" << [[err_msg localizedDescription ] UTF8String ];
101- }
102- if (err_msg != nil ) {
103- LOG (INFO) << " Warning: " << [[err_msg localizedDescription ] UTF8String ];
104- }
105- } else {
106- // Build from library.
107- auto q = dispatch_queue_create (" q" , DISPATCH_QUEUE_SERIAL);
108- auto data = dispatch_data_create (data_.c_str (), data_.length (), q,
109- ^{
110- });
111- e.lib = [w->devices[device_id] newLibraryWithData: data error: &err_msg];
112- if (err_msg != nil || e.lib == nil ) {
113- LOG (FATAL) << " Fail to compile metal lib:" << [[err_msg localizedDescription ] UTF8String ];
114- }
92+ id <MTLLibrary > lib = nil ;
93+ std::string source;
94+ auto kernel = parsed_kernels_.find (func_name);
95+ // If we cannot find this kernel in parsed_kernels_, it means that all kernels going together
96+ // without explicit separator. In this case we use data_ with all kernels. It done for backward
97+ // compatibility.
98+ if (kernel != parsed_kernels_.end ())
99+ source = kernel->second ;
100+ else
101+ source = data_;
102+ if (fmt_ == " metal" ) {
103+ MTLCompileOptions * opts = [MTLCompileOptions alloc ];
104+ opts.languageVersion = MTLLanguageVersion2_3;
105+ opts.fastMathEnabled = YES ;
106+ // opts = nil;
107+ lib =
108+ [w->devices[device_id] newLibraryWithSource: [NSString stringWithUTF8String: source.c_str ()]
109+ options: opts
110+ error: &err_msg];
111+ [opts dealloc ];
112+ if (lib == nil ) {
113+ LOG (FATAL) << " Fail to compile metal lib:" << [[err_msg localizedDescription ] UTF8String ];
114+ }
115+ if (err_msg != nil ) {
116+ LOG (INFO) << " Warning: " << [[err_msg localizedDescription ] UTF8String ];
117+ }
118+ } else {
119+ // Build from library.
120+ auto q = dispatch_queue_create (" q" , DISPATCH_QUEUE_SERIAL);
121+ auto data = dispatch_data_create (source.c_str (), source.length (), q,
122+ ^{
123+ });
124+ lib = [w->devices[device_id] newLibraryWithData: data error: &err_msg];
125+ if (err_msg != nil || lib == nil ) {
126+ LOG (FATAL) << " Fail to compile metal lib:" << [[err_msg localizedDescription ] UTF8String ];
115127 }
116128 }
117- id <MTLFunction > f =
118- [e.lib newFunctionWithName: [NSString stringWithUTF8String: func_name.c_str ()]];
129+ id <MTLFunction > f = [lib newFunctionWithName: [NSString stringWithUTF8String: func_name.c_str ()]];
119130 ICHECK (f != nil ) << " cannot find function " << func_name;
120131 id <MTLComputePipelineState > state =
121132 [w->devices[device_id] newComputePipelineStateWithFunction: f error: &err_msg];
122133 ICHECK (state != nil ) << " cannot get state:"
123134 << " for function " << func_name
124135 << [[err_msg localizedDescription ] UTF8String ];
125136 [f release ];
137+ [lib release ];
126138 // The state.threadExecutionWidth can change dynamically according
127139 // to the resource constraint in kernel, so it is not strictly hold
128140 // Turn of warp aware optimziation for now.
@@ -135,13 +147,10 @@ void SaveToBinary(dmlc::Stream* stream) final {
135147 private:
136148 // device specific entry
137149 struct DeviceEntry {
138- // library
139- id <MTLLibrary > lib = nil ;
140150 // state cache;
141- std::unordered_map<std::string, id <MTLComputePipelineState > > smap;
151+ std::unordered_map<std::string, id <MTLComputePipelineState >> smap;
142152
143153 ~DeviceEntry () {
144- if (lib != nil ) [lib release ];
145154 for (auto && kv : smap) {
146155 [kv.second release ];
147156 }
@@ -159,6 +168,8 @@ void SaveToBinary(dmlc::Stream* stream) final {
159168 std::vector<DeviceEntry> finfo_;
160169 // internal mutex when updating the module
161170 std::mutex mutex_;
171+ // parsed kernel data
172+ std::unordered_map<std::string, std::string> parsed_kernels_;
162173};
163174
164175// a wrapped function class to get packed func.
0 commit comments