Skip to content

Commit ad09c29

Browse files
telpiriondizcology
andauthored
feat: adds text batch prediction samples (#82)
* feat: adds text batch prediction samples * fix: lint * fix: broken test * fix: tests * fix: more changes * fix: TSA batch prediction test updates * fix: working model (I hope!) Co-authored-by: Yu-Han Liu <yuhanliu@google.com>
1 parent b012283 commit ad09c29

13 files changed

+431
-4
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_batch_prediction_job_text_classification_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf.struct_pb2 import Value
18+
19+
20+
def create_batch_prediction_job_text_classification_sample(
21+
project: str,
22+
display_name: str,
23+
model: str,
24+
gcs_source_uri: str,
25+
gcs_destination_output_uri_prefix: str,
26+
location: str = "us-central1",
27+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
28+
):
29+
client_options = {"api_endpoint": api_endpoint}
30+
# Initialize client that will be used to create and send requests.
31+
# This client only needs to be created once, and can be reused for multiple requests.
32+
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
33+
34+
batch_prediction_job = {
35+
"display_name": display_name,
36+
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
37+
"model": model,
38+
"model_parameters": Value(),
39+
"input_config": {
40+
"instances_format": "jsonl",
41+
"gcs_source": {"uris": [gcs_source_uri]},
42+
},
43+
"output_config": {
44+
"predictions_format": "jsonl",
45+
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
46+
},
47+
}
48+
parent = f"projects/{project}/locations/{location}"
49+
response = client.create_batch_prediction_job(
50+
parent=parent, batch_prediction_job=batch_prediction_job
51+
)
52+
print("response:", response)
53+
54+
55+
# [END aiplatform_create_batch_prediction_job_text_classification_sample]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from uuid import uuid4
16+
import pytest
17+
import os
18+
19+
import helpers
20+
21+
import create_batch_prediction_job_text_classification_sample
22+
import cancel_batch_prediction_job_sample
23+
import delete_batch_prediction_job_sample
24+
25+
from google.cloud import aiplatform
26+
27+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
28+
LOCATION = "us-central1"
29+
MODEL_ID = "3863595899074641920" # Permanent restaurant rating model
30+
DISPLAY_NAME = f"temp_create_batch_prediction_tcn_test_{uuid4()}"
31+
GCS_SOURCE_URI = (
32+
"gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl"
33+
)
34+
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"
35+
36+
37+
@pytest.fixture(scope="function")
38+
def shared_state():
39+
40+
shared_state = {}
41+
42+
yield shared_state
43+
44+
assert "/" in shared_state["batch_prediction_job_name"]
45+
46+
batch_prediction_job = shared_state["batch_prediction_job_name"].split("/")[-1]
47+
48+
# Stop the batch prediction job
49+
cancel_batch_prediction_job_sample.cancel_batch_prediction_job_sample(
50+
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
51+
)
52+
53+
job_client = aiplatform.gapic.JobServiceClient(
54+
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
55+
)
56+
57+
# Waiting for batch prediction job to be in CANCELLED state
58+
helpers.wait_for_job_state(
59+
get_job_method=job_client.get_batch_prediction_job,
60+
name=shared_state["batch_prediction_job_name"],
61+
)
62+
63+
# Delete the batch prediction job
64+
delete_batch_prediction_job_sample.delete_batch_prediction_job_sample(
65+
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
66+
)
67+
68+
69+
# Creating AutoML Text Classification batch prediction job
70+
def test_ucaip_generated_create_batch_prediction_tcn_sample(capsys, shared_state):
71+
72+
model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"
73+
74+
create_batch_prediction_job_text_classification_sample.create_batch_prediction_job_text_classification_sample(
75+
project=PROJECT_ID,
76+
display_name=DISPLAY_NAME,
77+
model=model_name,
78+
gcs_source_uri=GCS_SOURCE_URI,
79+
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
80+
)
81+
82+
out, _ = capsys.readouterr()
83+
84+
# Save resource name of the newly created batch prediction job
85+
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_batch_prediction_job_text_entity_extraction_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf.struct_pb2 import Value
18+
19+
20+
def create_batch_prediction_job_text_entity_extraction_sample(
21+
project: str,
22+
display_name: str,
23+
model: str,
24+
gcs_source_uri: str,
25+
gcs_destination_output_uri_prefix: str,
26+
location: str = "us-central1",
27+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
28+
):
29+
client_options = {"api_endpoint": api_endpoint}
30+
# Initialize client that will be used to create and send requests.
31+
# This client only needs to be created once, and can be reused for multiple requests.
32+
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
33+
34+
batch_prediction_job = {
35+
"display_name": display_name,
36+
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
37+
"model": model,
38+
"model_parameters": Value(),
39+
"input_config": {
40+
"instances_format": "jsonl",
41+
"gcs_source": {"uris": [gcs_source_uri]},
42+
},
43+
"output_config": {
44+
"predictions_format": "jsonl",
45+
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
46+
},
47+
}
48+
parent = f"projects/{project}/locations/{location}"
49+
response = client.create_batch_prediction_job(
50+
parent=parent, batch_prediction_job=batch_prediction_job
51+
)
52+
print("response:", response)
53+
54+
55+
# [END aiplatform_create_batch_prediction_job_text_entity_extraction_sample]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from uuid import uuid4
16+
import pytest
17+
import os
18+
19+
import helpers
20+
21+
import create_batch_prediction_job_text_entity_extraction_sample
22+
import cancel_batch_prediction_job_sample
23+
import delete_batch_prediction_job_sample
24+
25+
from google.cloud import aiplatform
26+
27+
PROJECT_ID = os.getenv("BUILD_SPECIFIC_GCLOUD_PROJECT")
28+
LOCATION = "us-central1"
29+
MODEL_ID = "5216364637146054656" # Permanent medical entity NL model
30+
DISPLAY_NAME = f"temp_create_batch_prediction_ten_test_{uuid4()}"
31+
GCS_SOURCE_URI = (
32+
"gs://ucaip-samples-test-output/inputs/batch_predict_TEN/ten_inputs.jsonl"
33+
)
34+
GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"
35+
36+
37+
@pytest.fixture(scope="function")
38+
def shared_state():
39+
40+
shared_state = {}
41+
42+
yield shared_state
43+
44+
assert "/" in shared_state["batch_prediction_job_name"]
45+
46+
batch_prediction_job = shared_state["batch_prediction_job_name"].split("/")[-1]
47+
48+
# Stop the batch prediction job
49+
cancel_batch_prediction_job_sample.cancel_batch_prediction_job_sample(
50+
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
51+
)
52+
53+
job_client = aiplatform.gapic.JobServiceClient(
54+
client_options={"api_endpoint": "us-central1-aiplatform.googleapis.com"}
55+
)
56+
57+
# Waiting for batch prediction job to be in CANCELLED state
58+
helpers.wait_for_job_state(
59+
get_job_method=job_client.get_batch_prediction_job,
60+
name=shared_state["batch_prediction_job_name"],
61+
)
62+
63+
# Delete the batch prediction job
64+
delete_batch_prediction_job_sample.delete_batch_prediction_job_sample(
65+
project=PROJECT_ID, batch_prediction_job_id=batch_prediction_job
66+
)
67+
68+
69+
# Creating AutoML Text Entity Extraction batch prediction job
70+
def test_ucaip_generated_create_batch_prediction_ten_sample(capsys, shared_state):
71+
72+
model_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}"
73+
74+
create_batch_prediction_job_text_entity_extraction_sample.create_batch_prediction_job_text_entity_extraction_sample(
75+
project=PROJECT_ID,
76+
display_name=DISPLAY_NAME,
77+
model=model_name,
78+
gcs_source_uri=GCS_SOURCE_URI,
79+
gcs_destination_output_uri_prefix=GCS_OUTPUT_URI,
80+
)
81+
82+
out, _ = capsys.readouterr()
83+
84+
# Save resource name of the newly created batch prediction job
85+
shared_state["batch_prediction_job_name"] = helpers.get_name(out)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample]
16+
from google.cloud import aiplatform
17+
from google.protobuf.struct_pb2 import Value
18+
19+
20+
def create_batch_prediction_job_text_sentiment_analysis_sample(
21+
project: str,
22+
display_name: str,
23+
model: str,
24+
gcs_source_uri: str,
25+
gcs_destination_output_uri_prefix: str,
26+
location: str = "us-central1",
27+
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
28+
):
29+
client_options = {"api_endpoint": api_endpoint}
30+
# Initialize client that will be used to create and send requests.
31+
# This client only needs to be created once, and can be reused for multiple requests.
32+
client = aiplatform.gapic.JobServiceClient(client_options=client_options)
33+
34+
batch_prediction_job = {
35+
"display_name": display_name,
36+
# Format: 'projects/{project}/locations/{location}/models/{model_id}'
37+
"model": model,
38+
"model_parameters": Value(),
39+
"input_config": {
40+
"instances_format": "jsonl",
41+
"gcs_source": {"uris": [gcs_source_uri]},
42+
},
43+
"output_config": {
44+
"predictions_format": "jsonl",
45+
"gcs_destination": {"output_uri_prefix": gcs_destination_output_uri_prefix},
46+
},
47+
}
48+
parent = f"projects/{project}/locations/{location}"
49+
response = client.create_batch_prediction_job(
50+
parent=parent, batch_prediction_job=batch_prediction_job
51+
)
52+
print("response:", response)
53+
54+
55+
# [END aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample]

0 commit comments

Comments
 (0)