@@ -30,7 +30,7 @@ def __init__(
3030 self ,
3131 model_path : str ,
3232 signature_name : Optional [str ] = None ,
33- outputs_to_explain : Optional [List [str ]] = () ,
33+ outputs_to_explain : Optional [List [str ]] = None ,
3434 ** kwargs
3535 ) -> None :
3636 """Initializes a SavedModelMetadataBuilder object.
@@ -93,26 +93,18 @@ def _infer_metadata_entries_from_model(
9393 Inferred input metadata and output metadata from the model.
9494
9595 Raises:
96- ValueError if specified name is not found in signature outputs.
96+ ValueError if specified name is not found in signature outputs.
9797 """
9898
9999 loaded_sig = self ._loaded_model .signatures [signature_name ]
100100 _ , input_sig = loaded_sig .structured_input_signature
101101 output_sig = loaded_sig .structured_outputs
102102 input_mds = {}
103103 for name , tensor_spec in input_sig .items ():
104- if tensor_spec .dtype .is_floating :
105- input_mds [
106- name
107- ] = explanation_metadata .ExplanationMetadata .InputMetadata (
108- input_tensor_name = name
109- )
110- else :
111- input_mds [
112- name
113- ] = explanation_metadata .ExplanationMetadata .InputMetadata (
114- input_tensor_name = name , modality = "categorical" ,
115- )
104+ input_mds [name ] = explanation_metadata .ExplanationMetadata .InputMetadata (
105+ input_tensor_name = name ,
106+ modality = None if tensor_spec .dtype .is_floating else "categorical" ,
107+ )
116108
117109 output_mds = {}
118110 for name in output_sig :
0 commit comments