Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flush request in pipeline #2200

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pgconn/pgconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2093,6 +2093,22 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
p.conn.frontend.SendExecute(&pgproto3.Execute{})
}

// SendFlushRequest sends a request for the server to flush its output buffer.
//
// The server flushes its output buffer automatically as a result of Sync being called,
// or on any request when not in pipeline mode; this function is useful to cause the server
// to flush its output buffer in pipeline mode without establishing a synchronization point.
// Note that the request is not itself flushed to the server automatically; use Flush if
// necessary. This copies the behavior of libpq PQsendFlushRequest.
func (p *Pipeline) SendFlushRequest() {
if p.closed {
return
}
p.pendingSync = true

p.conn.frontend.Send(&pgproto3.Flush{})
}

// Flush flushes the queued requests without establishing a synchronization point.
func (p *Pipeline) Flush() error {
if p.closed {
Expand Down Expand Up @@ -2157,6 +2173,23 @@ func (p *Pipeline) GetResults() (results any, err error) {
return p.getResults()
}

// GetResultsNotCheckSync gets the next results. If results are present, results may be a *ResultReader, *StatementDescription,
// or *PipelineSync. If an ErrorResponse is received from the server, results will be nil and err will be a *PgError.
//
// This method should be used only if the request was sent to the server via methods SendFlushRequest and Flush,
// without using Sync. In this case, you need to identify on your own when all results are received and
// there is no need to call the method anymore.
func (p *Pipeline) GetResultsNotCheckSync() (results any, err error) {
if p.closed {
if p.err != nil {
return nil, p.err
}
return nil, errors.New("pipeline closed")
}

return p.getResults()
}

func (p *Pipeline) getResults() (results any, err error) {
for {
msg, err := p.conn.receiveMessage()
Expand Down
64 changes: 64 additions & 0 deletions pgconn/pgconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3003,6 +3003,70 @@ func TestPipelinePrepareQuery(t *testing.T) {
ensureConnValid(t, pgConn)
}

func TestPipelinePrepareQueryWithFlush(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)

pipeline := pgConn.StartPipeline(ctx)
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
pipeline.SendFlushRequest()
err = pipeline.Flush()
require.NoError(t, err)

results, err := pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, "msg", string(sd.Fields[0].Name))
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)

results, err = pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "hello", string(readResult.Rows[0][0]))

results, err = pipeline.GetResultsNotCheckSync()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))

err = pipeline.Sync()
require.NoError(t, err)

results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)

results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)

err = pipeline.Close()
require.NoError(t, err)

ensureConnValid(t, pgConn)
}

func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
t.Parallel()

Expand Down