17
17
18
18
from __future__ import annotations
19
19
20
- from typing import Dict , List , Literal , Optional , Union
20
+ from typing import Dict , List , Literal , Optional
21
21
22
22
import bigframes_vendored .sklearn .ensemble ._forest
23
23
import bigframes_vendored .xgboost .sklearn
@@ -142,8 +142,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
142
142
143
143
def _fit (
144
144
self ,
145
- X : Union [ bpd . DataFrame , bpd . Series ] ,
146
- y : Union [ bpd . DataFrame , bpd . Series ] ,
145
+ X : utils . ArrayType ,
146
+ y : utils . ArrayType ,
147
147
transforms : Optional [List [str ]] = None ,
148
148
) -> XGBRegressor :
149
149
X , y = utils .convert_to_dataframe (X , y )
@@ -158,24 +158,24 @@ def _fit(
158
158
159
159
def predict (
160
160
self ,
161
- X : Union [ bpd . DataFrame , bpd . Series ] ,
161
+ X : utils . ArrayType ,
162
162
) -> bpd .DataFrame :
163
163
if not self ._bqml_model :
164
164
raise RuntimeError ("A model must be fitted before predict" )
165
- (X ,) = utils .convert_to_dataframe (X )
165
+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
166
166
167
167
return self ._bqml_model .predict (X )
168
168
169
169
def score (
170
170
self ,
171
- X : Union [ bpd . DataFrame , bpd . Series ] ,
172
- y : Union [ bpd . DataFrame , bpd . Series ] ,
171
+ X : utils . ArrayType ,
172
+ y : utils . ArrayType ,
173
173
):
174
- X , y = utils .convert_to_dataframe (X , y )
175
-
176
174
if not self ._bqml_model :
177
175
raise RuntimeError ("A model must be fitted before score" )
178
176
177
+ X , y = utils .convert_to_dataframe (X , y , session = self ._bqml_model .session )
178
+
179
179
input_data = (
180
180
X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
181
181
)
@@ -291,8 +291,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
291
291
292
292
def _fit (
293
293
self ,
294
- X : Union [ bpd . DataFrame , bpd . Series ] ,
295
- y : Union [ bpd . DataFrame , bpd . Series ] ,
294
+ X : utils . ArrayType ,
295
+ y : utils . ArrayType ,
296
296
transforms : Optional [List [str ]] = None ,
297
297
) -> XGBClassifier :
298
298
X , y = utils .convert_to_dataframe (X , y )
@@ -305,22 +305,22 @@ def _fit(
305
305
)
306
306
return self
307
307
308
- def predict (self , X : Union [ bpd . DataFrame , bpd . Series ] ) -> bpd .DataFrame :
308
+ def predict (self , X : utils . ArrayType ) -> bpd .DataFrame :
309
309
if not self ._bqml_model :
310
310
raise RuntimeError ("A model must be fitted before predict" )
311
- (X ,) = utils .convert_to_dataframe (X )
311
+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
312
312
313
313
return self ._bqml_model .predict (X )
314
314
315
315
def score (
316
316
self ,
317
- X : Union [ bpd . DataFrame , bpd . Series ] ,
318
- y : Union [ bpd . DataFrame , bpd . Series ] ,
317
+ X : utils . ArrayType ,
318
+ y : utils . ArrayType ,
319
319
):
320
320
if not self ._bqml_model :
321
321
raise RuntimeError ("A model must be fitted before score" )
322
322
323
- X , y = utils .convert_to_dataframe (X , y )
323
+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
324
324
325
325
input_data = (
326
326
X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
@@ -427,8 +427,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
427
427
428
428
def _fit (
429
429
self ,
430
- X : Union [ bpd . DataFrame , bpd . Series ] ,
431
- y : Union [ bpd . DataFrame , bpd . Series ] ,
430
+ X : utils . ArrayType ,
431
+ y : utils . ArrayType ,
432
432
transforms : Optional [List [str ]] = None ,
433
433
) -> RandomForestRegressor :
434
434
X , y = utils .convert_to_dataframe (X , y )
@@ -443,18 +443,18 @@ def _fit(
443
443
444
444
def predict (
445
445
self ,
446
- X : Union [ bpd . DataFrame , bpd . Series ] ,
446
+ X : utils . ArrayType ,
447
447
) -> bpd .DataFrame :
448
448
if not self ._bqml_model :
449
449
raise RuntimeError ("A model must be fitted before predict" )
450
- (X ,) = utils .convert_to_dataframe (X )
450
+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
451
451
452
452
return self ._bqml_model .predict (X )
453
453
454
454
def score (
455
455
self ,
456
- X : Union [ bpd . DataFrame , bpd . Series ] ,
457
- y : Union [ bpd . DataFrame , bpd . Series ] ,
456
+ X : utils . ArrayType ,
457
+ y : utils . ArrayType ,
458
458
):
459
459
"""Calculate evaluation metrics of the model.
460
460
@@ -476,7 +476,7 @@ def score(
476
476
if not self ._bqml_model :
477
477
raise RuntimeError ("A model must be fitted before score" )
478
478
479
- X , y = utils .convert_to_dataframe (X , y )
479
+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
480
480
481
481
input_data = (
482
482
X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
@@ -583,8 +583,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
583
583
584
584
def _fit (
585
585
self ,
586
- X : Union [ bpd . DataFrame , bpd . Series ] ,
587
- y : Union [ bpd . DataFrame , bpd . Series ] ,
586
+ X : utils . ArrayType ,
587
+ y : utils . ArrayType ,
588
588
transforms : Optional [List [str ]] = None ,
589
589
) -> RandomForestClassifier :
590
590
X , y = utils .convert_to_dataframe (X , y )
@@ -599,18 +599,18 @@ def _fit(
599
599
600
600
def predict (
601
601
self ,
602
- X : Union [ bpd . DataFrame , bpd . Series ] ,
602
+ X : utils . ArrayType ,
603
603
) -> bpd .DataFrame :
604
604
if not self ._bqml_model :
605
605
raise RuntimeError ("A model must be fitted before predict" )
606
- (X ,) = utils .convert_to_dataframe (X )
606
+ (X ,) = utils .convert_to_dataframe (X , session = self . _bqml_model . session )
607
607
608
608
return self ._bqml_model .predict (X )
609
609
610
610
def score (
611
611
self ,
612
- X : Union [ bpd . DataFrame , bpd . Series ] ,
613
- y : Union [ bpd . DataFrame , bpd . Series ] ,
612
+ X : utils . ArrayType ,
613
+ y : utils . ArrayType ,
614
614
):
615
615
"""Calculate evaluation metrics of the model.
616
616
@@ -632,7 +632,7 @@ def score(
632
632
if not self ._bqml_model :
633
633
raise RuntimeError ("A model must be fitted before score" )
634
634
635
- X , y = utils .convert_to_dataframe (X , y )
635
+ X , y = utils .convert_to_dataframe (X , y , session = self . _bqml_model . session )
636
636
637
637
input_data = (
638
638
X .join (y , how = "outer" ) if (X is not None ) and (y is not None ) else None
0 commit comments