@@ -44,7 +44,7 @@ type contextTransaction interface {
44
44
Commit () error
45
45
Rollback () error
46
46
resetForRetry (ctx context.Context ) error
47
- Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error )
47
+ Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error )
48
48
partitionQuery (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (driver.Rows , error )
49
49
ExecContext (ctx context.Context , stmt spanner.Statement , statementInfo * parser.StatementInfo , options spanner.QueryOptions ) (* result , error )
50
50
@@ -67,6 +67,7 @@ var _ rowIterator = &readOnlyRowIterator{}
67
67
68
68
type readOnlyRowIterator struct {
69
69
* spanner.RowIterator
70
+ stmtType parser.StatementType
70
71
}
71
72
72
73
func (ri * readOnlyRowIterator ) Next () (* spanner.Row , error ) {
@@ -84,10 +85,13 @@ func (ri *readOnlyRowIterator) Metadata() (*sppb.ResultSetMetadata, error) {
84
85
func (ri * readOnlyRowIterator ) ResultSetStats () * sppb.ResultSetStats {
85
86
// TODO: The Spanner client library should offer an option to get the full
86
87
// ResultSetStats, instead of only the RowCount and QueryPlan.
87
- return & sppb.ResultSetStats {
88
- RowCount : & sppb.ResultSetStats_RowCountExact {RowCountExact : ri .RowIterator .RowCount },
88
+ stats := & sppb.ResultSetStats {
89
89
QueryPlan : ri .RowIterator .QueryPlan ,
90
90
}
91
+ if ri .stmtType == parser .StatementTypeDml {
92
+ stats .RowCount = & sppb.ResultSetStats_RowCountExact {RowCountExact : ri .RowIterator .RowCount }
93
+ }
94
+ return stats
91
95
}
92
96
93
97
type txResult int
@@ -135,7 +139,7 @@ func (tx *readOnlyTransaction) resetForRetry(ctx context.Context) error {
135
139
return nil
136
140
}
137
141
138
- func (tx * readOnlyTransaction ) Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error ) {
142
+ func (tx * readOnlyTransaction ) Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error ) {
139
143
tx .logger .DebugContext (ctx , "Query" , "stmt" , stmt .SQL )
140
144
if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
141
145
if tx .boTx == nil {
@@ -152,7 +156,7 @@ func (tx *readOnlyTransaction) Query(ctx context.Context, stmt spanner.Statement
152
156
}
153
157
return mi , nil
154
158
}
155
- return & readOnlyRowIterator {tx .roTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions )}, nil
159
+ return & readOnlyRowIterator {tx .roTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions ), stmtType }, nil
156
160
}
157
161
158
162
func (tx * readOnlyTransaction ) partitionQuery (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (driver.Rows , error ) {
@@ -456,7 +460,7 @@ func (tx *readWriteTransaction) resetForRetry(ctx context.Context) error {
456
460
// Query executes a query using the read/write transaction and returns a
457
461
// rowIterator that will automatically retry the read/write transaction if the
458
462
// transaction is aborted during the query or while iterating the returned rows.
459
- func (tx * readWriteTransaction ) Query (ctx context.Context , stmt spanner.Statement , execOptions * ExecOptions ) (rowIterator , error ) {
463
+ func (tx * readWriteTransaction ) Query (ctx context.Context , stmt spanner.Statement , stmtType parser. StatementType , execOptions * ExecOptions ) (rowIterator , error ) {
460
464
tx .logger .Debug ("Query" , "stmt" , stmt .SQL )
461
465
tx .active = true
462
466
if err := tx .maybeRunAutoDmlBatch (ctx ); err != nil {
@@ -465,7 +469,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
465
469
// If internal retries have been disabled, we don't need to keep track of a
466
470
// running checksum for all results that we have seen.
467
471
if ! tx .retryAborts () {
468
- return & readOnlyRowIterator {tx .rwTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions )}, nil
472
+ return & readOnlyRowIterator {tx .rwTx .QueryWithOptions (ctx , stmt , execOptions .QueryOptions ), stmtType }, nil
469
473
}
470
474
471
475
// If retries are enabled, we need to use a row iterator that will keep
@@ -476,6 +480,7 @@ func (tx *readWriteTransaction) Query(ctx context.Context, stmt spanner.Statemen
476
480
ctx : ctx ,
477
481
tx : tx ,
478
482
stmt : stmt ,
483
+ stmtType : stmtType ,
479
484
options : execOptions .QueryOptions ,
480
485
buffer : buffer ,
481
486
enc : gob .NewEncoder (buffer ),
0 commit comments