Skip to content

Commit 03cf301

Browse files
authored
Feat: Add LIT methods for Pandas DataFrame and TensorFlow saved model. (googleapis#874)
Adds methods from go/lit-xai-notebook for Pandas DataFrame and TensorFlow saved model. b/208628825 Example Colab: go/lit-vertex-pr-1
1 parent 8a8a4fa commit 03cf301

File tree

3 files changed

+405
-3
lines changed

3 files changed

+405
-3
lines changed
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2021 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from typing import Dict, List, Tuple, Union
18+
19+
try:
20+
from lit_nlp.api import dataset as lit_dataset
21+
from lit_nlp.api import model as lit_model
22+
from lit_nlp.api import types as lit_types
23+
from lit_nlp import notebook
24+
except ImportError:
25+
raise ImportError(
26+
"LIT is not installed and is required to get Dataset as the return format. "
27+
'Please install the SDK using "pip install python-aiplatform[lit]"'
28+
)
29+
30+
try:
31+
import tensorflow as tf
32+
except ImportError:
33+
raise ImportError(
34+
"Tensorflow is not installed and is required to load saved model. "
35+
'Please install the SDK using "pip install pip install python-aiplatform[lit]"'
36+
)
37+
38+
try:
39+
import pandas as pd
40+
except ImportError:
41+
raise ImportError(
42+
"Pandas is not installed and is required to read the dataset. "
43+
'Please install Pandas using "pip install python-aiplatform[lit]"'
44+
)
45+
46+
47+
class _VertexLitDataset(lit_dataset.Dataset):
48+
"""LIT dataset class for the Vertex LIT integration.
49+
50+
This is used in the create_lit_dataset function.
51+
"""
52+
53+
def __init__(
54+
self,
55+
dataset: pd.DataFrame,
56+
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
57+
):
58+
"""Construct a VertexLitDataset.
59+
Args:
60+
dataset:
61+
Required. A Pandas DataFrame that includes feature column names and data.
62+
column_types:
63+
Required. An OrderedDict of string names matching the columns of the dataset
64+
as the key, and the associated LitType of the column.
65+
"""
66+
self._examples = dataset.to_dict(orient="records")
67+
self._column_types = column_types
68+
69+
def spec(self):
70+
"""Return a spec describing dataset elements."""
71+
return dict(self._column_types)
72+
73+
74+
class _VertexLitModel(lit_model.Model):
75+
"""LIT model class for the Vertex LIT integration.
76+
77+
This is used in the create_lit_model function.
78+
"""
79+
80+
def __init__(
81+
self,
82+
model: str,
83+
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
84+
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
85+
):
86+
"""Construct a VertexLitModel.
87+
Args:
88+
model:
89+
Required. A string reference to a local TensorFlow saved model directory.
90+
The model must have at most one input and one output tensor.
91+
input_types:
92+
Required. An OrderedDict of string names matching the features of the model
93+
as the key, and the associated LitType of the feature.
94+
output_types:
95+
Required. An OrderedDict of string names matching the labels of the model
96+
as the key, and the associated LitType of the label.
97+
"""
98+
self._loaded_model = tf.saved_model.load(model)
99+
serving_default = self._loaded_model.signatures[
100+
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
101+
]
102+
_, self._kwargs_signature = serving_default.structured_input_signature
103+
self._output_signature = serving_default.structured_outputs
104+
105+
if len(self._kwargs_signature) != 1:
106+
raise ValueError("Please use a model with only one input tensor.")
107+
108+
if len(self._output_signature) != 1:
109+
raise ValueError("Please use a model with only one output tensor.")
110+
111+
self._input_types = input_types
112+
self._output_types = output_types
113+
114+
def predict_minibatch(
115+
self, inputs: List[lit_types.JsonDict]
116+
) -> List[lit_types.JsonDict]:
117+
"""Returns predictions for a single batch of examples.
118+
Args:
119+
inputs:
120+
sequence of inputs, following model.input_spec()
121+
Returns:
122+
list of outputs, following model.output_spec()
123+
"""
124+
instances = []
125+
for input in inputs:
126+
instance = [input[feature] for feature in self._input_types]
127+
instances.append(instance)
128+
prediction_input_dict = {
129+
next(iter(self._kwargs_signature)): tf.convert_to_tensor(instances)
130+
}
131+
prediction_dict = self._loaded_model.signatures[
132+
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
133+
](**prediction_input_dict)
134+
predictions = prediction_dict[next(iter(self._output_signature))].numpy()
135+
outputs = []
136+
for prediction in predictions:
137+
outputs.append(
138+
{
139+
label: value
140+
for label, value in zip(self._output_types.keys(), prediction)
141+
}
142+
)
143+
return outputs
144+
145+
def input_spec(self) -> lit_types.Spec:
146+
"""Return a spec describing model inputs."""
147+
return dict(self._input_types)
148+
149+
def output_spec(self) -> lit_types.Spec:
150+
"""Return a spec describing model outputs."""
151+
return self._output_types
152+
153+
154+
def create_lit_dataset(
155+
dataset: pd.DataFrame,
156+
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
157+
) -> lit_dataset.Dataset:
158+
"""Creates a LIT Dataset object.
159+
Args:
160+
dataset:
161+
Required. A Pandas DataFrame that includes feature column names and data.
162+
column_types:
163+
Required. An OrderedDict of string names matching the columns of the dataset
164+
as the key, and the associated LitType of the column.
165+
Returns:
166+
A LIT Dataset object that has the data from the dataset provided.
167+
"""
168+
return _VertexLitDataset(dataset, column_types)
169+
170+
171+
def create_lit_model(
172+
model: str,
173+
input_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
174+
output_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
175+
) -> lit_model.Model:
176+
"""Creates a LIT Model object.
177+
Args:
178+
model:
179+
Required. A string reference to a local TensorFlow saved model directory.
180+
The model must have at most one input and one output tensor.
181+
input_types:
182+
Required. An OrderedDict of string names matching the features of the model
183+
as the key, and the associated LitType of the feature.
184+
output_types:
185+
Required. An OrderedDict of string names matching the labels of the model
186+
as the key, and the associated LitType of the label.
187+
Returns:
188+
A LIT Model object that has the same functionality as the model provided.
189+
"""
190+
return _VertexLitModel(model, input_types, output_types)
191+
192+
193+
def open_lit(
194+
models: Dict[str, lit_model.Model],
195+
datasets: Dict[str, lit_dataset.Dataset],
196+
open_in_new_tab: bool = True,
197+
):
198+
"""Open LIT from the provided models and datasets.
199+
Args:
200+
models:
201+
Required. A list of LIT models to open LIT with.
202+
input_types:
203+
Required. A lit of LIT datasets to open LIT with.
204+
open_in_new_tab:
205+
Optional. A boolean to choose if LIT open in a new tab or not.
206+
Raises:
207+
ImportError if LIT is not installed.
208+
"""
209+
widget = notebook.LitWidget(models, datasets, open_in_new_tab=open_in_new_tab)
210+
widget.render()
211+
212+
213+
def set_up_and_open_lit(
214+
dataset: Union[pd.DataFrame, lit_dataset.Dataset],
215+
column_types: "OrderedDict[str, lit_types.LitType]", # noqa: F821
216+
model: Union[str, lit_model.Model],
217+
input_types: Union[List[str], Dict[str, lit_types.LitType]],
218+
output_types: Union[str, List[str], Dict[str, lit_types.LitType]],
219+
open_in_new_tab: bool = True,
220+
) -> Tuple[lit_dataset.Dataset, lit_model.Model]:
221+
"""Creates a LIT dataset and model and opens LIT.
222+
Args:
223+
dataset:
224+
Required. A Pandas DataFrame that includes feature column names and data.
225+
column_types:
226+
Required. An OrderedDict of string names matching the columns of the dataset
227+
as the key, and the associated LitType of the column.
228+
model:
229+
Required. A string reference to a TensorFlow saved model directory.
230+
The model must have at most one input and one output tensor.
231+
input_types:
232+
Required. An OrderedDict of string names matching the features of the model
233+
as the key, and the associated LitType of the feature.
234+
output_types:
235+
Required. An OrderedDict of string names matching the labels of the model
236+
as the key, and the associated LitType of the label.
237+
Returns:
238+
A Tuple of the LIT dataset and model created.
239+
Raises:
240+
ImportError if LIT or TensorFlow is not installed.
241+
ValueError if the model doesn't have only 1 input and output tensor.
242+
"""
243+
if not isinstance(dataset, lit_dataset.Dataset):
244+
dataset = create_lit_dataset(dataset, column_types)
245+
246+
if not isinstance(model, lit_model.Model):
247+
model = create_lit_model(model, input_types, output_types)
248+
249+
open_lit({"model": model}, {"dataset": dataset}, open_in_new_tab=open_in_new_tab)
250+
251+
return dataset, model

setup.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,25 @@
3636
tensorboard_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
3737
metadata_extra_require = ["pandas >= 1.0.0"]
3838
xai_extra_require = ["tensorflow >=2.3.0, <=2.5.0"]
39+
lit_extra_require = ["tensorflow >= 2.3.0", "pandas >= 1.0.0", "lit-nlp >= 0.4.0"]
3940
profiler_extra_require = [
4041
"tensorboard-plugin-profile >= 2.4.0",
4142
"werkzeug >= 2.0.0",
4243
"tensorflow >=2.4.0",
4344
]
4445

4546
full_extra_require = list(
46-
set(tensorboard_extra_require + metadata_extra_require + xai_extra_require)
47+
set(
48+
tensorboard_extra_require
49+
+ metadata_extra_require
50+
+ xai_extra_require
51+
+ lit_extra_require
52+
)
4753
)
4854
testing_extra_require = (
49-
full_extra_require + profiler_extra_require + ["grpcio-testing", "pytest-xdist"]
55+
full_extra_require
56+
+ profiler_extra_require
57+
+ ["grpcio-testing", "pytest-xdist", "ipython"]
5058
)
5159

5260

@@ -88,7 +96,8 @@
8896
"tensorboard": tensorboard_extra_require,
8997
"testing": testing_extra_require,
9098
"xai": xai_extra_require,
91-
"cloud-profiler": profiler_extra_require,
99+
"lit": lit_extra_require,
100+
"cloud_profiler": profiler_extra_require,
92101
},
93102
python_requires=">=3.6",
94103
scripts=[],

0 commit comments

Comments
 (0)