Skip to content

Commit ee9662f

Browse files
feat: support transaction and request tags in dbapi (#1262)
* feat: support transaction and request tags in dbapi Adds support for setting transaction tags and request tags in dbapi. This makes these options available to frameworks that depend on dbapi, like SQLAlchemy and Django. Towards googleapis/python-spanner-sqlalchemy#525 * test: add test for transaction_tag with read-only tx * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent d9ee75a commit ee9662f

File tree

4 files changed

+277
-10
lines changed

4 files changed

+277
-10
lines changed

.gitignore

-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,3 @@ system_tests/local_test_setup
6262
# Make sure a generated file isn't accidentally committed.
6363
pylintrc
6464
pylintrc.test
65-
66-
67-
# Ignore coverage files
68-
.coverage*

google/cloud/spanner_dbapi/connection.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(self, instance, database=None, read_only=False, **kwargs):
113113
self.request_priority = None
114114
self._transaction_begin_marked = False
115115
# whether transaction started at Spanner. This means that we had
116-
# made atleast one call to Spanner.
116+
# made at least one call to Spanner.
117117
self._spanner_transaction_started = False
118118
self._batch_mode = BatchMode.NONE
119119
self._batch_dml_executor: BatchDmlExecutor = None
@@ -261,6 +261,28 @@ def request_options(self):
261261
self.request_priority = None
262262
return req_opts
263263

264+
@property
265+
def transaction_tag(self):
266+
"""The transaction tag that will be applied to the next read/write
267+
transaction on this `Connection`. This property is automatically cleared
268+
when a new transaction is started.
269+
270+
Returns:
271+
str: The transaction tag that will be applied to the next read/write transaction.
272+
"""
273+
return self._connection_variables.get("transaction_tag", None)
274+
275+
@transaction_tag.setter
276+
def transaction_tag(self, value):
277+
"""Sets the transaction tag for the next read/write transaction on this
278+
`Connection`. This property is automatically cleared when a new transaction
279+
is started.
280+
281+
Args:
282+
value (str): The transaction tag for the next read/write transaction.
283+
"""
284+
self._connection_variables["transaction_tag"] = value
285+
264286
@property
265287
def staleness(self):
266288
"""Current read staleness option value of this `Connection`.
@@ -340,6 +362,8 @@ def transaction_checkout(self):
340362
if not self.read_only and self._client_transaction_started:
341363
if not self._spanner_transaction_started:
342364
self._transaction = self._session_checkout().transaction()
365+
self._transaction.transaction_tag = self.transaction_tag
366+
self.transaction_tag = None
343367
self._snapshot = None
344368
self._spanner_transaction_started = True
345369
self._transaction.begin()
@@ -458,7 +482,9 @@ def run_prior_DDL_statements(self):
458482

459483
return self.database.update_ddl(ddl_statements).result()
460484

461-
def run_statement(self, statement: Statement):
485+
def run_statement(
486+
self, statement: Statement, request_options: RequestOptions = None
487+
):
462488
"""Run single SQL statement in begun transaction.
463489
464490
This method is never used in autocommit mode. In
@@ -472,6 +498,9 @@ def run_statement(self, statement: Statement):
472498
:param retried: (Optional) Retry the SQL statement if statement
473499
execution failed. Defaults to false.
474500
501+
:type request_options: :class:`RequestOptions`
502+
:param request_options: Request options to use for this statement.
503+
475504
:rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`,
476505
:class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum`
477506
:returns: Streamed result set of the statement and a
@@ -482,7 +511,7 @@ def run_statement(self, statement: Statement):
482511
statement.sql,
483512
statement.params,
484513
param_types=statement.param_types,
485-
request_options=self.request_options,
514+
request_options=request_options or self.request_options,
486515
)
487516

488517
@check_not_closed

google/cloud/spanner_dbapi/cursor.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
5151
from google.cloud.spanner_dbapi.utils import PeekIterator
5252
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
53+
from google.cloud.spanner_v1 import RequestOptions
5354
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
5455

5556
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
@@ -97,6 +98,39 @@ def __init__(self, connection):
9798
self._parsed_statement: ParsedStatement = None
9899
self._in_retry_mode = False
99100
self._batch_dml_rows_count = None
101+
self._request_tag = None
102+
103+
@property
104+
def request_tag(self):
105+
"""The request tag that will be applied to the next statement on this
106+
cursor. This property is automatically cleared when a statement is
107+
executed.
108+
109+
Returns:
110+
str: The request tag that will be applied to the next statement on
111+
this cursor.
112+
"""
113+
return self._request_tag
114+
115+
@request_tag.setter
116+
def request_tag(self, value):
117+
"""Sets the request tag for the next statement on this cursor. This
118+
property is automatically cleared when a statement is executed.
119+
120+
Args:
121+
value (str): The request tag for the statement.
122+
"""
123+
self._request_tag = value
124+
125+
@property
126+
def request_options(self):
127+
options = self.connection.request_options
128+
if self._request_tag:
129+
if not options:
130+
options = RequestOptions()
131+
options.request_tag = self._request_tag
132+
self._request_tag = None
133+
return options
100134

