-
Notifications
You must be signed in to change notification settings - Fork 15.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
5 changed files
with
286 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
#include <Python.h> | ||
|
||
#include "absl/log/absl_log.h" | ||
#include "absl/strings/str_join.h" | ||
#include "absl/strings/str_split.h" | ||
#include "google/protobuf/descriptor.h" | ||
#include "google/protobuf/text_format.h" | ||
#include "google/protobuf/unittest_proto3.pb.h" | ||
#include "google/protobuf/proto_api.h" | ||
|
||
// We have to run the pure python import first, otherwise PyCapsule_Import will | ||
// fail if this module is the very first import. | ||
bool PrerequisitesSatisfied() { | ||
return true; | ||
} | ||
|
||
const google::protobuf::python::PyProto_API* GetProtoApi() { | ||
return google::protobuf::python::GetAPI(); | ||
} | ||
|
||
PyObject* pyclear_message(PyObject* self, PyObject* args) { | ||
PyObject* py_message; | ||
if (!PyArg_ParseTuple(args, "O", &py_message)) return nullptr; | ||
auto msg = GetProtoApi()->GetClearedMessageMutator(py_message); | ||
if (!msg.ok()) { | ||
PyErr_Format(PyExc_TypeError, std::string(msg.status().message()).c_str()); | ||
return nullptr; | ||
} | ||
Py_INCREF(Py_None); | ||
return Py_None; | ||
} | ||
|
||
PyObject* pyparse_message(PyObject* self, PyObject* args) { | ||
PyObject* py_message; | ||
char* text_proto; | ||
if (!PyArg_ParseTuple(args, "sO", &text_proto, &py_message)) return nullptr; | ||
auto message = GetProtoApi()->GetClearedMessageMutator(py_message); | ||
if (!message.ok()) { | ||
PyErr_Format(PyExc_TypeError, | ||
std::string(message.status().message()).c_str()); | ||
return nullptr; | ||
} | ||
if (message->get() == nullptr) { | ||
return nullptr; | ||
} | ||
// No op. Test '->' operator | ||
if ((*message)->ByteSizeLong()) { | ||
return nullptr; | ||
} | ||
|
||
ABSL_LOG(ERROR) << "Try to parse message"; | ||
ABSL_LOG(ERROR) << message.value()->GetDescriptor()->file()->name(); | ||
// ABSL_LOG(ERROR) << message.value()->GetDescriptor()->DebugString(); | ||
google::protobuf::TextFormat::Parser parser; | ||
parser.ParseFromString(text_proto, message.value().get()); | ||
Py_INCREF(Py_None); | ||
return Py_None; | ||
} | ||
|
||
PyObject* pycheck_cpp_proto(PyObject* self, PyObject* args) { | ||
PyObject* py_message; | ||
if (!PyArg_ParseTuple(args, "O", &py_message)) return nullptr; | ||
auto msg = GetProtoApi()->GetConstMessagePointer(py_message); | ||
if (!msg.ok()) { | ||
PyErr_Format(PyExc_TypeError, std::string(msg.status().message()).c_str()); | ||
return nullptr; | ||
} | ||
if (msg->get().GetDescriptor()->file()->pool() == | ||
google::protobuf::DescriptorPool::generated_pool()) { | ||
Py_RETURN_TRUE; | ||
} | ||
ABSL_LOG(ERROR) << "file is not from generated pool"; | ||
ABSL_LOG(ERROR) << msg.value().get().GetDescriptor()->file()->name(); | ||
// ABSL_LOG(ERROR) << msg.value().get().GetDescriptor()->DebugString(); | ||
Py_RETURN_FALSE; | ||
} | ||
|
||
PyObject* pymessage_get(PyObject* self, PyObject* args) { | ||
PyObject* py_message; | ||
if (!PyArg_ParseTuple(args, "O", &py_message)) return nullptr; | ||
auto message = GetProtoApi()->GetConstMessagePointer(py_message); | ||
if (!message.ok()) { | ||
PyErr_Format(PyExc_TypeError, | ||
std::string(message.status().message()).c_str()); | ||
return nullptr; | ||
} | ||
// Test move constructor. | ||
google::protobuf::python::PythonConstMessagePointer moved_msg(std::move(*message)); | ||
const proto3_unittest::TestAllTypes& msg_ptr = | ||
static_cast<const proto3_unittest::TestAllTypes&>(moved_msg.get()); | ||
return PyLong_FromLong(msg_ptr.optional_int32()); | ||
} | ||
|
||
PyObject* pymessage_mutate_const(PyObject* self, PyObject* args) { | ||
PyObject* py_message; | ||
if (!PyArg_ParseTuple(args, "O", &py_message)) return nullptr; | ||
auto message = GetProtoApi()->GetConstMessagePointer(py_message); | ||
PyObject_SetAttrString(py_message, "optional_bool", Py_True); | ||
// The message has been changed. Returns false if detect not changed. | ||
if (message->NotChanged()) { | ||
Py_RETURN_FALSE; | ||
} | ||
// Change it back before returning. | ||
PyErr_Clear(); | ||
PyObject* temp = PyObject_CallMethod(py_message, "Clear", nullptr); | ||
Py_DECREF(temp); | ||
Py_RETURN_TRUE; | ||
} | ||
|
||
PyMethodDef module_methods[] = { | ||
{"ClearMessage", reinterpret_cast<PyCFunction>(pyclear_message), | ||
METH_VARARGS, "Clear message."}, | ||
{"ParseMessage", reinterpret_cast<PyCFunction>(pyparse_message), | ||
METH_VARARGS, "Parse message."}, | ||
{"IsCppProtoLinked", reinterpret_cast<PyCFunction>(pycheck_cpp_proto), | ||
METH_VARARGS, "Check if the generated cpp proto is linked."}, | ||
{"GetOptionalInt32", reinterpret_cast<PyCFunction>(pymessage_get), | ||
METH_VARARGS, "Get optional_int32 field."}, | ||
{"MutateConstAlive", reinterpret_cast<PyCFunction>(pymessage_mutate_const), | ||
METH_VARARGS, "Mutate python message while keep a const pointer."}, | ||
{nullptr, nullptr}}; | ||
|
||
extern "C" { | ||
PyMODINIT_FUNC PyInit_proto_api_example(void) { | ||
if (!PrerequisitesSatisfied() || GetProtoApi() == nullptr) { | ||
return nullptr; | ||
} | ||
static struct PyModuleDef moduledef = { | ||
PyModuleDef_HEAD_INIT, | ||
"proto_api_example", /* m_name */ | ||
"proto_api test", /* m_doc */ | ||
-1, /* m_size */ | ||
module_methods, /* m_methods */ | ||
nullptr, /* m_reload */ | ||
nullptr, /* m_traverse */ | ||
nullptr, /* m_clear */ | ||
nullptr, /* m_free */ | ||
}; | ||
return PyModule_Create(&moduledef); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Protocol Buffers - Google's data interchange format | ||
# Copyright 2008 Google Inc. All rights reserved. | ||
# | ||
# Use of this source code is governed by a BSD-style | ||
# license that can be found in the LICENSE file or at | ||
# https://developers.google.com/open-source/licenses/bsd | ||
|
||
"""Unittest for proto_api.""" | ||
|
||
import unittest | ||
|
||
from google.protobuf.internal import more_extensions_pb2 | ||
from google.protobuf.internal import proto_api_example | ||
from google.protobuf.internal import testing_refleaks | ||
|
||
from google.protobuf import unittest_proto3_pb2 | ||
|
||
|
||
@testing_refleaks.TestCase | ||
class ProtoApiTest(unittest.TestCase): | ||
|
||
def test_message_mutator_clear_generated_factory(self): | ||
msg = unittest_proto3_pb2.TestAllTypes( | ||
optional_int32=24, optional_string='optional_string' | ||
) | ||
self.assertEqual(24, msg.optional_int32) | ||
self.assertEqual('optional_string', msg.optional_string) | ||
self.assertTrue(proto_api_example.IsCppProtoLinked(msg)) | ||
proto_api_example.ClearMessage(msg) | ||
self.assertEqual(0, msg.optional_int32) | ||
self.assertEqual('', msg.optional_string) | ||
|
||
def test_message_mutator_clear_dynamic_factory(self): | ||
msg = more_extensions_pb2.ForeignMessage(foreign_message_int=15) | ||
self.assertIn('foreign_message_int', msg) | ||
self.assertFalse(proto_api_example.IsCppProtoLinked(msg)) | ||
proto_api_example.ClearMessage(msg) | ||
self.assertNotIn('foreign_message_int', msg) | ||
|
||
def test_not_a_message(self): | ||
with self.assertRaises(TypeError): | ||
proto_api_example.IsCppProtoLinked(112) | ||
with self.assertRaises(TypeError): | ||
proto_api_example.GetOptionalInt32(True) | ||
|
||
def test_message_mutator_parse(self): | ||
msg = more_extensions_pb2.ForeignMessage(foreign_message_int=123) | ||
proto_api_example.ParseMessage('foreign_message_int: 321', msg) | ||
self.assertEqual(321, msg.foreign_message_int) | ||
return | ||
msg = unittest_proto3_pb2.TestAllTypes( | ||
optional_int32=24, optional_string='optional_string' | ||
) | ||
proto_api_example.ParseMessage('optional_string: "changed"', msg) | ||
self.assertEqual(0, msg.optional_int32) | ||
self.assertEqual(msg.optional_string, 'changed') | ||
|
||
def test_message_const_pointer_get(self): | ||
msg = unittest_proto3_pb2.TestAllTypes(optional_int32=123) | ||
self.assertEqual(123, proto_api_example.GetOptionalInt32(msg)) | ||
|
||
def test_mutate_python_message_while_const_pinter_alive(self): | ||
msg = unittest_proto3_pb2.TestAllTypes() | ||
self.assertTrue(proto_api_example.MutateConstAlive(msg)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters