@@ -163,6 +163,56 @@ def mock_stream_generate_content(
163163 yield response
164164
165165
166+ def mock_generate_content (
167+ self ,
168+ request : gapic_prediction_service_types .GenerateContentRequest ,
169+ * ,
170+ model : Optional [str ] = None ,
171+ contents : Optional [MutableSequence [gapic_content_types .Content ]] = None ,
172+ ) -> Iterable [gapic_prediction_service_types .GenerateContentResponse ]:
173+ is_continued_chat = len (request .contents ) > 1
174+ has_tools = bool (request .tools )
175+
176+ if has_tools :
177+ has_function_response = any (
178+ "function_response" in content .parts [0 ] for content in request .contents
179+ )
180+ needs_function_call = not has_function_response
181+ if needs_function_call :
182+ response_part_struct = _RESPONSE_FUNCTION_CALL_PART_STRUCT
183+ else :
184+ response_part_struct = _RESPONSE_AFTER_FUNCTION_CALL_PART_STRUCT
185+ elif is_continued_chat :
186+ response_part_struct = {"text" : "Other planets may have different sky color." }
187+ else :
188+ response_part_struct = _RESPONSE_TEXT_PART_STRUCT
189+
190+ return gapic_prediction_service_types .GenerateContentResponse (
191+ candidates = [
192+ gapic_content_types .Candidate (
193+ index = 0 ,
194+ content = gapic_content_types .Content (
195+ # Model currently does not identify itself
196+ # role="model",
197+ parts = [
198+ gapic_content_types .Part (response_part_struct ),
199+ ],
200+ ),
201+ finish_reason = gapic_content_types .Candidate .FinishReason .STOP ,
202+ safety_ratings = [
203+ gapic_content_types .SafetyRating (rating )
204+ for rating in _RESPONSE_SAFETY_RATINGS_STRUCT
205+ ],
206+ citation_metadata = gapic_content_types .CitationMetadata (
207+ citations = [
208+ gapic_content_types .Citation (_RESPONSE_CITATION_STRUCT ),
209+ ]
210+ ),
211+ ),
212+ ],
213+ )
214+
215+
166216@pytest .mark .usefixtures ("google_auth_mock" )
167217class TestGenerativeModels :
168218 """Unit tests for the generative models."""
@@ -178,8 +228,8 @@ def teardown_method(self):
178228
179229 @mock .patch .object (
180230 target = prediction_service .PredictionServiceClient ,
181- attribute = "stream_generate_content " ,
182- new = mock_stream_generate_content ,
231+ attribute = "generate_content " ,
232+ new = mock_generate_content ,
183233 )
184234 def test_generate_content (self ):
185235 model = generative_models .GenerativeModel ("gemini-pro" )
@@ -212,8 +262,8 @@ def test_generate_content_streaming(self):
212262
213263 @mock .patch .object (
214264 target = prediction_service .PredictionServiceClient ,
215- attribute = "stream_generate_content " ,
216- new = mock_stream_generate_content ,
265+ attribute = "generate_content " ,
266+ new = mock_generate_content ,
217267 )
218268 def test_chat_send_message (self ):
219269 model = generative_models .GenerativeModel ("gemini-pro" )
@@ -225,8 +275,8 @@ def test_chat_send_message(self):
225275
226276 @mock .patch .object (
227277 target = prediction_service .PredictionServiceClient ,
228- attribute = "stream_generate_content " ,
229- new = mock_stream_generate_content ,
278+ attribute = "generate_content " ,
279+ new = mock_generate_content ,
230280 )
231281 def test_chat_function_calling (self ):
232282 get_current_weather_func = generative_models .FunctionDeclaration (
0 commit comments