Skip to content

Commit 4241738

Browse files
authored
feat: MBSDK Tabular samples (#338)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: Tracking Bug: [MB SDK Samples - Milestone 1](https://buganizer.corp.google.com/issues/180729765) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-aiplatform/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent c057083 commit 4241738

23 files changed

+693
-31
lines changed

samples/model-builder/conftest.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,58 @@ def mock_import_text_dataset(mock_text_dataset):
138138

139139

140140
@pytest.fixture
141-
def mock_init_automl_image_training_job():
142-
with patch.object(
143-
aiplatform.training_jobs.AutoMLImageTrainingJob, "__init__"
144-
) as mock:
145-
mock.return_value = None
141+
def mock_custom_training_job():
142+
mock = MagicMock(aiplatform.training_jobs.CustomTrainingJob)
143+
yield mock
144+
145+
146+
@pytest.fixture
147+
def mock_image_training_job():
148+
mock = MagicMock(aiplatform.training_jobs.AutoMLImageTrainingJob)
149+
yield mock
150+
151+
152+
@pytest.fixture
153+
def mock_tabular_training_job():
154+
mock = MagicMock(aiplatform.training_jobs.AutoMLTabularTrainingJob)
155+
yield mock
156+
157+
158+
@pytest.fixture
159+
def mock_text_training_job():
160+
mock = MagicMock(aiplatform.training_jobs.AutoMLTextTrainingJob)
161+
yield mock
162+
163+
164+
@pytest.fixture
165+
def mock_video_training_job():
166+
mock = MagicMock(aiplatform.training_jobs.AutoMLVideoTrainingJob)
167+
yield mock
168+
169+
170+
@pytest.fixture
171+
def mock_get_automl_tabular_training_job(mock_tabular_training_job):
172+
with patch.object(aiplatform, "AutoMLTabularTrainingJob") as mock:
173+
mock.return_value = mock_tabular_training_job
174+
yield mock
175+
176+
177+
@pytest.fixture
178+
def mock_run_automl_tabular_training_job(mock_tabular_training_job):
179+
with patch.object(mock_tabular_training_job, "run") as mock:
180+
yield mock
181+
182+
183+
@pytest.fixture
184+
def mock_get_automl_image_training_job(mock_image_training_job):
185+
with patch.object(aiplatform, "AutoMLImageTrainingJob") as mock:
186+
mock.return_value = mock_image_training_job
146187
yield mock
147188

148189

149190
@pytest.fixture
150-
def mock_run_automl_image_training_job():
151-
with patch.object(aiplatform.training_jobs.AutoMLImageTrainingJob, "run") as mock:
191+
def mock_run_automl_image_training_job(mock_image_training_job):
192+
with patch.object(mock_image_training_job, "run") as mock:
152193
yield mock
153194

154195

@@ -173,15 +214,21 @@ def mock_run_custom_training_job():
173214

174215

175216
@pytest.fixture
176-
def mock_init_model():
177-
with patch.object(aiplatform.models.Model, "__init__") as mock:
178-
mock.return_value = None
217+
def mock_model():
218+
mock = MagicMock(aiplatform.models.Model)
219+
yield mock
220+
221+
222+
@pytest.fixture
223+
def mock_get_model(mock_model):
224+
with patch.object(aiplatform, "Model") as mock:
225+
mock.return_value = mock_model
179226
yield mock
180227

181228

182229
@pytest.fixture
183-
def mock_batch_predict_model():
184-
with patch.object(aiplatform.models.Model, "batch_predict") as mock:
230+
def mock_batch_predict_model(mock_model):
231+
with patch.object(mock_model, "batch_predict") as mock:
185232
yield mock
186233

187234

@@ -211,6 +258,12 @@ def mock_endpoint():
211258
yield mock
212259

213260

261+
@pytest.fixture
262+
def mock_create_endpoint():
263+
with patch.object(aiplatform.Endpoint, "create") as mock:
264+
yield mock
265+
266+
214267
@pytest.fixture
215268
def mock_get_endpoint(mock_endpoint):
216269
with patch.object(aiplatform, "Endpoint") as mock_get_endpoint:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.cloud import aiplatform
17+
18+
19+
# [START aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
20+
def create_and_import_dataset_tabular_bigquery_sample(
21+
display_name: str, project: str, location: str, bigquery_source: str,
22+
):
23+
24+
aiplatform.init(project=project, location=location)
25+
26+
dataset = aiplatform.TabularDataset.create(
27+
display_name=display_name, bigquery_source=bigquery_source,
28+
)
29+
30+
dataset.wait()
31+
32+
print(f'\tDataset: "{dataset.display_name}"')
33+
print(f'\tname: "{dataset.resource_name}"')
34+
35+
36+
# [END aiplatform_sdk_create_and_import_dataset_tabular_bigquery_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_tabular_bigquery_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_tabular_bigquery_sample(
21+
mock_sdk_init, mock_create_tabular_dataset
22+
):
23+
24+
create_and_import_dataset_tabular_bigquery_sample.create_and_import_dataset_tabular_bigquery_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
bigquery_source=constants.BIGQUERY_SOURCE,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_tabular_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME, bigquery_source=constants.BIGQUERY_SOURCE,
36+
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Union
16+
17+
from google.cloud import aiplatform
18+
19+
20+
# [START aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample]
21+
def create_and_import_dataset_tabular_gcs_sample(
22+
display_name: str, project: str, location: str, gcs_source: Union[str, List[str]],
23+
):
24+
25+
aiplatform.init(project=project, location=location)
26+
27+
dataset = aiplatform.TabularDataset.create(
28+
display_name=display_name, gcs_source=gcs_source,
29+
)
30+
31+
dataset.wait()
32+
33+
print(f'\tDataset: "{dataset.display_name}"')
34+
print(f'\tname: "{dataset.resource_name}"')
35+
36+
37+
# [END aiplatform_sdk_create_and_import_dataset_tabular_gcs_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_and_import_dataset_tabular_gcs_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_and_import_dataset_tabular_gcs_sample(
21+
mock_sdk_init, mock_create_tabular_dataset
22+
):
23+
24+
create_and_import_dataset_tabular_gcs_sample.create_and_import_dataset_tabular_gcs_sample(
25+
project=constants.PROJECT,
26+
location=constants.LOCATION,
27+
gcs_source=constants.GCS_SOURCES,
28+
display_name=constants.DISPLAY_NAME,
29+
)
30+
31+
mock_sdk_init.assert_called_once_with(
32+
project=constants.PROJECT, location=constants.LOCATION
33+
)
34+
mock_create_tabular_dataset.assert_called_once_with(
35+
display_name=constants.DISPLAY_NAME, gcs_source=constants.GCS_SOURCES,
36+
)

