Skip to content

Commit bc33a67

Browse files
authored
fix: dry run queries with DB API cursor (#128)
* fix: dry run queries with DB API cursor * Fix a merge errors with master * Return no rows on dry run instead of processed bytes count
1 parent 3235255 commit bc33a67

File tree

3 files changed

+164
-25
lines changed

3 files changed

+164
-25
lines changed

google/cloud/bigquery/dbapi/cursor.py

+37-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Cursor for the Google BigQuery DB-API."""
1616

1717
import collections
18+
import copy
1819
import warnings
1920

2021
try:
@@ -93,18 +94,16 @@ def _set_description(self, schema):
9394
return
9495

9596
self.description = tuple(
96-
[
97-
Column(
98-
name=field.name,
99-
type_code=field.field_type,
100-
display_size=None,
101-
internal_size=None,
102-
precision=None,
103-
scale=None,
104-
null_ok=field.is_nullable,
105-
)
106-
for field in schema
107-
]
97+
Column(
98+
name=field.name,
99+
type_code=field.field_type,
100+
display_size=None,
101+
internal_size=None,
102+
precision=None,
103+
scale=None,
104+
null_ok=field.is_nullable,
105+
)
106+
for field in schema
108107
)
109108

110109
def _set_rowcount(self, query_results):
@@ -173,12 +172,24 @@ def execute(self, operation, parameters=None, job_id=None, job_config=None):
173172
formatted_operation = _format_operation(operation, parameters=parameters)
174173
query_parameters = _helpers.to_query_parameters(parameters)
175174

176-
config = job_config or job.QueryJobConfig(use_legacy_sql=False)
175+
if client._default_query_job_config:
176+
if job_config:
177+
config = job_config._fill_from_default(client._default_query_job_config)
178+
else:
179+
config = copy.deepcopy(client._default_query_job_config)
180+
else:
181+
config = job_config or job.QueryJobConfig(use_legacy_sql=False)
182+
177183
config.query_parameters = query_parameters
178184
self._query_job = client.query(
179185
formatted_operation, job_config=config, job_id=job_id
180186
)
181187

188+
if self._query_job.dry_run:
189+
self._set_description(schema=None)
190+
self.rowcount = 0
191+
return
192+
182193
# Wait for the query to finish.
183194
try:
184195
self._query_job.result()
@@ -211,6 +222,10 @@ def _try_fetch(self, size=None):
211222
"No query results: execute() must be called before fetch."
212223
)
213224

225+
if self._query_job.dry_run:
226+
self._query_data = iter([])
227+
return
228+
214229
is_dml = (
215230
self._query_job.statement_type
216231
and self._query_job.statement_type.upper() != "SELECT"
@@ -307,6 +322,9 @@ def _bqstorage_fetch(self, bqstorage_client):
307322
def fetchone(self):
308323
"""Fetch a single row from the results of the last ``execute*()`` call.
309324
325+
.. note::
326+
If a dry run query was executed, no rows are returned.
327+
310328
Returns:
311329
Tuple:
312330
A tuple representing a row or ``None`` if no more data is
@@ -324,6 +342,9 @@ def fetchone(self):
324342
def fetchmany(self, size=None):
325343
"""Fetch multiple results from the last ``execute*()`` call.
326344
345+
.. note::
346+
If a dry run query was executed, no rows are returned.
347+
327348
.. note::
328349
The size parameter is not used for the request/response size.
329350
Set the ``arraysize`` attribute before calling ``execute()`` to
@@ -360,6 +381,9 @@ def fetchmany(self, size=None):
360381
def fetchall(self):
361382
"""Fetch all remaining results from the last ``execute*()`` call.
362383
384+
.. note::
385+
If a dry run query was executed, no rows are returned.
386+
363387
Returns:
364388
List[Tuple]: A list of all the rows in the results.
365389

tests/system.py

+16
Original file line numberDiff line numberDiff line change
@@ -1782,6 +1782,22 @@ def test_dbapi_fetch_w_bqstorage_client_v1beta1_large_result_set(self):
17821782
]
17831783
self.assertEqual(fetched_data, expected_data)
17841784

