Skip to content

Commit 8ab81de

Browse files
feat: read_gbq creates order deterministically without table copy (#191)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 741c75e commit 8ab81de

File tree

18 files changed

+438
-519
lines changed

18 files changed

+438
-519
lines changed

bigframes/dataframe.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2719,7 +2719,8 @@ def _get_block(self) -> blocks.Block:
27192719
return self._block
27202720

27212721
def _cached(self) -> DataFrame:
2722-
return DataFrame(self._block.cached())
2722+
self._set_block(self._block.cached())
2723+
return self
27232724

27242725
_DataFrameOrSeries = typing.TypeVar("_DataFrameOrSeries")
27252726

bigframes/ml/core.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def generate_text_embedding(
126126

127127
def forecast(self) -> bpd.DataFrame:
128128
sql = self._model_manipulation_sql_generator.ml_forecast()
129-
return self._session.read_gbq(sql)
129+
return self._session.read_gbq(sql, index_col="forecast_timestamp").reset_index()
130130

131131
def evaluate(self, input_data: Optional[bpd.DataFrame] = None):
132132
# TODO: validate input data schema
@@ -139,14 +139,18 @@ def centroids(self) -> bpd.DataFrame:
139139

140140
sql = self._model_manipulation_sql_generator.ml_centroids()
141141

142-
return self._session.read_gbq(sql)
142+
return self._session.read_gbq(
143+
sql, index_col=["centroid_id", "feature"]
144+
).reset_index()
143145

144146
def principal_components(self) -> bpd.DataFrame:
145147
assert self._model.model_type == "PCA"
146148

147149
sql = self._model_manipulation_sql_generator.ml_principal_components()
148150

149-
return self._session.read_gbq(sql)
151+
return self._session.read_gbq(
152+
sql, index_col=["principal_component_id", "feature"]
153+
).reset_index()
150154

151155
def principal_component_info(self) -> bpd.DataFrame:
152156
assert self._model.model_type == "PCA"
@@ -228,10 +232,12 @@ def create_model(
228232
Returns: a BqmlModel, wrapping a trained model in BigQuery
229233
"""
230234
options = dict(options)
235+
# Cache dataframes to make sure base table is not a snapshot
236+
# cached dataframe creates a full copy, never uses snapshot
231237
if y_train is None:
232-
input_data = X_train
238+
input_data = X_train._cached()
233239
else:
234-
input_data = X_train.join(y_train, how="outer")
240+
input_data = X_train._cached().join(y_train._cached(), how="outer")
235241
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
236242

237243
session = X_train._session
@@ -259,7 +265,9 @@ def create_time_series_model(
259265
), "Time stamp data input must only contain 1 column."
260266

261267
options = dict(options)
262-
input_data = X_train.join(y_train, how="outer")
268+
# Cache dataframes to make sure base table is not a snapshot
269+
# cached dataframe creates a full copy, never uses snapshot
270+
input_data = X_train._cached().join(y_train._cached(), how="outer")
263271
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
264272
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})
265273

bigframes/series.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1503,7 +1503,8 @@ def _slice(
15031503
)
15041504

15051505
def _cached(self) -> Series:
1506-
return Series(self._block.cached())
1506+
self._set_block(self._block.cached())
1507+
return self
15071508

15081509

15091510
def _is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:

0 commit comments

Comments
 (0)