@@ -46,7 +46,15 @@ def _get_target_class():
46
46
def _make_one (self , * args , ** kw ):
47
47
return self ._get_target_class ()(* args , ** kw )
48
48
49
- def _mock_client (self , rows = None , schema = None , num_dml_affected_rows = None ):
49
+ def _mock_client (
50
+ self ,
51
+ rows = None ,
52
+ schema = None ,
53
+ num_dml_affected_rows = None ,
54
+ default_query_job_config = None ,
55
+ dry_run_job = False ,
56
+ total_bytes_processed = 0 ,
57
+ ):
50
58
from google .cloud .bigquery import client
51
59
52
60
if rows is None :
@@ -59,8 +67,11 @@ def _mock_client(self, rows=None, schema=None, num_dml_affected_rows=None):
59
67
total_rows = total_rows ,
60
68
schema = schema ,
61
69
num_dml_affected_rows = num_dml_affected_rows ,
70
+ dry_run = dry_run_job ,
71
+ total_bytes_processed = total_bytes_processed ,
62
72
)
63
73
mock_client .list_rows .return_value = rows
74
+ mock_client ._default_query_job_config = default_query_job_config
64
75
65
76
# Assure that the REST client gets used, not the BQ Storage client.
66
77
mock_client ._create_bqstorage_client .return_value = None
@@ -95,27 +106,41 @@ def _mock_bqstorage_client(self, rows=None, stream_count=0, v1beta1=False):
95
106
)
96
107
97
108
mock_client .create_read_session .return_value = mock_read_session
109
+
98
110
mock_rows_stream = mock .MagicMock ()
99
111
mock_rows_stream .rows .return_value = iter (rows )
100
112
mock_client .read_rows .return_value = mock_rows_stream
101
113
102
114
return mock_client
103
115
104
- def _mock_job (self , total_rows = 0 , schema = None , num_dml_affected_rows = None ):
116
+ def _mock_job (
117
+ self ,
118
+ total_rows = 0 ,
119
+ schema = None ,
120
+ num_dml_affected_rows = None ,
121
+ dry_run = False ,
122
+ total_bytes_processed = 0 ,
123
+ ):
105
124
from google .cloud .bigquery import job
106
125
107
126
mock_job = mock .create_autospec (job .QueryJob )
108
127
mock_job .error_result = None
109
128
mock_job .state = "DONE"
110
- mock_job .result .return_value = mock_job
111
- mock_job ._query_results = self ._mock_results (
112
- total_rows = total_rows ,
113
- schema = schema ,
114
- num_dml_affected_rows = num_dml_affected_rows ,
115
- )
116
- mock_job .destination .to_bqstorage .return_value = (
117
- "projects/P/datasets/DS/tables/T"
118
- )
129
+ mock_job .dry_run = dry_run
130
+
131
+ if dry_run :
132
+ mock_job .result .side_effect = exceptions .NotFound
133
+ mock_job .total_bytes_processed = total_bytes_processed
134
+ else :
135
+ mock_job .result .return_value = mock_job
136
+ mock_job ._query_results = self ._mock_results (
137
+ total_rows = total_rows ,
138
+ schema = schema ,
139
+ num_dml_affected_rows = num_dml_affected_rows ,
140
+ )
141
+ mock_job .destination .to_bqstorage .return_value = (
142
+ "projects/P/datasets/DS/tables/T"
143
+ )
119
144
120
145
if num_dml_affected_rows is None :
121
146
mock_job .statement_type = None # API sends back None for SELECT
@@ -445,7 +470,27 @@ def test_execute_custom_job_id(self):
445
470
self .assertEqual (args [0 ], "SELECT 1;" )
446
471
self .assertEqual (kwargs ["job_id" ], "foo" )
447
472
448
- def test_execute_custom_job_config (self ):
473
+ def test_execute_w_default_config (self ):
474
+ from google .cloud .bigquery .dbapi import connect
475
+ from google .cloud .bigquery import job
476
+
477
+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
478
+ client = self ._mock_client (
479
+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
480
+ )
481
+ connection = connect (client )
482
+ cursor = connection .cursor ()
483
+
484
+ cursor .execute ("SELECT 1;" , job_id = "foo" )
485
+
486
+ _ , kwargs = client .query .call_args
487
+ used_config = kwargs ["job_config" ]
488
+ expected_config = job .QueryJobConfig (
489
+ use_legacy_sql = False , flatten_results = True , query_parameters = []
490
+ )
491
+ self .assertEqual (used_config ._properties , expected_config ._properties )
492
+
493
+ def test_execute_custom_job_config_wo_default_config (self ):
449
494
from google .cloud .bigquery .dbapi import connect
450
495
from google .cloud .bigquery import job
451
496
@@ -459,6 +504,29 @@ def test_execute_custom_job_config(self):
459
504
self .assertEqual (kwargs ["job_id" ], "foo" )
460
505
self .assertEqual (kwargs ["job_config" ], config )
461
506
507
+ def test_execute_custom_job_config_w_default_config (self ):
508
+ from google .cloud .bigquery .dbapi import connect
509
+ from google .cloud .bigquery import job
510
+
511
+ default_config = job .QueryJobConfig (use_legacy_sql = False , flatten_results = True )
512
+ client = self ._mock_client (
513
+ rows = [], num_dml_affected_rows = 0 , default_query_job_config = default_config
514
+ )
515
+ connection = connect (client )
516
+ cursor = connection .cursor ()
517
+ config = job .QueryJobConfig (use_legacy_sql = True )
518
+
519
+ cursor .execute ("SELECT 1;" , job_id = "foo" , job_config = config )
520
+
521
+ _ , kwargs = client .query .call_args
522
+ used_config = kwargs ["job_config" ]
523
+ expected_config = job .QueryJobConfig (
524
+ use_legacy_sql = True , # the config passed to execute() prevails
525
+ flatten_results = True , # from the default
526
+ query_parameters = [],
527
+ )
528
+ self .assertEqual (used_config ._properties , expected_config ._properties )
529
+
462
530
def test_execute_w_dml (self ):
463
531
from google .cloud .bigquery .dbapi import connect
464
532
@@ -514,6 +582,35 @@ def test_execute_w_query(self):
514
582
row = cursor .fetchone ()
515
583
self .assertIsNone (row )
516
584
585
+ def test_execute_w_query_dry_run (self ):
586
+ from google .cloud .bigquery .job import QueryJobConfig
587
+ from google .cloud .bigquery .schema import SchemaField
588
+ from google .cloud .bigquery import dbapi
589
+
590
+ connection = dbapi .connect (
591
+ self ._mock_client (
592
+ rows = [("hello" , "world" , 1 ), ("howdy" , "y'all" , 2 )],
593
+ schema = [
594
+ SchemaField ("a" , "STRING" , mode = "NULLABLE" ),
595
+ SchemaField ("b" , "STRING" , mode = "REQUIRED" ),
596
+ SchemaField ("c" , "INTEGER" , mode = "NULLABLE" ),
597
+ ],
598
+ dry_run_job = True ,
599
+ total_bytes_processed = 12345 ,
600
+ )
601
+ )
602
+ cursor = connection .cursor ()
603
+
604
+ cursor .execute (
605
+ "SELECT a, b, c FROM hello_world WHERE d > 3;" ,
606
+ job_config = QueryJobConfig (dry_run = True ),
607
+ )
608
+
609
+ self .assertEqual (cursor .rowcount , 0 )
610
+ self .assertIsNone (cursor .description )
611
+ rows = cursor .fetchall ()
612
+ self .assertEqual (list (rows ), [])
613
+
517
614
def test_execute_raises_if_result_raises (self ):
518
615
import google .cloud .exceptions
519
616
@@ -523,8 +620,10 @@ def test_execute_raises_if_result_raises(self):
523
620
from google .cloud .bigquery .dbapi import exceptions
524
621
525
622
job = mock .create_autospec (job .QueryJob )
623
+ job .dry_run = None
526
624
job .result .side_effect = google .cloud .exceptions .GoogleCloudError ("" )
527
625
client = mock .create_autospec (client .Client )
626
+ client ._default_query_job_config = None
528
627
client .query .return_value = job
529
628
connection = connect (client )
530
629
cursor = connection .cursor ()
0 commit comments