Skip to content

Commit 4037992

Browse files
authored
refactor!: move model optional args to kwargs (#381)
To be more like sklearn, and make API more accurate. Those param shouldn't be called through positions.
1 parent 59b446b commit 4037992

File tree

8 files changed

+45
-22
lines changed

8 files changed

+45
-22
lines changed

bigframes/ml/base.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def register(self: _T, vertex_ai_model_id: Optional[str] = None) -> _T:
127127
self._bqml_model.register(vertex_ai_model_id)
128128
return self
129129

130+
@abc.abstractmethod
131+
def to_gbq(self, model_name, replace):
132+
pass
133+
130134

131135
class TrainablePredictor(Predictor):
132136
"""A BigQuery DataFrames ML Model base class that can be used to fit and predict outputs.
@@ -141,11 +145,6 @@ def _fit(self, X, y, transforms=None):
141145
def score(self, X, y):
142146
pass
143147

144-
# TODO(b/291812029): move to Predictor after implement in LLM and imported models
145-
@abc.abstractmethod
146-
def to_gbq(self, model_name, replace):
147-
pass
148-
149148

150149
class SupervisedTrainablePredictor(TrainablePredictor):
151150
"""A BigQuery DataFrames ML Supervised Model base class that can be used to fit and predict outputs.
@@ -165,7 +164,7 @@ def fit(
165164
class UnsupervisedTrainablePredictor(TrainablePredictor):
166165
"""A BigQuery DataFrames ML Unsupervised Model base class that can be used to fit and predict outputs.
167166
168-
Only need to provide both X (y is optional and ignored) in unsupervised tasks."""
167+
Only need to provide X (y is optional and ignored) in unsupervised tasks."""
169168

170169
_T = TypeVar("_T", bound="UnsupervisedTrainablePredictor")
171170

bigframes/ml/ensemble.py

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class XGBRegressor(
5858
def __init__(
5959
self,
6060
num_parallel_tree: int = 1,
61+
*,
6162
booster: Literal["gbtree", "dart"] = "gbtree",
6263
dart_normalized_type: Literal["tree", "forest"] = "tree",
6364
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
@@ -215,6 +216,7 @@ class XGBClassifier(
215216
def __init__(
216217
self,
217218
num_parallel_tree: int = 1,
219+
*,
218220
booster: Literal["gbtree", "dart"] = "gbtree",
219221
dart_normalized_type: Literal["tree", "forest"] = "tree",
220222
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
@@ -372,6 +374,7 @@ class RandomForestRegressor(
372374
def __init__(
373375
self,
374376
num_parallel_tree: int = 100,
377+
*,
375378
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
376379
min_tree_child_weight: int = 1,
377380
colsample_bytree=1.0,
@@ -538,6 +541,7 @@ class RandomForestClassifier(
538541
def __init__(
539542
self,
540543
num_parallel_tree: int = 100,
544+
*,
541545
tree_method: Literal["auto", "exact", "approx", "hist"] = "auto",
542546
min_tree_child_weight: int = 1,
543547
colsample_bytree: float = 1.0,

bigframes/ml/forecasting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _fit(
8787
)
8888

8989
def predict(
90-
self, X=None, horizon: int = 3, confidence_level: float = 0.95
90+
self, X=None, *, horizon: int = 3, confidence_level: float = 0.95
9191
) -> bpd.DataFrame:
9292
"""Predict the closest cluster for each sample in X.
9393

bigframes/ml/imported.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ class TensorFlowModel(base.Predictor):
3232
"""Imported TensorFlow model.
3333
3434
Args:
35+
model_path (str):
36+
GCS path that holds the model files.
3537
session (BigQuery Session):
3638
BQ session to create the model
37-
model_path (str):
38-
GCS path that holds the model files."""
39+
"""
3940

4041
def __init__(
4142
self,
43+
model_path: str,
44+
*,
4245
session: Optional[bigframes.Session] = None,
43-
model_path: Optional[str] = None,
4446
):
4547
self.session = session or bpd.get_global_session()
4648
self.model_path = model_path
@@ -59,7 +61,7 @@ def _from_bq(
5961
) -> TensorFlowModel:
6062
assert model.model_type == "TENSORFLOW"
6163

62-
tf_model = cls(session=session, model_path=None)
64+
tf_model = cls(session=session, model_path="")
6365
tf_model._bqml_model = core.BqmlModel(session, model)
6466
return tf_model
6567

@@ -109,15 +111,17 @@ class ONNXModel(base.Predictor):
109111
"""Imported Open Neural Network Exchange (ONNX) model.
110112
111113
Args:
114+
model_path (str):
115+
Cloud Storage path that holds the model files.
112116
session (BigQuery Session):
113117
BQ session to create the model
114-
model_path (str):
115-
Cloud Storage path that holds the model files."""
118+
"""
116119

117120
def __init__(
118121
self,
122+
model_path: str,
123+
*,
119124
session: Optional[bigframes.Session] = None,
120-
model_path: Optional[str] = None,
121125
):
122126
self.session = session or bpd.get_global_session()
123127
self.model_path = model_path
@@ -134,7 +138,7 @@ def _create_bqml_model(self):
134138
def _from_bq(cls, session: bigframes.Session, model: bigquery.Model) -> ONNXModel:
135139
assert model.model_type == "ONNX"
136140

137-
onnx_model = cls(session=session, model_path=None)
141+
onnx_model = cls(session=session, model_path="")
138142
onnx_model._bqml_model = core.BqmlModel(session, model)
139143
return onnx_model
140144

@@ -189,8 +193,8 @@ class XGBoostModel(base.Predictor):
189193
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-xgboost#limitations
190194
191195
Args:
192-
session (BigQuery Session):
193-
BQ session to create the model
196+
model_path (str):
197+
Cloud Storage path that holds the model files.
194198
input (Dict, default None):
195199
Specify the model input schema information when you
196200
create the XGBoost model. The input should be the format of
@@ -203,15 +207,17 @@ class XGBoostModel(base.Predictor):
203207
{field_name: field_type}. Output is optional only if feature_names
204208
and feature_types are both specified in the model file. Supported types
205209
are "bool", "string", "int64", "float64", "array<bool>", "array<string>", "array<int64>", "array<float64>".
206-
model_path (str):
207-
Cloud Storage path that holds the model files."""
210+
session (BigQuery Session):
211+
BQ session to create the model
212+
"""
208213

209214
def __init__(
210215
self,
211-
session: Optional[bigframes.Session] = None,
216+
model_path: str,
217+
*,
212218
input: Mapping[str, str] = {},
213219
output: Mapping[str, str] = {},
214-
model_path: Optional[str] = None,
220+
session: Optional[bigframes.Session] = None,
215221
):
216222
self.session = session or bpd.get_global_session()
217223
self.model_path = model_path
@@ -248,7 +254,7 @@ def _from_bq(
248254
) -> XGBoostModel:
249255
assert model.model_type == "XGBOOST"
250256

251-
xgboost_model = cls(session=session, model_path=None)
257+
xgboost_model = cls(session=session, model_path="")
252258
xgboost_model._bqml_model = core.BqmlModel(session, model)
253259
return xgboost_model
254260

bigframes/ml/linear_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class LinearRegression(
5858

5959
def __init__(
6060
self,
61+
*,
6162
optimize_strategy: Literal[
6263
"auto_strategy", "batch_gradient_descent", "normal_equation"
6364
] = "normal_equation",
@@ -192,6 +193,7 @@ class LogisticRegression(
192193
# TODO(ashleyxu) support class_weights in the constructor.
193194
def __init__(
194195
self,
196+
*,
195197
fit_intercept: bool = True,
196198
class_weights: Optional[Union[Literal["balanced"], Dict[str, float]]] = None,
197199
):

bigframes/ml/llm.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class PaLM2TextGenerator(base.Predictor):
6666

6767
def __init__(
6868
self,
69+
*,
6970
model_name: Literal["text-bison", "text-bison-32k"] = "text-bison",
7071
session: Optional[bigframes.Session] = None,
7172
connection_name: Optional[str] = None,
@@ -140,6 +141,7 @@ def _from_bq(
140141
def predict(
141142
self,
142143
X: Union[bpd.DataFrame, bpd.Series],
144+
*,
143145
temperature: float = 0.0,
144146
max_output_tokens: int = 128,
145147
top_k: int = 40,
@@ -273,6 +275,7 @@ class PaLM2TextEmbeddingGenerator(base.Predictor):
273275

274276
def __init__(
275277
self,
278+
*,
276279
model_name: Literal[
277280
"textembedding-gecko", "textembedding-gecko-multilingual"
278281
] = "textembedding-gecko",
@@ -415,6 +418,7 @@ class GeminiTextGenerator(base.Predictor):
415418

416419
def __init__(
417420
self,
421+
*,
418422
session: Optional[bigframes.Session] = None,
419423
connection_name: Optional[str] = None,
420424
):
@@ -475,6 +479,7 @@ def _from_bq(
475479
def predict(
476480
self,
477481
X: Union[bpd.DataFrame, bpd.Series],
482+
*,
478483
temperature: float = 0.9,
479484
max_output_tokens: int = 8192,
480485
top_k: int = 40,

bigframes/ml/metrics/_metrics.py

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
def r2_score(
3535
y_true: Union[bpd.DataFrame, bpd.Series],
3636
y_pred: Union[bpd.DataFrame, bpd.Series],
37+
*,
3738
force_finite=True,
3839
) -> float:
3940
y_true_series, y_pred_series = utils.convert_to_series(y_true, y_pred)
@@ -61,6 +62,7 @@ def r2_score(
6162
def accuracy_score(
6263
y_true: Union[bpd.DataFrame, bpd.Series],
6364
y_pred: Union[bpd.DataFrame, bpd.Series],
65+
*,
6466
normalize=True,
6567
) -> float:
6668
# TODO(ashleyxu): support sample_weight as the parameter
@@ -83,6 +85,7 @@ def accuracy_score(
8385
def roc_curve(
8486
y_true: Union[bpd.DataFrame, bpd.Series],
8587
y_score: Union[bpd.DataFrame, bpd.Series],
88+
*,
8689
drop_intermediate: bool = True,
8790
) -> Tuple[bpd.Series, bpd.Series, bpd.Series]:
8891
# TODO(bmil): Add multi-class support
@@ -227,6 +230,7 @@ def confusion_matrix(
227230
def recall_score(
228231
y_true: Union[bpd.DataFrame, bpd.Series],
229232
y_pred: Union[bpd.DataFrame, bpd.Series],
233+
*,
230234
average: str = "binary",
231235
) -> pd.Series:
232236
# TODO(ashleyxu): support more average type, default to "binary"
@@ -263,6 +267,7 @@ def recall_score(
263267
def precision_score(
264268
y_true: Union[bpd.DataFrame, bpd.Series],
265269
y_pred: Union[bpd.DataFrame, bpd.Series],
270+
*,
266271
average: str = "binary",
267272
) -> pd.Series:
268273
# TODO(ashleyxu): support more average type, default to "binary"
@@ -301,6 +306,7 @@ def precision_score(
301306
def f1_score(
302307
y_true: Union[bpd.DataFrame, bpd.Series],
303308
y_pred: Union[bpd.DataFrame, bpd.Series],
309+
*,
304310
average: str = "binary",
305311
) -> pd.Series:
306312
# TODO(ashleyxu): support more average type, default to "binary"

bigframes/ml/remote.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
endpoint: str,
5555
input: Mapping[str, str],
5656
output: Mapping[str, str],
57+
*,
5758
session: Optional[bigframes.Session] = None,
5859
connection_name: Optional[str] = None,
5960
):

0 commit comments

Comments
 (0)