samples/model-builder/create_batch_prediction_job_sample_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
def test_create_batch_prediction_job_sample(
21-
mock_sdk_init, mock_init_model, mock_batch_predict_model
21+
mock_sdk_init, mock_get_model, mock_batch_predict_model
2222
):
2323

2424
create_batch_prediction_job_sample.create_batch_prediction_job_sample(
@@ -33,7 +33,7 @@ def test_create_batch_prediction_job_sample(
3333
mock_sdk_init.assert_called_once_with(
3434
project=constants.PROJECT, location=constants.LOCATION
3535
)
36-
mock_init_model.assert_called_once_with(constants.MODEL_NAME)
36+
mock_get_model.assert_called_once_with(constants.MODEL_NAME)
3737
mock_batch_predict_model.assert_called_once_with(
3838
job_display_name=constants.DISPLAY_NAME,
3939
gcs_source=constants.GCS_SOURCES,
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud import aiplatform
16+
17+
18+
# [START aiplatform_sdk_create_endpoint_sample]
19+
def create_endpoint_sample(
20+
project: str, display_name: str, location: str, sync: bool = True,
21+
):
22+
aiplatform.init(project=project, location=location)
23+
24+
endpoint = aiplatform.Endpoint.create(
25+
display_name=display_name, project=project, location=location,
26+
)
27+
28+
print(endpoint.display_name)
29+
print(endpoint.resource_name)
30+
print(endpoint.uri)
31+
return endpoint
32+
33+
34+
# [END aiplatform_sdk_create_endpoint_sample]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import create_endpoint_sample
17+
import test_constants as constants
18+
19+
20+
def test_create_endpoint_sample(mock_sdk_init, mock_create_endpoint):
21+
22+
create_endpoint_sample.create_endpoint_sample(
23+
project=constants.PROJECT,
24+
display_name=constants.DISPLAY_NAME,
25+
location=constants.LOCATION,
26+
)
27+
28+
mock_sdk_init.assert_called_once_with(
29+
project=constants.PROJECT, location=constants.LOCATION
30+
)
31+
32+
mock_create_endpoint.assert_called_once_with(
33+
display_name=constants.DISPLAY_NAME,
34+
project=constants.PROJECT,
35+
location=constants.LOCATION,
36+
)

samples/model-builder/create_training_pipeline_image_classification_sample_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def test_create_training_pipeline_image_classification_sample(
2121
mock_sdk_init,
2222
mock_image_dataset,
23-
mock_init_automl_image_training_job,
23+
mock_get_automl_image_training_job,
2424
mock_run_automl_image_training_job,
2525
mock_get_image_dataset,
2626
):
@@ -43,7 +43,7 @@ def test_create_training_pipeline_image_classification_sample(
4343
mock_sdk_init.assert_called_once_with(
4444
project=constants.PROJECT, location=constants.LOCATION
4545
)
46-
mock_init_automl_image_training_job.assert_called_once_with(
46+
mock_get_automl_image_training_job.assert_called_once_with(
4747
display_name=constants.DISPLAY_NAME
4848
)
4949
mock_run_automl_image_training_job.assert_called_once_with(

0 commit comments

Comments
 (0)