Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better func names, fix ref counting and memory free, add tests #6

Merged
merged 1 commit into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions _demo/autoderef/autoderef.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func main() {
gp.Initialize()
defer gp.Finalize()
fooMod := foo.InitFooModule()
gp.GetModuleDict().Set(gp.MakeStr("foo").Object, fooMod.Object)
gp.GetModuleDict().SetString("foo", fooMod)

Main1(fooMod)
Main2()
Expand All @@ -22,21 +22,21 @@ func main() {

func Main1(fooMod gp.Module) {
fmt.Printf("=========== Main1 ==========\n")
sum := fooMod.Call("add", gp.MakeLong(1), gp.MakeLong(2)).AsLong()
sum := fooMod.Call("add", 1, 2).AsLong()
fmt.Printf("Sum of 1 + 2: %d\n", sum.Int64())

dict := fooMod.Dict()
Point := dict.Get(gp.MakeStr("Point")).AsFunc()

point := Point.Call(gp.MakeLong(3), gp.MakeLong(4))
point := Point.Call(3, 4)
fmt.Printf("dir(point): %v\n", point.Dir())
fmt.Printf("x: %v, y: %v\n", point.GetAttr("x"), point.GetAttr("y"))
fmt.Printf("x: %v, y: %v\n", point.Attr("x"), point.Attr("y"))

distance := point.Call("distance").AsFloat()
fmt.Printf("Distance of 3 * 4: %f\n", distance.Float64())

point.Call("move", gp.MakeFloat(1), gp.MakeFloat(2))
fmt.Printf("x: %v, y: %v\n", point.GetAttr("x"), point.GetAttr("y"))
point.Call("move", 1, 2)
fmt.Printf("x: %v, y: %v\n", point.Attr("x"), point.Attr("y"))

distance = point.Call("distance").AsFloat()
fmt.Printf("Distance of 4 * 6: %f\n", distance.Float64())
Expand All @@ -45,7 +45,7 @@ func Main1(fooMod gp.Module) {

func Main2() {
fmt.Printf("=========== Main2 ==========\n")
gp.RunString(`
_ = gp.RunString(`
import foo
point = foo.Point(3, 4)
print("dir(point):", dir(point))
Expand Down Expand Up @@ -92,7 +92,7 @@ for i in range(10):
fmt.Printf("Iteration %d in python\n", i+1)
}

memory_allocation_test := mod.GetFuncAttr("memory_allocation_test")
memory_allocation_test := mod.AttrFunc("memory_allocation_test")

for i := 0; i < 100; i++ {
// 100MB every time
Expand Down
2 changes: 1 addition & 1 deletion _demo/gradio/gradio.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func main() {
})
textbox := gr.Call("Textbox")
examples := gr.Call("Examples", [][]string{{"Chicago"}, {"Little Rock"}, {"San Francisco"}}, textbox)
dataset := examples.GetAttr("dataset")
dataset := examples.Attr("dataset")
dropdown.Call("change", fn, dropdown, dataset)
})
demo.Call("launch")
Expand Down
2 changes: 1 addition & 1 deletion _demo/plot/plot.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ func main() {
gp.Initialize()
defer gp.Finalize()
plt := Plt()
plt.Plot(gp.MakeTuple(5, 10), gp.MakeTuple(10, 15), gp.KwArgs{"color": "red"})
plt.Plot([]int{5, 10}, []int{10, 15}, gp.KwArgs{"color": "red"})
plt.Show()
}
4 changes: 4 additions & 0 deletions adap_go.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ func AllocCStr(s string) *C.char {
return C.CString(s)
}

func AllocCStrDontFree(s string) *C.char {
return C.CString(s)
}

func AllocWCStr(s string) *C.wchar_t {
runes := []rune(s)
wchars := make([]uint16, len(runes)+1)
Expand Down
8 changes: 5 additions & 3 deletions bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ func BytesFromStr(s string) Bytes {

func MakeBytes(bytes []byte) Bytes {
ptr := C.CBytes(bytes)
return newBytes(C.PyBytes_FromStringAndSize((*C.char)(ptr), C.Py_ssize_t(len(bytes))))
o := C.PyBytes_FromStringAndSize((*C.char)(ptr), C.Py_ssize_t(len(bytes)))
C.free(unsafe.Pointer(ptr))
return newBytes(o)
}

func (b Bytes) Bytes() []byte {
var p *byte
var l int
p := (*byte)(unsafe.Pointer(C.PyBytes_AsString(b.obj)))
l := int(C.PyBytes_Size(b.obj))
return C.GoBytes(unsafe.Pointer(p), C.int(l))
}

Expand Down
36 changes: 24 additions & 12 deletions dict.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
#include <Python.h>
*/
import "C"
import "fmt"
import (
"fmt"
"unsafe"
)

type Dict struct {
Object
Expand Down Expand Up @@ -47,26 +50,32 @@
return newObject(v)
}

func (d Dict) Set(key, value Object) {
C.Py_IncRef(key.obj)
C.Py_IncRef(value.obj)
C.PyDict_SetItem(d.obj, key.obj, value.obj)
func (d Dict) Set(key, value Objecter) {
keyObj := key.Obj()
valueObj := value.Obj()
C.PyDict_SetItem(d.obj, keyObj, valueObj)
}

func (d Dict) SetString(key string, value Object) {
C.Py_IncRef(value.obj)
C.PyDict_SetItemString(d.obj, AllocCStr(key), value.obj)
func (d Dict) SetString(key string, value Objecter) {
valueObj := value.Obj()
ckey := AllocCStr(key)
r := C.PyDict_SetItemString(d.obj, ckey, valueObj)
C.free(unsafe.Pointer(ckey))
if r != 0 {
panic(fmt.Errorf("failed to set item string: %v", r))

Check warning on line 65 in dict.go

View check run for this annotation

Codecov / codecov/patch

dict.go#L59-L65

Added lines #L59 - L65 were not covered by tests
}
}

func (d Dict) GetString(key string) Object {
v := C.PyDict_GetItemString(d.obj, AllocCStr(key))
ckey := AllocCStr(key)
v := C.PyDict_GetItemString(d.obj, ckey)

Check warning on line 71 in dict.go

View check run for this annotation

Codecov / codecov/patch

dict.go#L70-L71

Added lines #L70 - L71 were not covered by tests
C.Py_IncRef(v)
C.free(unsafe.Pointer(ckey))

Check warning on line 73 in dict.go

View check run for this annotation

Codecov / codecov/patch

dict.go#L73

Added line #L73 was not covered by tests
return newObject(v)
}

func (d Dict) Del(key Object) {
C.PyDict_DelItem(d.obj, key.obj)
C.Py_DecRef(key.obj)
func (d Dict) Del(key Objecter) {
C.PyDict_DelItem(d.obj, key.Obj())

Check warning on line 78 in dict.go

View check run for this annotation

Codecov / codecov/patch

dict.go#L77-L78

Added lines #L77 - L78 were not covered by tests
}

func (d Dict) ForEach(fn func(key, value Object)) {
Expand All @@ -84,6 +93,9 @@
C.Py_IncRef(item)
key := C.PyTuple_GetItem(item, 0)
value := C.PyTuple_GetItem(item, 1)
C.Py_IncRef(key)
C.Py_IncRef(value)
C.Py_DecRef(item)
fn(newObject(key), newObject(value))
}
}
2 changes: 1 addition & 1 deletion float.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
}

func (f Float) IsInteger() Bool {
fn := Cast[Func](f.GetAttr("is_integer"))
fn := Cast[Func](f.Attr("is_integer"))

Check warning on line 25 in float.go

View check run for this annotation

Codecov / codecov/patch

float.go#L25

Added line #L25 was not covered by tests
return Cast[Bool](fn.callNoArgs())
}
22 changes: 10 additions & 12 deletions function.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
}

//export getterMethod
func getterMethod(self *C.PyObject, closure unsafe.Pointer, methodId C.int) *C.PyObject {
func getterMethod(self *C.PyObject, _closure unsafe.Pointer, methodId C.int) *C.PyObject {
typeMeta := typeMetaMap[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
if typeMeta == nil {
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
Expand All @@ -177,7 +177,7 @@
}

//export setterMethod
func setterMethod(self, value *C.PyObject, closure unsafe.Pointer, methodId C.int) C.int {
func setterMethod(self, value *C.PyObject, _closure unsafe.Pointer, methodId C.int) C.int {
typeMeta := typeMetaMap[(*C.PyObject)(unsafe.Pointer(self.ob_type))]
if typeMeta == nil {
SetError(fmt.Errorf("type %v not registered", FromPy(self)))
Expand Down Expand Up @@ -260,6 +260,7 @@

for i := 0; i < int(argc); i++ {
arg := C.PyTuple_GetItem(args, C.Py_ssize_t(i))
C.Py_IncRef(arg)
argType := methodType.In(i + argIndex)
argPy := FromPy(arg)
goValue := reflect.New(argType).Elem()
Expand Down Expand Up @@ -316,7 +317,7 @@
methodPtr := C.wrapperMethods[methodId]

ret = append(ret, C.PyMethodDef{
ml_name: C.CString(pythonName),
ml_name: AllocCStrDontFree(pythonName),
ml_meth: (C.PyCFunction)(unsafe.Pointer(methodPtr)),
ml_flags: C.METH_VARARGS,
ml_doc: nil,
Expand Down Expand Up @@ -402,7 +403,7 @@
if memberType != -1 {
// create as member variable for C-compatible types
membersList = append(membersList, C.PyMemberDef{
name: C.CString(pythonName),
name: AllocCStrDontFree(pythonName),
_type: memberType,
offset: C.Py_ssize_t(baseOffset + field.Offset),
})
Expand All @@ -429,7 +430,7 @@
index: i,
}
getsetsList = append(getsetsList, C.PyGetSetDef{
name: C.CString(pythonName),
name: AllocCStrDontFree(pythonName),
get: C.getterMethods[getId],
set: C.setterMethods[setId],
doc: nil,
Expand Down Expand Up @@ -480,9 +481,6 @@
methods: make(map[uint]*slotMeta),
}

cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))

slots := make([]C.PyType_Slot, 0)
if init != nil {
slots = append(slots, C.PyType_Slot{slot: C.Py_tp_init, pfunc: unsafe.Pointer(C.wrapperInit)})
Expand Down Expand Up @@ -512,7 +510,7 @@
}

spec := &C.PyType_Spec{
name: cname,
name: C.CString(name),
basicsize: C.int(unsafe.Sizeof(wrapper)),
flags: C.Py_TPFLAGS_DEFAULT,
slots: slotsPtr,
Expand All @@ -526,7 +524,7 @@
typeMetaMap[typeObj] = meta
pyTypeMap[ty] = typeObj

if C.PyModule_AddObject(m.obj, cname, typeObj) < 0 {
if C.PyModule_AddObject(m.obj, C.CString(name), typeObj) < 0 {
C.Py_DecRef(typeObj)
panic(fmt.Sprintf("Failed to add type %s to module", name))
}
Expand Down Expand Up @@ -599,14 +597,14 @@

func SetError(err error) {
errStr := C.CString(err.Error())
defer C.free(unsafe.Pointer(errStr))
C.PyErr_SetString(C.PyExc_RuntimeError, errStr)
C.free(unsafe.Pointer(errStr))

Check warning on line 601 in function.go

View check run for this annotation

Codecov / codecov/patch

function.go#L601

Added line #L601 was not covered by tests
}

func SetTypeError(err error) {
errStr := C.CString(err.Error())
defer C.free(unsafe.Pointer(errStr))
C.PyErr_SetString(C.PyExc_TypeError, errStr)
C.free(unsafe.Pointer(errStr))
}

// FetchError returns the current Python error as a Go error
Expand Down
10 changes: 6 additions & 4 deletions list.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
return newObject(v)
}

func (l List) SetItem(index int, item Object) {
C.PyList_SetItem(l.obj, C.Py_ssize_t(index), item.obj)
func (l List) SetItem(index int, item Objecter) {
itemObj := item.Obj()
C.Py_IncRef(itemObj)
C.PyList_SetItem(l.obj, C.Py_ssize_t(index), itemObj)
}

func (l List) Len() int {
return int(C.PyList_Size(l.obj))
}

func (l List) Append(obj Object) {
C.PyList_Append(l.obj, obj.obj)
func (l List) Append(obj Objecter) {
C.PyList_Append(l.obj, obj.Obj())

Check warning on line 42 in list.go

View check run for this annotation

Codecov / codecov/patch

list.go#L41-L42

Added lines #L41 - L42 were not covered by tests
}
5 changes: 4 additions & 1 deletion long.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <Python.h>
*/
import "C"
import "unsafe"

type Long struct {
Object
Expand Down Expand Up @@ -35,7 +36,9 @@

func LongFromString(s string, base int) Long {
cstr := AllocCStr(s)
return newLong(C.PyLong_FromString(cstr, nil, C.int(base)))
o := C.PyLong_FromString(cstr, nil, C.int(base))
C.free(unsafe.Pointer(cstr))
return newLong(o)

Check warning on line 41 in long.go

View check run for this annotation

Codecov / codecov/patch

long.go#L39-L41

Added lines #L39 - L41 were not covered by tests
}

func LongFromUnicode(u Object, base int) Long {
Expand Down
13 changes: 10 additions & 3 deletions module.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <Python.h>
*/
import "C"
import "unsafe"

type Module struct {
Object
Expand All @@ -14,7 +15,9 @@
}

func ImportModule(name string) Module {
mod := C.PyImport_ImportModule(AllocCStr(name))
cname := AllocCStr(name)
mod := C.PyImport_ImportModule(cname)
C.free(unsafe.Pointer(cname))
return newModule(mod)
}

Expand All @@ -27,11 +30,15 @@
}

func (m Module) AddObject(name string, obj Object) int {
return int(C.PyModule_AddObject(m.obj, AllocCStr(name), obj.obj))
cname := AllocCStr(name)
r := int(C.PyModule_AddObject(m.obj, cname, obj.obj))
C.free(unsafe.Pointer(cname))
return r

Check warning on line 36 in module.go

View check run for this annotation

Codecov / codecov/patch

module.go#L33-L36

Added lines #L33 - L36 were not covered by tests
}

func CreateModule(name string) Module {
return newModule(C.PyModule_New(AllocCStr(name)))
mod := C.PyModule_New(AllocCStrDontFree(name))
return newModule(mod)

Check warning on line 41 in module.go

View check run for this annotation

Codecov / codecov/patch

module.go#L40-L41

Added lines #L40 - L41 were not covered by tests
}

func GetModuleDict() Dict {
Expand Down
Loading
Loading