Skip to content

Commit 918998c

Browse files
authored
feat: add tf1 metadata builder (#526)
* feat: add tf1 metadata builder * Change import checks
1 parent fdeb51b commit 918998c

File tree

5 files changed

+264
-2
lines changed

5 files changed

+264
-2
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from google.protobuf import json_format
18+
from typing import Any, Dict, List, Optional
19+
20+
from google.cloud.aiplatform.compat.types import (
21+
explanation_metadata_v1beta1 as explanation_metadata,
22+
)
23+
from google.cloud.aiplatform.explain.metadata import metadata_builder
24+
25+
26+
class SavedModelMetadataBuilder(metadata_builder.MetadataBuilder):
27+
"""Metadata builder class that accepts a TF1 saved model."""
28+
29+
def __init__(
30+
self,
31+
model_path: str,
32+
tags: Optional[List[str]] = None,
33+
signature_name: Optional[str] = None,
34+
outputs_to_explain: Optional[List[str]] = None,
35+
) -> None:
36+
"""Initializes a SavedModelMetadataBuilder object.
37+
38+
Args:
39+
model_path:
40+
Required. Local or GCS path to load the saved model from.
41+
tags:
42+
Optional. Tags to identify the model graph. If None or empty,
43+
TensorFlow's default serving tag will be used.
44+
signature_name:
45+
Optional. Name of the signature to be explained. Inputs and
46+
outputs of this signature will be written in the metadata. If not
47+
provided, the default signature will be used.
48+
outputs_to_explain:
49+
Optional. List of output names to explain. Only single output is
50+
supported for now. Hence, the list should contain one element.
51+
This parameter is required if the model signature (provided via
52+
signature_name) specifies multiple outputs.
53+
54+
Raises:
55+
ValueError if outputs_to_explain contains more than 1 element or
56+
signature contains multiple outputs.
57+
"""
58+
if outputs_to_explain:
59+
if len(outputs_to_explain) > 1:
60+
raise ValueError(
61+
"Only one output is supported at the moment. "
62+
f"Received: {outputs_to_explain}."
63+
)
64+
self._output_to_explain = next(iter(outputs_to_explain))
65+
66+
try:
67+
import tensorflow.compat.v1 as tf
68+
except ImportError:
69+
raise ImportError(
70+
"Tensorflow is not installed and is required to load saved model. "
71+
'Please install the SDK using "pip install "tensorflow>=1.15,<2.0""'
72+
)
73+
74+
if not signature_name:
75+
signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
76+
self._tags = tags or [tf.saved_model.tag_constants.SERVING]
77+
self._graph = tf.Graph()
78+
79+
with self.graph.as_default():
80+
self._session = tf.Session(graph=self.graph)
81+
self._metagraph_def = tf.saved_model.loader.load(
82+
sess=self.session, tags=self._tags, export_dir=model_path
83+
)
84+
if signature_name not in self._metagraph_def.signature_def:
85+
raise ValueError(
86+
f"Serving sigdef key {signature_name} not in the signature def."
87+
)
88+
serving_sigdef = self._metagraph_def.signature_def[signature_name]
89+
if not outputs_to_explain:
90+
if len(serving_sigdef.outputs) > 1:
91+
raise ValueError(
92+
"The signature contains multiple outputs. Specify "
93+
'an output via "outputs_to_explain" parameter.'
94+
)
95+
self._output_to_explain = next(iter(serving_sigdef.outputs.keys()))
96+
97+
self._inputs = _create_input_metadata_from_signature(serving_sigdef.inputs)
98+
self._outputs = _create_output_metadata_from_signature(
99+
serving_sigdef.outputs, self._output_to_explain
100+
)
101+
102+
@property
103+
def graph(self) -> "tf.Graph": # noqa: F821
104+
return self._graph
105+
106+
@property
107+
def session(self) -> "tf.Session": # noqa: F821
108+
return self._session
109+
110+
def get_metadata(self) -> Dict[str, Any]:
111+
"""Returns the current metadata as a dictionary.
112+
113+
Returns:
114+
Json format of the explanation metadata.
115+
"""
116+
current_md = explanation_metadata.ExplanationMetadata(
117+
inputs=self._inputs, outputs=self._outputs,
118+
)
119+
return json_format.MessageToDict(current_md._pb)
120+
121+
122+
def _create_input_metadata_from_signature(
123+
signature_inputs: Dict[str, "tf.Tensor"] # noqa: F821
124+
) -> Dict[str, explanation_metadata.ExplanationMetadata.InputMetadata]:
125+
"""Creates InputMetadata from signature inputs.
126+
127+
Args:
128+
signature_inputs:
129+
Required. Inputs of the signature to be explained. If not provided,
130+
the default signature will be used.
131+
132+
Returns:
133+
Inferred input metadata from the model.
134+
"""
135+
input_mds = {}
136+
for key, tensor in signature_inputs.items():
137+
input_mds[key] = explanation_metadata.ExplanationMetadata.InputMetadata(
138+
input_tensor_name=tensor.name
139+
)
140+
return input_mds
141+
142+
143+
def _create_output_metadata_from_signature(
144+
signature_outputs: Dict[str, "tf.Tensor"], # noqa: F821
145+
output_to_explain: Optional[str] = None,
146+
) -> Dict[str, explanation_metadata.ExplanationMetadata.OutputMetadata]:
147+
"""Creates OutputMetadata from signature inputs.
148+
149+
Args:
150+
signature_outputs:
151+
Required. Inputs of the signature to be explained. If not provided,
152+
the default signature will be used.
153+
output_to_explain:
154+
Optional. Output name to explain.
155+
156+
Returns:
157+
Inferred output metadata from the model.
158+
"""
159+
output_mds = {}
160+
for key, tensor in signature_outputs.items():
161+
if not output_to_explain or output_to_explain == key:
162+
output_mds[key] = explanation_metadata.ExplanationMetadata.OutputMetadata(
163+
output_tensor_name=tensor.name
164+
)
165+
return output_mds

google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
3838
Args:
3939
model_path:
40-
Required. Path to load the saved model from.
40+
Required. Local or GCS path to load the saved model from.
4141
signature_name:
4242
Optional. Name of the signature to be explained. Inputs and
4343
outputs of this signature will be written in the metadata. If not
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import tensorflow.compat.v1 as tf
19+
20+
from google.cloud.aiplatform.explain.metadata.tf.v1 import saved_model_metadata_builder
21+
22+
23+
class SavedModelMetadataBuilderTF1Test(tf.test.TestCase):
24+
def _set_up(self):
25+
self.sess = tf.Session(graph=tf.Graph())
26+
with self.sess.graph.as_default():
27+
self.x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name="inp")
28+
weights = tf.constant(1.0, shape=(10, 2), name="weights")
29+
bias_weight = tf.constant(1.0, shape=(2,), name="bias")
30+
self.linear_layer = tf.add(tf.matmul(self.x, weights), bias_weight)
31+
self.prediction = tf.nn.relu(self.linear_layer)
32+
# save the model
33+
self.model_path = self.get_temp_dir()
34+
builder = tf.saved_model.builder.SavedModelBuilder(self.model_path)
35+
tensor_info_x = tf.saved_model.utils.build_tensor_info(self.x)
36+
tensor_info_pred = tf.saved_model.utils.build_tensor_info(self.prediction)
37+
tensor_info_lin = tf.saved_model.utils.build_tensor_info(self.linear_layer)
38+
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
39+
inputs={"x": tensor_info_x},
40+
outputs={"y": tensor_info_pred},
41+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
42+
)
43+
double_output_signature = tf.saved_model.signature_def_utils.build_signature_def(
44+
inputs={"x": tensor_info_x},
45+
outputs={"y": tensor_info_pred, "lin": tensor_info_lin},
46+
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
47+
)
48+
49+
builder.add_meta_graph_and_variables(
50+
self.sess,
51+
[tf.saved_model.tag_constants.SERVING],
52+
signature_def_map={
53+
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature,
54+
"double": double_output_signature,
55+
},
56+
)
57+
builder.save()
58+
59+
def test_get_metadata_correct_inputs(self):
60+
self._set_up()
61+
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
62+
self.model_path, tags=[tf.saved_model.tag_constants.SERVING]
63+
)
64+
expected_md = {
65+
"inputs": {"x": {"inputTensorName": "inp:0"}},
66+
"outputs": {"y": {"outputTensorName": "Relu:0"}},
67+
}
68+
69+
assert md_builder.get_metadata() == expected_md
70+
71+
def test_get_metadata_double_output(self):
72+
self._set_up()
73+
md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
74+
self.model_path, signature_name="double", outputs_to_explain=["lin"]
75+
)
76+
77+
expected_md = {
78+
"inputs": {"x": {"inputTensorName": "inp:0"}},
79+
"outputs": {"lin": {"outputTensorName": "Add:0"}},
80+
}
81+
82+
assert md_builder.get_metadata() == expected_md

tests/unit/aiplatform/test_explain_saved_model_metadata_builder_test.py renamed to tests/unit/aiplatform/test_explain_saved_model_metadata_builder_tf2_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from google.cloud.aiplatform.explain.metadata.tf.v2 import saved_model_metadata_builder
2323

2424

25-
class SavedModelMetadataBuilderTest(tf.test.TestCase):
25+
class SavedModelMetadataBuilderTF2Test(tf.test.TestCase):
2626
def test_get_metadata_sequential(self):
2727
# Set up for the sequential.
2828
self.seq_model = tf.keras.models.Sequential()

0 commit comments

Comments
 (0)