101135
@property
102136
def is_closed(self):
@@ -284,7 +318,7 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
284318
sql,
285319
params=args,
286320
param_types=self._parsed_statement.statement.param_types,
287-
request_options=self.connection.request_options,
321+
request_options=self.request_options,
288322
)
289323
self._result_set = None
290324
else:
@@ -318,7 +352,9 @@ def _execute_in_rw_transaction(self):
318352
if self.connection._client_transaction_started:
319353
while True:
320354
try:
321-
self._result_set = self.connection.run_statement(statement)
355+
self._result_set = self.connection.run_statement(
356+
statement, self.request_options
357+
)
322358
self._itr = PeekIterator(self._result_set)
323359
return
324360
except Aborted:
@@ -478,7 +514,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
478514
sql,
479515
params,
480516
get_param_types(params),
481-
request_options=self.connection.request_options,
517+
request_options=self.request_options,
482518
)
483519
# Read the first element so that the StreamedResultSet can
484520
# return the metadata after a DQL statement.

tests/mockserver_tests/test_tags.py

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud.spanner_dbapi import Connection
16+
from google.cloud.spanner_v1 import (
17+
BatchCreateSessionsRequest,
18+
ExecuteSqlRequest,
19+
BeginTransactionRequest,
20+
TypeCode,
21+
CommitRequest,
22+
)
23+
from tests.mockserver_tests.mock_server_test_base import (
24+
MockServerTestBase,
25+
add_single_result,
26+
)
27+
28+
29+
class TestTags(MockServerTestBase):
30+
@classmethod
31+
def setup_class(cls):
32+
super().setup_class()
33+
add_single_result(
34+
"select name from singers", "name", TypeCode.STRING, [("Some Singer",)]
35+
)
36+
37+
def test_select_autocommit_no_tags(self):
38+
connection = Connection(self.instance, self.database)
39+
connection.autocommit = True
40+
request = self._execute_and_verify_select_singers(connection)
41+
self.assertEqual("", request.request_options.request_tag)
42+
self.assertEqual("", request.request_options.transaction_tag)
43+
44+
def test_select_autocommit_with_request_tag(self):
45+
connection = Connection(self.instance, self.database)
46+
connection.autocommit = True
47+
request = self._execute_and_verify_select_singers(
48+
connection, request_tag="my_tag"
49+
)
50+
self.assertEqual("my_tag", request.request_options.request_tag)
51+
self.assertEqual("", request.request_options.transaction_tag)
52+
53+
def test_select_read_only_transaction_no_tags(self):
54+
connection = Connection(self.instance, self.database)
55+
connection.autocommit = False
56+
connection.read_only = True
57+
request = self._execute_and_verify_select_singers(connection)
58+
self.assertEqual("", request.request_options.request_tag)
59+
self.assertEqual("", request.request_options.transaction_tag)
60+
61+
def test_select_read_only_transaction_with_request_tag(self):
62+
connection = Connection(self.instance, self.database)
63+
connection.autocommit = False
64+
connection.read_only = True
65+
request = self._execute_and_verify_select_singers(
66+
connection, request_tag="my_tag"
67+
)
68+
self.assertEqual("my_tag", request.request_options.request_tag)
69+
self.assertEqual("", request.request_options.transaction_tag)
70+
71+
def test_select_read_only_transaction_with_transaction_tag(self):
72+
connection = Connection(self.instance, self.database)
73+
connection.autocommit = False
74+
connection.read_only = True
75+
connection.transaction_tag = "my_transaction_tag"
76+
self._execute_and_verify_select_singers(connection)
77+
self._execute_and_verify_select_singers(connection)
78+
79+
# Read-only transactions do not support tags, so the transaction_tag is
80+
# also not cleared from the connection when a read-only transaction is
81+
# executed.
82+
self.assertEqual("my_transaction_tag", connection.transaction_tag)
83+
84+
# Read-only transactions do not need to be committed or rolled back on
85+
# Spanner, but dbapi requires this to end the transaction.
86+
connection.commit()
87+
requests = self.spanner_service.requests
88+
self.assertEqual(4, len(requests))
89+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
90+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
91+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
92+
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
93+
# Transaction tags are not supported for read-only transactions.
94+
self.assertEqual("", requests[2].request_options.transaction_tag)
95+
self.assertEqual("", requests[3].request_options.transaction_tag)
96+
97+
def test_select_read_write_transaction_no_tags(self):
98+
connection = Connection(self.instance, self.database)
99+
connection.autocommit = False
100+
request = self._execute_and_verify_select_singers(connection)
101+
self.assertEqual("", request.request_options.request_tag)
102+
self.assertEqual("", request.request_options.transaction_tag)
103+
104+
def test_select_read_write_transaction_with_request_tag(self):
105+
connection = Connection(self.instance, self.database)
106+
connection.autocommit = False
107+
request = self._execute_and_verify_select_singers(
108+
connection, request_tag="my_tag"
109+
)
110+
self.assertEqual("my_tag", request.request_options.request_tag)
111+
self.assertEqual("", request.request_options.transaction_tag)
112+
113+
def test_select_read_write_transaction_with_transaction_tag(self):
114+
connection = Connection(self.instance, self.database)
115+
connection.autocommit = False
116+
connection.transaction_tag = "my_transaction_tag"
117+
# The transaction tag should be included for all statements in the transaction.
118+
self._execute_and_verify_select_singers(connection)
119+
self._execute_and_verify_select_singers(connection)
120+
121+
# The transaction tag was cleared from the connection when the transaction
122+
# was started.
123+
self.assertIsNone(connection.transaction_tag)
124+
# The commit call should also include a transaction tag.
125+
connection.commit()
126+
requests = self.spanner_service.requests
127+
self.assertEqual(5, len(requests))
128+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
129+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
130+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
131+
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
132+
self.assertTrue(isinstance(requests[4], CommitRequest))
133+
self.assertEqual(
134+
"my_transaction_tag", requests[2].request_options.transaction_tag
135+
)
136+
self.assertEqual(
137+
"my_transaction_tag", requests[3].request_options.transaction_tag
138+
)
139+
self.assertEqual(
140+
"my_transaction_tag", requests[4].request_options.transaction_tag
141+
)
142+
143+
def test_select_read_write_transaction_with_transaction_and_request_tag(self):
144+
connection = Connection(self.instance, self.database)
145+
connection.autocommit = False
146+
connection.transaction_tag = "my_transaction_tag"
147+
# The transaction tag should be included for all statements in the transaction.
148+
self._execute_and_verify_select_singers(connection, request_tag="my_tag1")
149+
self._execute_and_verify_select_singers(connection, request_tag="my_tag2")
150+
151+
# The transaction tag was cleared from the connection when the transaction
152+
# was started.
153+
self.assertIsNone(connection.transaction_tag)
154+
# The commit call should also include a transaction tag.
155+
connection.commit()
156+
requests = self.spanner_service.requests
157+
self.assertEqual(5, len(requests))
158+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
159+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
160+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
161+
self.assertTrue(isinstance(requests[3], ExecuteSqlRequest))
162+
self.assertTrue(isinstance(requests[4], CommitRequest))
163+
self.assertEqual(
164+
"my_transaction_tag", requests[2].request_options.transaction_tag
165+
)
166+
self.assertEqual("my_tag1", requests[2].request_options.request_tag)
167+
self.assertEqual(
168+
"my_transaction_tag", requests[3].request_options.transaction_tag
169+
)
170+
self.assertEqual("my_tag2", requests[3].request_options.request_tag)
171+
self.assertEqual(
172+
"my_transaction_tag", requests[4].request_options.transaction_tag
173+
)
174+
175+
def test_request_tag_is_cleared(self):
176+
connection = Connection(self.instance, self.database)
177+
connection.autocommit = True
178+
with connection.cursor() as cursor:
179+
cursor.request_tag = "my_tag"
180+
cursor.execute("select name from singers")
181+
# This query will not have a request tag.
182+
cursor.execute("select name from singers")
183+
requests = self.spanner_service.requests
184+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
185+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
186+
self.assertEqual("my_tag", requests[1].request_options.request_tag)
187+
self.assertEqual("", requests[2].request_options.request_tag)
188+
189+
def _execute_and_verify_select_singers(
190+
self, connection: Connection, request_tag: str = "", transaction_tag: str = ""
191+
) -> ExecuteSqlRequest:
192+
with connection.cursor() as cursor:
193+
if request_tag:
194+
cursor.request_tag = request_tag
195+
cursor.execute("select name from singers")
196+
result_list = cursor.fetchall()
197+
for row in result_list:
198+
self.assertEqual("Some Singer", row[0])
199+
self.assertEqual(1, len(result_list))
200+
requests = self.spanner_service.requests
201+
return next(
202+
request
203+
for request in requests
204+
if isinstance(request, ExecuteSqlRequest)
205+
and request.sql == "select name from singers"
206+
)

0 commit comments

Comments
 (0)