Skip to content

Commit bb72562

Browse files
bryanesmithcopybara-github
authored andcommitted
Copybara import of the project:
-- 2504858 by Bryan Smith <bryanesmith@gmail.com>: add bq_dataset_id parameter to batch_serve_to_df, with unit test -- d6e20f6 by Bryan Smith <bryanesmith@gmail.com>: applied blacken formatting -- 7020556 by Bryan Smith <bryanesmith@gmail.com>: doc strings for helper methods -- a8f052e by Bryan Smith <bryanesmith@gmail.com>: remove the timestamp column fix COPYBARA_INTEGRATE_REVIEW=#1623 from bryanesmith:feature_bq_dataset_id 4f4762d PiperOrigin-RevId: 485654375
1 parent 7a4bfbe commit bb72562

File tree

2 files changed

+192
-23
lines changed

2 files changed

+192
-23
lines changed

google/cloud/aiplatform/featurestore/featurestore.py

Lines changed: 89 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,7 @@ def batch_serve_to_df(
10921092
feature_destination_fields: Optional[Dict[str, str]] = None,
10931093
request_metadata: Optional[Sequence[Tuple[str, str]]] = (),
10941094
serve_request_timeout: Optional[float] = None,
1095+
bq_dataset_id: Optional[str] = None,
10951096
) -> "pd.DataFrame": # noqa: F821 - skip check for undefined name 'pd'
10961097
"""Batch serves feature values to pandas DataFrame
10971098
@@ -1176,6 +1177,11 @@ def batch_serve_to_df(
11761177
serve_request_timeout (float):
11771178
Optional. The timeout for the serve request in seconds.
11781179
1180+
bq_dataset_id (str):
1181+
Optional. The full dataset ID for the BigQuery dataset to use
1182+
for temporarily staging data. If specified, caller must have
1183+
`bigquery.tables.create` permissions for Dataset.
1184+
11791185
Returns:
11801186
pd.DataFrame: The pandas DataFrame containing feature values from batch serving.
11811187
@@ -1210,34 +1216,43 @@ def batch_serve_to_df(
12101216

12111217
self.wait()
12121218
featurestore_name_components = self._parse_resource_name(self.resource_name)
1213-
featurestore_id = featurestore_name_components["featurestore"]
1214-
1215-
temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace(
1216-
"-", "_"
1217-
)
12181219

1219-
project_id = resource_manager_utils.get_project_id(
1220-
project_number=featurestore_name_components["project"],
1221-
credentials=self.credentials,
1222-
)
1223-
temp_bq_dataset_id = f"{project_id}.{temp_bq_dataset_name}"[:1024]
1224-
temp_bq_dataset = bigquery.Dataset(dataset_ref=temp_bq_dataset_id)
1225-
temp_bq_dataset.location = self.location
1226-
temp_bq_dataset = bigquery_client.create_dataset(temp_bq_dataset)
1220+
# if user didn't specify BigQuery dataset, create an ephemeral one
1221+
if bq_dataset_id is None:
1222+
temp_bq_full_dataset_id = self._get_ephemeral_bq_full_dataset_id(
1223+
featurestore_name_components["featurestore"],
1224+
featurestore_name_components["project"],
1225+
)
1226+
temp_bq_dataset = self._create_ephemeral_bq_dataset(
1227+
bigquery_client, temp_bq_full_dataset_id
1228+
)
1229+
temp_bq_batch_serve_table_name = "batch_serve"
1230+
temp_bq_read_instances_table_name = "read_instances"
1231+
1232+
# if user specified BigQuery dataset, create ephemeral tables
1233+
else:
1234+
temp_bq_full_dataset_id = bq_dataset_id
1235+
temp_bq_dataset = bigquery.Dataset(dataset_ref=temp_bq_full_dataset_id)
1236+
temp_bq_batch_serve_table_name = f"tmp_batch_serve_{uuid.uuid4()}".replace(
1237+
"-", "_"
1238+
)
1239+
temp_bq_read_instances_table_name = (
1240+
f"tmp_read_instances_{uuid.uuid4()}".replace("-", "_")
1241+
)
12271242

1228-
temp_bq_batch_serve_table_name = "batch_serve"
1229-
temp_bq_read_instances_table_name = "read_instances"
12301243
temp_bq_batch_serve_table_id = (
1231-
f"{temp_bq_dataset_id}.{temp_bq_batch_serve_table_name}"
1244+
f"{temp_bq_full_dataset_id}.{temp_bq_batch_serve_table_name}"
12321245
)
1246+
12331247
temp_bq_read_instances_table_id = (
1234-
f"{temp_bq_dataset_id}.{temp_bq_read_instances_table_name}"
1248+
f"{temp_bq_full_dataset_id}.{temp_bq_read_instances_table_name}"
12351249
)
12361250

12371251
try:
12381252

12391253
job = bigquery_client.load_table_from_dataframe(
1240-
dataframe=read_instances_df, destination=temp_bq_read_instances_table_id
1254+
dataframe=read_instances_df,
1255+
destination=temp_bq_read_instances_table_id,
12411256
)
12421257
job.result()
12431258

@@ -1259,7 +1274,7 @@ def batch_serve_to_df(
12591274
read_session=bigquery_storage.types.ReadSession(
12601275
table="projects/{project}/datasets/{dataset}/tables/{table}".format(
12611276
project=self.project,
1262-
dataset=temp_bq_dataset_name,
1277+
dataset=temp_bq_dataset.dataset_id,
12631278
table=temp_bq_batch_serve_table_name,
12641279
),
12651280
data_format=bigquery_storage.types.DataFormat.ARROW,
@@ -1273,9 +1288,60 @@ def batch_serve_to_df(
12731288
frames.append(message.to_dataframe())
12741289

12751290
finally:
1276-
bigquery_client.delete_dataset(
1277-
dataset=temp_bq_dataset.dataset_id,
1278-
delete_contents=True,
1279-
)
1291+
# clean up: if user didn't specify dataset, delete ephemeral dataset
1292+
if bq_dataset_id is None:
1293+
bigquery_client.delete_dataset(
1294+
dataset=temp_bq_dataset.dataset_id,
1295+
delete_contents=True,
1296+
)
1297+
1298+
# clean up: if user specified BigQuery dataset, delete ephemeral tables
1299+
else:
1300+
bigquery_client.delete_table(temp_bq_batch_serve_table_id)
1301+
bigquery_client.delete_table(temp_bq_read_instances_table_id)
12801302

12811303
return pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(frames)
1304+
1305+
def _get_ephemeral_bq_full_dataset_id(
1306+
self, featurestore_id: str, project_number: str
1307+
) -> str:
1308+
"""Helper method to generate an id for an ephemeral dataset in BigQuery
1309+
used to temporarily stage data.
1310+
1311+
Args:
1312+
featurestore_id (str):
1313+
Required. The ID to use for this featurestore.
1314+
project_number (str):
1315+
Required. Project to retrieve featurestore from.
1316+
Returns:
1317+
str - full BigQuery dataset ID
1318+
"""
1319+
temp_bq_dataset_name = f"temp_{featurestore_id}_{uuid.uuid4()}".replace(
1320+
"-", "_"
1321+
)
1322+
1323+
project_id = resource_manager_utils.get_project_id(
1324+
project_number=project_number,
1325+
credentials=self.credentials,
1326+
)
1327+
1328+
return f"{project_id}.{temp_bq_dataset_name}"[:1024]
1329+
1330+
def _create_ephemeral_bq_dataset(
1331+
self, bigquery_client: bigquery.Client, dataset_id: str
1332+
) -> "bigquery.Dataset":
1333+
"""Helper method to create an ephemeral dataset in BigQuery used to
1334+
temporarily stage data.
1335+
1336+
Args:
1337+
bigquery_client (bigquery.Client):
1338+
Required. BigQuery client to use to generate the BigQuery dataset.
1339+
dataset_id (str):
1340+
Required. Identifier to use for the BigQuery dataset.
1341+
Returns:
1342+
bigquery.Dataset - new BigQuery dataset used to temporarily stage data
1343+
"""
1344+
temp_bq_dataset = bigquery.Dataset(dataset_ref=dataset_id)
1345+
temp_bq_dataset.location = self.location
1346+
1347+
return bigquery_client.create_dataset(temp_bq_dataset)

tests/unit/aiplatform/test_featurestores.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ def bq_delete_dataset_mock(bq_client_mock):
418418
yield bq_delete_dataset_mock
419419

420420

421+
@pytest.fixture
422+
def bq_delete_table_mock(bq_client_mock):
423+
with patch.object(bq_client_mock, "delete_table") as bq_delete_table_mock:
424+
yield bq_delete_table_mock
425+
426+
421427
@pytest.fixture
422428
def bqs_client_mock():
423429
mock = MagicMock(bigquery_storage.BigQueryReadClient)
@@ -1701,6 +1707,103 @@ def test_batch_serve_to_df(self, batch_read_feature_values_mock):
17011707
timeout=None,
17021708
)
17031709

1710+
@pytest.mark.skipif(
1711+
_USE_BQ_STORAGE is False, reason="batch_serve_to_df requires bigquery_storage"
1712+
)
1713+
@pytest.mark.usefixtures(
1714+
"get_featurestore_mock",
1715+
"bq_init_client_mock",
1716+
"bq_init_dataset_mock",
1717+
"bq_create_dataset_mock",
1718+
"bq_load_table_from_dataframe_mock",
1719+
"bq_delete_dataset_mock",
1720+
"bq_delete_table_mock",
1721+
"bqs_init_client_mock",
1722+
"bqs_create_read_session",
1723+
"get_project_mock",
1724+
)
1725+
@patch("uuid.uuid4", uuid_mock)
1726+
def test_batch_serve_to_df_user_specified_bq_dataset(
1727+
self,
1728+
batch_read_feature_values_mock,
1729+
bq_create_dataset_mock,
1730+
bq_delete_dataset_mock,
1731+
bq_delete_table_mock,
1732+
):
1733+
1734+
aiplatform.init(project=_TEST_PROJECT_DIFF)
1735+
1736+
my_featurestore = aiplatform.Featurestore(
1737+
featurestore_name=_TEST_FEATURESTORE_NAME
1738+
)
1739+
1740+
read_instances_df = pd.DataFrame()
1741+
1742+
expected_temp_bq_dataset_name = "my_dataset_name"
1743+
expected_temp_bq_dataset_id = (
1744+
f"{_TEST_PROJECT}.{expected_temp_bq_dataset_name}"[:1024]
1745+
)
1746+
expected_temp_bq_batch_serve_table_name = (
1747+
f"tmp_batch_serve_{uuid.uuid4()}".replace("-", "_")
1748+
)
1749+
expected_temp_bq_batch_serve_table_id = (
1750+
f"{expected_temp_bq_dataset_id}.{expected_temp_bq_batch_serve_table_name}"
1751+
)
1752+
expected_temp_bq_read_instances_table_name = (
1753+
f"tmp_read_instances_{uuid.uuid4()}".replace("-", "_")
1754+
)
1755+
expected_temp_bq_read_instances_table_id = f"{expected_temp_bq_dataset_id}.{expected_temp_bq_read_instances_table_name}"
1756+
1757+
expected_entity_type_specs = [
1758+
_get_entity_type_spec_proto_with_feature_ids(
1759+
entity_type_id="my_entity_type_id_1",
1760+
feature_ids=["my_feature_id_1_1", "my_feature_id_1_2"],
1761+
),
1762+
_get_entity_type_spec_proto_with_feature_ids(
1763+
entity_type_id="my_entity_type_id_2",
1764+
feature_ids=["my_feature_id_2_1", "my_feature_id_2_2"],
1765+
),
1766+
]
1767+
1768+
expected_batch_read_feature_values_request = (
1769+
gca_featurestore_service.BatchReadFeatureValuesRequest(
1770+
featurestore=my_featurestore.resource_name,
1771+
destination=gca_featurestore_service.FeatureValueDestination(
1772+
bigquery_destination=gca_io.BigQueryDestination(
1773+
output_uri=f"bq://{expected_temp_bq_batch_serve_table_id}"
1774+
),
1775+
),
1776+
entity_type_specs=expected_entity_type_specs,
1777+
bigquery_read_instances=gca_io.BigQuerySource(
1778+
input_uri=f"bq://{expected_temp_bq_read_instances_table_id}"
1779+
),
1780+
)
1781+
)
1782+
1783+
my_featurestore.batch_serve_to_df(
1784+
serving_feature_ids=_TEST_SERVING_FEATURE_IDS,
1785+
read_instances_df=read_instances_df,
1786+
serve_request_timeout=None,
1787+
bq_dataset_id=expected_temp_bq_dataset_id,
1788+
)
1789+
1790+
batch_read_feature_values_mock.assert_called_once_with(
1791+
request=expected_batch_read_feature_values_request,
1792+
metadata=_TEST_REQUEST_METADATA,
1793+
timeout=None,
1794+
)
1795+
1796+
bq_delete_table_mock.assert_has_calls(
1797+
calls=[
1798+
mock.call(expected_temp_bq_batch_serve_table_id),
1799+
mock.call(expected_temp_bq_read_instances_table_id),
1800+
],
1801+
any_order=True,
1802+
)
1803+
1804+
bq_create_dataset_mock.assert_not_called()
1805+
bq_delete_dataset_mock.assert_not_called()
1806+
17041807

17051808
class TestEntityType:
17061809
def setup_method(self):

0 commit comments

Comments
 (0)