Skip to content

Commit 335de38

Browse files
committed
Address comments
1 parent 114b16c commit 335de38

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

google/cloud/aiplatform/explain/metadata/tf/v2/saved_model_metadata_builder.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
self,
3131
model_path: str,
3232
signature_name: Optional[str] = None,
33-
outputs_to_explain: Optional[List[str]] = (),
33+
outputs_to_explain: Optional[List[str]] = None,
3434
**kwargs
3535
) -> None:
3636
"""Initializes a SavedModelMetadataBuilder object.
@@ -93,26 +93,18 @@ def _infer_metadata_entries_from_model(
9393
Inferred input metadata and output metadata from the model.
9494
9595
Raises:
96-
ValueError if specified name is not found in signature outputs.
96+
ValueError if specified name is not found in signature outputs.
9797
"""
9898

9999
loaded_sig = self._loaded_model.signatures[signature_name]
100100
_, input_sig = loaded_sig.structured_input_signature
101101
output_sig = loaded_sig.structured_outputs
102102
input_mds = {}
103103
for name, tensor_spec in input_sig.items():
104-
if tensor_spec.dtype.is_floating:
105-
input_mds[
106-
name
107-
] = explanation_metadata.ExplanationMetadata.InputMetadata(
108-
input_tensor_name=name
109-
)
110-
else:
111-
input_mds[
112-
name
113-
] = explanation_metadata.ExplanationMetadata.InputMetadata(
114-
input_tensor_name=name, modality="categorical",
115-
)
104+
input_mds[name] = explanation_metadata.ExplanationMetadata.InputMetadata(
105+
input_tensor_name=name,
106+
modality=None if tensor_spec.dtype.is_floating else "categorical",
107+
)
116108

117109
output_mds = {}
118110
for name in output_sig:

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
]
3737
metadata_extra_require = ["pandas >= 1.0.0"]
3838
xai_extra_require = ["tensorflow-cpu>=2.3.0, <=2.5.0"]
39-
full_extra_require = (
40-
tensorboard_extra_require + metadata_extra_require + xai_extra_require
39+
full_extra_require = list(
40+
set(tensorboard_extra_require + metadata_extra_require + xai_extra_require)
4141
)
4242
testing_extra_require = full_extra_require + ["grpcio-testing ~= 1.34.0"]
4343

0 commit comments

Comments
 (0)