1785+
def test_dbapi_dry_run_query(self):
1786+
from google.cloud.bigquery.job import QueryJobConfig
1787+
1788+
query = """
1789+
SELECT country_name
1790+
FROM `bigquery-public-data.utility_us.country_code_iso`
1791+
WHERE country_name LIKE 'U%'
1792+
"""
1793+
1794+
Config.CURSOR.execute(query, job_config=QueryJobConfig(dry_run=True))
1795+
self.assertEqual(Config.CURSOR.rowcount, 0, "expected no rows")
1796+
1797+
rows = Config.CURSOR.fetchall()
1798+
1799+
self.assertEqual(list(rows), [])
1800+
17851801
@unittest.skipIf(
17861802
bigquery_storage_v1 is None, "Requires `google-cloud-bigquery-storage`"
17871803
)

tests/unit/test_dbapi_cursor.py

+111-12
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ def _get_target_class():
4646
def _make_one(self, *args, **kw):
4747
return self._get_target_class()(*args, **kw)
4848

49-
def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
49+
def _mock_client(
50+
self,
51+
rows=None,
52+
schema=None,
53+
num_dml_affected_rows=None,
54+
default_query_job_config=None,
55+
dry_run_job=False,
56+
total_bytes_processed=0,
57+
):
5058
from google.cloud.bigquery import client
5159

5260
if rows is None:
@@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
5967
total_rows=total_rows,
6068
schema=schema,
6169
num_dml_affected_rows=num_dml_affected_rows,
70+
dry_run=dry_run_job,
71+
total_bytes_processed=total_bytes_processed,
6272
)
6373
mock_client.list_rows.return_value = rows
74+
mock_client._default_query_job_config = default_query_job_config
6475

6576
# Assure that the REST client gets used, not the BQ Storage client.
6677
mock_client._create_bqstorage_client.return_value = None
@@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False):
95106
)
96107

97108
mock_client.create_read_session.return_value = mock_read_session
109+
98110
mock_rows_stream = mock.MagicMock()
99111
mock_rows_stream.rows.return_value = iter(rows)
100112
mock_client.read_rows.return_value = mock_rows_stream
101113

102114
return mock_client
103115

104-
def _mock_job(self, total_rows=0, schema=None, num_dml_affected_rows=None):
116+
def _mock_job(
117+
self,
118+
total_rows=0,
119+
schema=None,
120+
num_dml_affected_rows=None,
121+
dry_run=False,
122+
total_bytes_processed=0,
123+
):
105124
from google.cloud.bigquery import job
106125

107126
mock_job = mock.create_autospec(job.QueryJob)
108127
mock_job.error_result = None
109128
mock_job.state = "DONE"
110-
mock_job.result.return_value = mock_job
111-
mock_job._query_results = self._mock_results(
112-
total_rows=total_rows,
113-
schema=schema,
114-
num_dml_affected_rows=num_dml_affected_rows,
115-
)
116-
mock_job.destination.to_bqstorage.return_value = (
117-
"projects/P/datasets/DS/tables/T"
118-
)
129+
mock_job.dry_run = dry_run
130+
131+
if dry_run:
132+
mock_job.result.side_effect = exceptions.NotFound
133+
mock_job.total_bytes_processed = total_bytes_processed
134+
else:
135+
mock_job.result.return_value = mock_job
136+
mock_job._query_results = self._mock_results(
137+
total_rows=total_rows,
138+
schema=schema,
139+
num_dml_affected_rows=num_dml_affected_rows,
140+
)
141+
mock_job.destination.to_bqstorage.return_value = (
142+
"projects/P/datasets/DS/tables/T"
143+
)
119144

120145
if num_dml_affected_rows is None:
121146
mock_job.statement_type = None # API sends back None for SELECT
@@ -445,7 +470,27 @@ def test_execute_custom_job_id(self):
445470
self.assertEqual(args[0], "SELECT 1;")
446471
self.assertEqual(kwargs["job_id"], "foo")
447472

