Skip to content

Commit 443fa20

Browse files
authored
[RUNTIME] Update Module and Registry to use String Container (apache#14902)
This PR updates the Module and Registry's DLL function to use String container instead of std::string. While it is impossible to obtain a stable ABI due to the nature of c++, and it is important to keep that flexibility, it is helpful to keep small set of tvm/runtime functions to work with use a String so it is more stable across compilers.
1 parent 7fe58a1 commit 443fa20

File tree

71 files changed

+171
-187
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+171
-187
lines changed

apps/dso_plugin_module/plugin_module.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ class MyModuleNode : public ModuleNode {
3535

3636
virtual const char* type_key() const final { return "MyModule"; }
3737

38-
virtual PackedFunc GetFunction(const std::string& name,
39-
const ObjectPtr<Object>& sptr_to_self) final {
38+
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
4039
if (name == "add") {
4140
return TypedPackedFunc<int(int)>([sptr_to_self, this](int value) { return value_ + value; });
4241
} else if (name == "mul") {

include/tvm/runtime/module.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class Module : public ObjectRef {
9090
* This function will return PackedFunc(nullptr) if function do not exist.
9191
* \note Implemented in packed_func.cc
9292
*/
93-
inline PackedFunc GetFunction(const std::string& name, bool query_imports = false);
93+
inline PackedFunc GetFunction(const String& name, bool query_imports = false);
9494
// The following functions requires link with runtime.
9595
/*!
9696
* \brief Import another module into this module.
@@ -111,7 +111,7 @@ class Module : public ObjectRef {
111111
* \note This function won't load the import relationship.
112112
* Re-create import relationship by calling Import.
113113
*/
114-
TVM_DLL static Module LoadFromFile(const std::string& file_name, const std::string& format = "");
114+
TVM_DLL static Module LoadFromFile(const String& file_name, const String& format = "");
115115
// refer to the corresponding container.
116116
using ContainerType = ModuleNode;
117117
friend class ModuleNode;
@@ -165,14 +165,13 @@ class TVM_DLL ModuleNode : public Object {
165165
* If the function need resource from the module(e.g. late linking),
166166
* it should capture sptr_to_self.
167167
*/
168-
virtual PackedFunc GetFunction(const std::string& name,
169-
const ObjectPtr<Object>& sptr_to_self) = 0;
168+
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) = 0;
170169
/*!
171170
* \brief Save the module to file.
172171
* \param file_name The file to be saved to.
173172
* \param format The format of the file.
174173
*/
175-
virtual void SaveToFile(const std::string& file_name, const std::string& format);
174+
virtual void SaveToFile(const String& file_name, const String& format);
176175
/*!
177176
* \brief Save the module to binary stream.
178177
* \param stream The binary stream to save to.
@@ -186,12 +185,12 @@ class TVM_DLL ModuleNode : public Object {
186185
* \param format Format of the source code, can be empty by default.
187186
* \return Possible source code when available.
188187
*/
189-
virtual std::string GetSource(const std::string& format = "");
188+
virtual String GetSource(const String& format = "");
190189
/*!
191190
* \brief Get the format of the module, when available.
192191
* \return Possible format when available.
193192
*/
194-
virtual std::string GetFormat();
193+
virtual String GetFormat();
195194
/*!
196195
* \brief Get packed function from current module by name.
197196
*
@@ -201,7 +200,7 @@ class TVM_DLL ModuleNode : public Object {
201200
* This function will return PackedFunc(nullptr) if function do not exist.
202201
* \note Implemented in packed_func.cc
203202
*/
204-
PackedFunc GetFunction(const std::string& name, bool query_imports = false);
203+
PackedFunc GetFunction(const String& name, bool query_imports = false);
205204
/*!
206205
* \brief Import another module into this module.
207206
* \param other The module to be imported.
@@ -217,7 +216,7 @@ class TVM_DLL ModuleNode : public Object {
217216
* \param name name of the function.
218217
* \return The corresponding function.
219218
*/
220-
const PackedFunc* GetFuncFromEnv(const std::string& name);
219+
const PackedFunc* GetFuncFromEnv(const String& name);
221220
/*! \return The module it imports from */
222221
const std::vector<Module>& imports() const { return imports_; }
223222

@@ -268,7 +267,7 @@ class TVM_DLL ModuleNode : public Object {
268267
* \param target The target module name.
269268
* \return Whether runtime is enabled.
270269
*/
271-
TVM_DLL bool RuntimeEnabled(const std::string& target);
270+
TVM_DLL bool RuntimeEnabled(const String& target);
272271

273272
/*! \brief namespace for constant symbols */
274273
namespace symbol {

include/tvm/runtime/packed_func.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ inline TVMRetValue::operator T() const {
19421942
return PackedFuncValueConverter<T>::From(*this);
19431943
}
19441944

1945-
inline PackedFunc Module::GetFunction(const std::string& name, bool query_imports) {
1945+
inline PackedFunc Module::GetFunction(const String& name, bool query_imports) {
19461946
return (*this)->GetFunction(name, query_imports);
19471947
}
19481948

include/tvm/runtime/registry.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
#ifndef TVM_RUNTIME_REGISTRY_H_
4444
#define TVM_RUNTIME_REGISTRY_H_
4545

46+
#include <tvm/runtime/container/string.h>
4647
#include <tvm/runtime/packed_func.h>
4748

48-
#include <string>
4949
#include <type_traits>
5050
#include <utility>
5151
#include <vector>
@@ -295,32 +295,32 @@ class Registry {
295295
* \param override Whether allow override existing function.
296296
* \return Reference to the registry.
297297
*/
298-
TVM_DLL static Registry& Register(const std::string& name, bool override = false); // NOLINT(*)
298+
TVM_DLL static Registry& Register(const String& name, bool override = false); // NOLINT(*)
299299
/*!
300300
* \brief Erase global function from registry, if exist.
301301
* \param name The name of the function.
302302
* \return Whether function exist.
303303
*/
304-
TVM_DLL static bool Remove(const std::string& name);
304+
TVM_DLL static bool Remove(const String& name);
305305
/*!
306306
* \brief Get the global function by name.
307307
* \param name The name of the function.
308308
* \return pointer to the registered function,
309309
* nullptr if it does not exist.
310310
*/
311-
TVM_DLL static const PackedFunc* Get(const std::string& name); // NOLINT(*)
311+
TVM_DLL static const PackedFunc* Get(const String& name); // NOLINT(*)
312312
/*!
313313
* \brief Get the names of currently registered global function.
314314
* \return The names
315315
*/
316-
TVM_DLL static std::vector<std::string> ListNames();
316+
TVM_DLL static std::vector<String> ListNames();
317317

318318
// Internal class.
319319
struct Manager;
320320

321321
protected:
322322
/*! \brief name of the function */
323-
std::string name_;
323+
String name_;
324324
/*! \brief internal packed function */
325325
PackedFunc func_;
326326
friend struct Manager;

include/tvm/runtime/vm/executable.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class TVM_DLL Executable : public ModuleNode {
6464
*
6565
* \return PackedFunc or nullptr when it is not available.
6666
*/
67-
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
67+
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final;
6868

6969
/*! \brief Get the property of the runtime module .*/
7070
int GetPropertyMask() const final { return ModulePropertyMask::kBinarySerializable; };
@@ -88,7 +88,7 @@ class TVM_DLL Executable : public ModuleNode {
8888
* \param path The path to write the serialized data to.
8989
* \param format The format of the serialized blob.
9090
*/
91-
void SaveToFile(const std::string& path, const std::string& format) final;
91+
void SaveToFile(const String& path, const String& format) final;
9292

9393
/*!
9494
* \brief Serialize the executable into global section, constant section, and

include/tvm/runtime/vm/vm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class TVM_DLL VirtualMachine : public runtime::ModuleNode {
164164
* If the function needs resource from the module(e.g. late linking),
165165
* it should capture sptr_to_self.
166166
*/
167-
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
167+
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self);
168168

169169
virtual ~VirtualMachine() {}
170170

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
13101310
class AOTExecutorCodegenModule : public runtime::ModuleNode {
13111311
public:
13121312
AOTExecutorCodegenModule() {}
1313-
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
1313+
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
13141314
if (name == "init") {
13151315
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
13161316
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "

src/relay/backend/build_module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ class RelayBuildModule : public runtime::ModuleNode {
172172
* \param sptr_to_self The pointer to the module node.
173173
* \return The corresponding member function.
174174
*/
175-
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
175+
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
176176
if (name == "get_graph_json") {
177177
return PackedFunc(
178178
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetGraphJSON(); });

src/relay/backend/contrib/ethosu/source_module.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class EthosUModuleNode : public ModuleNode {
7878
* \param file_name The file to be saved to.
7979
* \param format The format of the file.
8080
*/
81-
void SaveToFile(const std::string& file_name, const std::string& format) final {
81+
void SaveToFile(const String& file_name, const String& format) final {
8282
std::string fmt = GetFileFormat(file_name, format);
8383
ICHECK_EQ(fmt, "c") << "Can only save to format="
8484
<< "c";
@@ -87,9 +87,9 @@ class EthosUModuleNode : public ModuleNode {
8787
out.close();
8888
}
8989

90-
std::string GetSource(const std::string& format) final { return c_source; }
90+
String GetSource(const String& format) final { return c_source; }
9191

92-
std::string GetFormat() override { return "c"; }
92+
String GetFormat() override { return "c"; }
9393

9494
Array<CompilationArtifact> GetArtifacts() { return compilation_artifacts_; }
9595

@@ -101,7 +101,7 @@ class EthosUModuleNode : public ModuleNode {
101101
*
102102
* \return The function pointer when it is found, otherwise, PackedFunc(nullptr).
103103
*/
104-
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
104+
PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final {
105105
if (name == "get_func_names") {
106106
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
107107
Array<String> func_names;

src/relay/backend/graph_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
628628
class GraphExecutorCodegenModule : public runtime::ModuleNode {
629629
public:
630630
GraphExecutorCodegenModule() {}
631-
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
631+
virtual PackedFunc GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) {
632632
if (name == "init") {
633633
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
634634
ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "

0 commit comments

Comments
 (0)