@@ -32,15 +32,17 @@ class TensorFlowModel(base.Predictor):
32
32
"""Imported TensorFlow model.
33
33
34
34
Args:
35
+ model_path (str):
36
+ GCS path that holds the model files.
35
37
session (BigQuery Session):
36
38
BQ session to create the model
37
- model_path (str):
38
- GCS path that holds the model files."""
39
+ """
39
40
40
41
def __init__ (
41
42
self ,
43
+ model_path : str ,
44
+ * ,
42
45
session : Optional [bigframes .Session ] = None ,
43
- model_path : Optional [str ] = None ,
44
46
):
45
47
self .session = session or bpd .get_global_session ()
46
48
self .model_path = model_path
@@ -59,7 +61,7 @@ def _from_bq(
59
61
) -> TensorFlowModel :
60
62
assert model .model_type == "TENSORFLOW"
61
63
62
- tf_model = cls (session = session , model_path = None )
64
+ tf_model = cls (session = session , model_path = "" )
63
65
tf_model ._bqml_model = core .BqmlModel (session , model )
64
66
return tf_model
65
67
@@ -109,15 +111,17 @@ class ONNXModel(base.Predictor):
109
111
"""Imported Open Neural Network Exchange (ONNX) model.
110
112
111
113
Args:
114
+ model_path (str):
115
+ Cloud Storage path that holds the model files.
112
116
session (BigQuery Session):
113
117
BQ session to create the model
114
- model_path (str):
115
- Cloud Storage path that holds the model files."""
118
+ """
116
119
117
120
def __init__ (
118
121
self ,
122
+ model_path : str ,
123
+ * ,
119
124
session : Optional [bigframes .Session ] = None ,
120
- model_path : Optional [str ] = None ,
121
125
):
122
126
self .session = session or bpd .get_global_session ()
123
127
self .model_path = model_path
@@ -134,7 +138,7 @@ def _create_bqml_model(self):
134
138
def _from_bq (cls , session : bigframes .Session , model : bigquery .Model ) -> ONNXModel :
135
139
assert model .model_type == "ONNX"
136
140
137
- onnx_model = cls (session = session , model_path = None )
141
+ onnx_model = cls (session = session , model_path = "" )
138
142
onnx_model ._bqml_model = core .BqmlModel (session , model )
139
143
return onnx_model
140
144
@@ -189,8 +193,8 @@ class XGBoostModel(base.Predictor):
189
193
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-xgboost#limitations
190
194
191
195
Args:
192
- session (BigQuery Session ):
193
- BQ session to create the model
196
+ model_path (str ):
197
+ Cloud Storage path that holds the model files.
194
198
input (Dict, default None):
195
199
Specify the model input schema information when you
196
200
create the XGBoost model. The input should be the format of
@@ -203,15 +207,17 @@ class XGBoostModel(base.Predictor):
203
207
{field_name: field_type}. Output is optional only if feature_names
204
208
and feature_types are both specified in the model file. Supported types
205
209
are "bool", "string", "int64", "float64", "array<bool>", "array<string>", "array<int64>", "array<float64>".
206
- model_path (str):
207
- Cloud Storage path that holds the model files."""
210
+ session (BigQuery Session):
211
+ BQ session to create the model
212
+ """
208
213
209
214
def __init__ (
210
215
self ,
211
- session : Optional [bigframes .Session ] = None ,
216
+ model_path : str ,
217
+ * ,
212
218
input : Mapping [str , str ] = {},
213
219
output : Mapping [str , str ] = {},
214
- model_path : Optional [str ] = None ,
220
+ session : Optional [bigframes . Session ] = None ,
215
221
):
216
222
self .session = session or bpd .get_global_session ()
217
223
self .model_path = model_path
@@ -248,7 +254,7 @@ def _from_bq(
248
254
) -> XGBoostModel :
249
255
assert model .model_type == "XGBOOST"
250
256
251
- xgboost_model = cls (session = session , model_path = None )
257
+ xgboost_model = cls (session = session , model_path = "" )
252
258
xgboost_model ._bqml_model = core .BqmlModel (session , model )
253
259
return xgboost_model
254
260
0 commit comments