448-
def test_execute_custom_job_config(self):
473+
def test_execute_w_default_config(self):
474+
from google.cloud.bigquery.dbapi import connect
475+
from google.cloud.bigquery import job
476+
477+
default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
478+
client = self._mock_client(
479+
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
480+
)
481+
connection = connect(client)
482+
cursor = connection.cursor()
483+
484+
cursor.execute("SELECT 1;", job_id="foo")
485+
486+
_, kwargs = client.query.call_args
487+
used_config = kwargs["job_config"]
488+
expected_config = job.QueryJobConfig(
489+
use_legacy_sql=False, flatten_results=True, query_parameters=[]
490+
)
491+
self.assertEqual(used_config._properties, expected_config._properties)
492+
493+
def test_execute_custom_job_config_wo_default_config(self):
449494
from google.cloud.bigquery.dbapi import connect
450495
from google.cloud.bigquery import job
451496

@@ -459,6 +504,29 @@ def test_execute_custom_job_config(self):
459504
self.assertEqual(kwargs["job_id"], "foo")
460505
self.assertEqual(kwargs["job_config"], config)
461506

507+
def test_execute_custom_job_config_w_default_config(self):
508+
from google.cloud.bigquery.dbapi import connect
509+
from google.cloud.bigquery import job
510+
511+
default_config = job.QueryJobConfig(use_legacy_sql=False, flatten_results=True)
512+
client = self._mock_client(
513+
rows=[], num_dml_affected_rows=0, default_query_job_config=default_config
514+
)
515+
connection = connect(client)
516+
cursor = connection.cursor()
517+
config = job.QueryJobConfig(use_legacy_sql=True)
518+
519+
cursor.execute("SELECT 1;", job_id="foo", job_config=config)
520+
521+
_, kwargs = client.query.call_args
522+
used_config = kwargs["job_config"]
523+
expected_config = job.QueryJobConfig(
524+
use_legacy_sql=True, # the config passed to execute() prevails
525+
flatten_results=True, # from the default
526+
query_parameters=[],
527+
)
528+
self.assertEqual(used_config._properties, expected_config._properties)
529+
462530
def test_execute_w_dml(self):
463531
from google.cloud.bigquery.dbapi import connect
464532

@@ -514,6 +582,35 @@ def test_execute_w_query(self):
514582
row = cursor.fetchone()
515583
self.assertIsNone(row)
516584

585+
def test_execute_w_query_dry_run(self):
586+
from google.cloud.bigquery.job import QueryJobConfig
587+
from google.cloud.bigquery.schema import SchemaField
588+
from google.cloud.bigquery import dbapi
589+
590+
connection = dbapi.connect(
591+
self._mock_client(
592+
rows=[("hello", "world", 1), ("howdy", "y'all", 2)],
593+
schema=[
594+
SchemaField("a", "STRING", mode="NULLABLE"),
595+
SchemaField("b", "STRING", mode="REQUIRED"),
596+
SchemaField("c", "INTEGER", mode="NULLABLE"),
597+
],
598+
dry_run_job=True,
599+
total_bytes_processed=12345,
600+
)
601+
)
602+
cursor = connection.cursor()
603+
604+
cursor.execute(
605+
"SELECT a, b, c FROM hello_world WHERE d > 3;",
606+
job_config=QueryJobConfig(dry_run=True),
607+
)
608+
609+
self.assertEqual(cursor.rowcount, 0)
610+
self.assertIsNone(cursor.description)
611+
rows = cursor.fetchall()
612+
self.assertEqual(list(rows), [])
613+
517614
def test_execute_raises_if_result_raises(self):
518615
import google.cloud.exceptions
519616

@@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self):
523620
from google.cloud.bigquery.dbapi import exceptions
524621

525622
job = mock.create_autospec(job.QueryJob)
623+
job.dry_run = None
526624
job.result.side_effect = google.cloud.exceptions.GoogleCloudError("")
527625
client = mock.create_autospec(client.Client)
626+
client._default_query_job_config = None
528627
client.query.return_value = job
529628
connection = connect(client)
530629
cursor = connection.cursor()

0 commit comments

Comments
 (0)