Skip to content

Commit 55c07e9

Browse files
authored
fix: stop ignoring arguments to MatrixFactorization.score(X, y) (#1726)
* fix: stop ignoring arguments to `MatrixFactorization.score(X, y)` * fix unit tests
1 parent df24c84 commit 55c07e9

File tree

4 files changed

+45
-10
lines changed

4 files changed

+45
-10
lines changed

bigframes/ml/decomposition.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,5 +360,12 @@ def score(
360360
if not self._bqml_model:
361361
raise RuntimeError("A model must be fitted before score")
362362

363-
# TODO(b/291973741): X param is ignored. Update BQML supports input in ML.EVALUATE.
364-
return self._bqml_model.evaluate()
363+
if X is not None and y is not None:
364+
X, y = utils.batch_convert_to_dataframe(
365+
X, y, session=self._bqml_model.session
366+
)
367+
input_data = X.join(y, how="outer")
368+
else:
369+
input_data = X
370+
371+
return self._bqml_model.evaluate(input_data)

tests/system/large/ml/test_decomposition.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import pandas as pd
16+
import pandas.testing
1617

1718
from bigframes.ml import decomposition
1819
from tests.system import utils
@@ -193,7 +194,16 @@ def test_decomposition_mf_configure_fit_load(
193194
)
194195
)
195196

196-
reloaded_model.score(new_ratings)
197+
# Make sure the input to score is not ignored.
198+
scores_training_data = reloaded_model.score().to_pandas()
199+
scores_new_ratings = reloaded_model.score(new_ratings).to_pandas()
200+
pandas.testing.assert_index_equal(
201+
scores_training_data.columns, scores_new_ratings.columns
202+
)
203+
assert (
204+
scores_training_data["mean_squared_error"].iloc[0]
205+
!= scores_new_ratings["mean_squared_error"].iloc[0]
206+
)
197207

198208
result = reloaded_model.predict(new_ratings).to_pandas()
199209

tests/unit/ml/test_golden_sql.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def mock_X(mock_y, mock_session):
8181
["index_column_id"],
8282
["index_column_label"],
8383
)
84+
type(mock_X).sql = mock.PropertyMock(return_value="input_X_sql_property")
8485
mock_X.reset_index(drop=True).cache().sql = "input_X_no_index_sql"
8586
mock_X.join(mock_y).sql = "input_X_y_sql"
8687
mock_X.join(mock_y).cache.return_value = mock_X.join(mock_y)
@@ -248,7 +249,7 @@ def test_decomposition_mf_predict(mock_session, bqml_model, mock_X):
248249
)
249250

250251

251-
def test_decomposition_mf_score(mock_session, bqml_model, mock_X):
252+
def test_decomposition_mf_score(mock_session, bqml_model):
252253
model = decomposition.MatrixFactorization(
253254
num_factors=34,
254255
feedback_type="explicit",
@@ -258,8 +259,23 @@ def test_decomposition_mf_score(mock_session, bqml_model, mock_X):
258259
l2_reg=9.83,
259260
)
260261
model._bqml_model = bqml_model
261-
model.score(mock_X)
262-
262+
model.score()
263263
mock_session.read_gbq.assert_called_once_with(
264264
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`)"
265265
)
266+
267+
268+
def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X):
269+
model = decomposition.MatrixFactorization(
270+
num_factors=34,
271+
feedback_type="explicit",
272+
user_col="user_id",
273+
item_col="item_col",
274+
rating_col="rating_col",
275+
l2_reg=9.83,
276+
)
277+
model._bqml_model = bqml_model
278+
model.score(mock_X)
279+
mock_session.read_gbq.assert_called_once_with(
280+
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))"
281+
)

third_party/bigframes_vendored/sklearn/decomposition/_mf.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,13 @@ def score(self, X=None, y=None):
7373
for the outputs relevant to this model type.
7474
7575
Args:
76-
X (default None):
77-
Ignored.
76+
X (bigframes.dataframe.DataFrame | bigframes.series.Series | None):
77+
DataFrame of shape (n_samples, n_features). Test samples.
78+
79+
y (bigframes.dataframe.DataFrame | bigframes.series.Series | None):
80+
DataFrame of shape (n_samples,) or (n_samples, n_outputs). True
81+
labels for `X`.
7882
79-
y (default None):
80-
Ignored.
8183
Returns:
8284
bigframes.dataframe.DataFrame: DataFrame that represents model metrics.
8385
"""

0 commit comments

Comments
 (0)