1515# limitations under the License.
1616#
1717
18- from typing import List , Optional
18+ import datetime
19+ import re
20+ from typing import Any , Dict , List , Optional
1921
22+ from google .auth import credentials as auth_credentials
23+ from google .cloud import aiplatform_v1beta1
24+ from google .cloud .aiplatform import base
25+ from google .cloud .aiplatform import compat
26+ from google .cloud .aiplatform import initializer
27+ from google .cloud .aiplatform import pipeline_job_schedules
28+ from google .cloud .aiplatform import utils
29+ from google .cloud .aiplatform .constants import pipeline as pipeline_constants
30+ from google .cloud .aiplatform .metadata import constants as metadata_constants
31+ from google .cloud .aiplatform .metadata import experiment_resources
2032from google .cloud .aiplatform .pipeline_jobs import (
2133 PipelineJob as PipelineJobGa ,
2234)
2335from google .cloud .aiplatform_v1 .services .pipeline_service import (
2436 PipelineServiceClient as PipelineServiceClientGa ,
2537)
26- from google .cloud import aiplatform_v1beta1
27- from google .cloud .aiplatform import compat , pipeline_job_schedules
28- from google .cloud .aiplatform import initializer
29- from google .cloud .aiplatform import utils
3038
31- from google .cloud .aiplatform .metadata import constants as metadata_constants
32- from google .cloud .aiplatform .metadata import experiment_resources
39+ from google .protobuf import json_format
40+
41+
42+ _LOGGER = base .Logger (__name__ )
43+
44+ # Pattern for valid names used as a Vertex resource name.
45+ _VALID_NAME_PATTERN = pipeline_constants ._VALID_NAME_PATTERN
46+
47+ # Pattern for an Artifact Registry URL.
48+ _VALID_AR_URL = pipeline_constants ._VALID_AR_URL
49+
50+ # Pattern for any JSON or YAML file over HTTPS.
51+ _VALID_HTTPS_URL = pipeline_constants ._VALID_HTTPS_URL
52+
53+
54+ def _get_current_time () -> datetime .datetime :
55+ """Gets the current timestamp."""
56+ return datetime .datetime .now ()
57+
58+
59+ def _set_enable_caching_value (
60+ pipeline_spec : Dict [str , Any ], enable_caching : bool
61+ ) -> None :
62+ """Sets pipeline tasks caching options.
63+
64+ Args:
65+ pipeline_spec (Dict[str, Any]):
66+ Required. The dictionary of pipeline spec.
67+ enable_caching (bool):
68+ Required. Whether to enable caching.
69+ """
70+ for component in [pipeline_spec ["root" ]] + list (
71+ pipeline_spec ["components" ].values ()
72+ ):
73+ if "dag" in component :
74+ for task in component ["dag" ]["tasks" ].values ():
75+ task ["cachingOptions" ] = {"enableCache" : enable_caching }
3376
3477
3578class _PipelineJob (
@@ -42,6 +85,192 @@ class _PipelineJob(
4285):
4386 """Preview PipelineJob resource for Vertex AI."""
4487
88+ def __init__ (
89+ self ,
90+ display_name : str ,
91+ template_path : str ,
92+ job_id : Optional [str ] = None ,
93+ pipeline_root : Optional [str ] = None ,
94+ parameter_values : Optional [Dict [str , Any ]] = None ,
95+ input_artifacts : Optional [Dict [str , str ]] = None ,
96+ enable_caching : Optional [bool ] = None ,
97+ encryption_spec_key_name : Optional [str ] = None ,
98+ labels : Optional [Dict [str , str ]] = None ,
99+ credentials : Optional [auth_credentials .Credentials ] = None ,
100+ project : Optional [str ] = None ,
101+ location : Optional [str ] = None ,
102+ failure_policy : Optional [str ] = None ,
103+ enable_preflight_validations : Optional [bool ] = False ,
104+ ):
105+ """Retrieves a PipelineJob resource and instantiates its
106+ representation.
107+
108+ Args:
109+ display_name (str):
110+ Required. The user-defined name of this Pipeline.
111+ template_path (str):
112+ Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It
113+ can be a local path, a Google Cloud Storage URI (e.g. "gs://project.name"),
114+ an Artifact Registry URI (e.g.
115+ "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI.
116+ job_id (str):
117+ Optional. The unique ID of the job run.
118+ If not specified, pipeline name + timestamp will be used.
119+ pipeline_root (str):
120+ Optional. The root of the pipeline outputs. If not set, the staging bucket
121+ set in aiplatform.init will be used. If that's not set a pipeline-specific
122+ artifacts bucket will be used.
123+ parameter_values (Dict[str, Any]):
124+ Optional. The mapping from runtime parameter names to its values that
125+ control the pipeline run.
126+ input_artifacts (Dict[str, str]):
127+ Optional. The mapping from the runtime parameter name for this artifact to its resource id.
128+ For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
129+ enable_caching (bool):
130+ Optional. Whether to turn on caching for the run.
131+
132+ If this is not set, defaults to the compile time settings, which
133+ are True for all tasks by default, while users may specify
134+ different caching options for individual tasks.
135+
136+ If this is set, the setting applies to all tasks in the pipeline.
137+
138+ Overrides the compile time settings.
139+ encryption_spec_key_name (str):
140+ Optional. The Cloud KMS resource identifier of the customer
141+ managed encryption key used to protect the job. Has the
142+ form:
143+ ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
144+ The key needs to be in the same region as where the compute
145+ resource is created.
146+
147+ If this is set, then all
148+ resources created by the PipelineJob will
149+ be encrypted with the provided encryption key.
150+
151+ Overrides encryption_spec_key_name set in aiplatform.init.
152+ labels (Dict[str, str]):
153+ Optional. The user defined metadata to organize PipelineJob.
154+ credentials (auth_credentials.Credentials):
155+ Optional. Custom credentials to use to create this PipelineJob.
156+ Overrides credentials set in aiplatform.init.
157+ project (str):
158+ Optional. The project that you want to run this PipelineJob in. If not set,
159+ the project set in aiplatform.init will be used.
160+ location (str):
161+ Optional. Location to create PipelineJob. If not set,
162+ location set in aiplatform.init will be used.
163+ failure_policy (str):
164+ Optional. The failure policy - "slow" or "fast".
165+ Currently, the default of a pipeline is that the pipeline will continue to
166+ run until no more tasks can be executed, also known as
167+ PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow").
168+ However, if a pipeline is set to
169+ PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"),
170+ it will stop scheduling any new tasks when a task has failed. Any
171+ scheduled tasks will continue to completion.
172+ enable_preflight_validations (bool):
173+ Optional. Whether to enable preflight validations or not.
174+
175+ Raises:
176+ ValueError: If job_id or labels have incorrect format.
177+ """
178+
179+ super ().__init__ (
180+ display_name = display_name ,
181+ template_path = template_path ,
182+ job_id = job_id ,
183+ pipeline_root = pipeline_root ,
184+ parameter_values = parameter_values ,
185+ input_artifacts = input_artifacts ,
186+ enable_caching = enable_caching ,
187+ encryption_spec_key_name = encryption_spec_key_name ,
188+ labels = labels ,
189+ credentials = credentials ,
190+ project = project ,
191+ location = location ,
192+ failure_policy = failure_policy ,
193+ )
194+
195+ # needs to rebuild the v1beta version of pipeline_job and runtime_config
196+ pipeline_json = utils .yaml_utils .load_yaml (
197+ template_path , self .project , self .credentials
198+ )
199+
200+ # Pipeline_json can be either PipelineJob or PipelineSpec.
201+ if pipeline_json .get ("pipelineSpec" ) is not None :
202+ pipeline_job = pipeline_json
203+ pipeline_root = (
204+ pipeline_root
205+ or pipeline_job ["pipelineSpec" ].get ("defaultPipelineRoot" )
206+ or pipeline_job ["runtimeConfig" ].get ("gcsOutputDirectory" )
207+ or initializer .global_config .staging_bucket
208+ )
209+ else :
210+ pipeline_job = {
211+ "pipelineSpec" : pipeline_json ,
212+ "runtimeConfig" : {},
213+ }
214+ pipeline_root = (
215+ pipeline_root
216+ or pipeline_job ["pipelineSpec" ].get ("defaultPipelineRoot" )
217+ or initializer .global_config .staging_bucket
218+ )
219+ pipeline_root = (
220+ pipeline_root
221+ or utils .gcs_utils .generate_gcs_directory_for_pipeline_artifacts (
222+ project = project ,
223+ location = location ,
224+ )
225+ )
226+ builder = utils .pipeline_utils .PipelineRuntimeConfigBuilder .from_job_spec_json (
227+ pipeline_job
228+ )
229+ builder .update_pipeline_root (pipeline_root )
230+ builder .update_runtime_parameters (parameter_values )
231+ builder .update_input_artifacts (input_artifacts )
232+
233+ builder .update_failure_policy (failure_policy )
234+ runtime_config_dict = builder .build ()
235+
236+ runtime_config = aiplatform_v1beta1 .PipelineJob .RuntimeConfig ()._pb
237+ json_format .ParseDict (runtime_config_dict , runtime_config )
238+
239+ pipeline_name = pipeline_job ["pipelineSpec" ]["pipelineInfo" ]["name" ]
240+ self .job_id = job_id or "{pipeline_name}-{timestamp}" .format (
241+ pipeline_name = re .sub ("[^-0-9a-z]+" , "-" , pipeline_name .lower ())
242+ .lstrip ("-" )
243+ .rstrip ("-" ),
244+ timestamp = _get_current_time ().strftime ("%Y%m%d%H%M%S" ),
245+ )
246+ if not _VALID_NAME_PATTERN .match (self .job_id ):
247+ raise ValueError (
248+ f"Generated job ID: { self .job_id } is illegal as a Vertex pipelines job ID. "
249+ "Expecting an ID following the regex pattern "
250+ f'"{ _VALID_NAME_PATTERN .pattern [1 :- 1 ]} "'
251+ )
252+
253+ if enable_caching is not None :
254+ _set_enable_caching_value (pipeline_job ["pipelineSpec" ], enable_caching )
255+
256+ pipeline_job_args = {
257+ "display_name" : display_name ,
258+ "pipeline_spec" : pipeline_job ["pipelineSpec" ],
259+ "labels" : labels ,
260+ "runtime_config" : runtime_config ,
261+ "encryption_spec" : initializer .global_config .get_encryption_spec (
262+ encryption_spec_key_name = encryption_spec_key_name
263+ ),
264+ "preflight_validations" : enable_preflight_validations ,
265+ }
266+
267+ if _VALID_AR_URL .match (template_path ) or _VALID_HTTPS_URL .match (template_path ):
268+ pipeline_job_args ["template_uri" ] = template_path
269+
270+ self ._v1_beta1_pipeline_job = aiplatform_v1beta1 .PipelineJob (
271+ ** pipeline_job_args
272+ )
273+
45274 def create_schedule (
46275 self ,
47276 cron_expression : str ,
@@ -180,3 +409,79 @@ def batch_delete(
180409 v1beta1_client = client .select_version (compat .V1BETA1 )
181410 operation = v1beta1_client .batch_delete_pipeline_jobs (request )
182411 return operation .result ()
412+
413+ def submit (
414+ self ,
415+ service_account : Optional [str ] = None ,
416+ network : Optional [str ] = None ,
417+ reserved_ip_ranges : Optional [List [str ]] = None ,
418+ create_request_timeout : Optional [float ] = None ,
419+ job_id : Optional [str ] = None ,
420+ ) -> None :
421+ """Run this configured PipelineJob.
422+
423+ Args:
424+ service_account (str):
425+ Optional. Specifies the service account for workload run-as account.
426+ Users submitting jobs must have act-as permission on this run-as account.
427+ network (str):
428+ Optional. The full name of the Compute Engine network to which the job
429+ should be peered. For example, projects/12345/global/networks/myVPC.
430+
431+ Private services access must already be configured for the network.
432+ If left unspecified, the network set in aiplatform.init will be used.
433+ Otherwise, the job is not peered with any network.
434+ reserved_ip_ranges (List[str]):
435+ Optional. A list of names for the reserved IP ranges under the VPC
436+ network that can be used for this PipelineJob's workload. For example: ['vertex-ai-ip-range'].
437+
438+ If left unspecified, the job will be deployed to any IP ranges under
439+ the provided VPC network.
440+ create_request_timeout (float):
441+ Optional. The timeout for the create request in seconds.
442+ job_id (str):
443+ Optional. The ID to use for the PipelineJob, which will become the final
444+ component of the PipelineJob name. If not provided, an ID will be
445+ automatically generated.
446+ """
447+ network = network or initializer .global_config .network
448+ service_account = service_account or initializer .global_config .service_account
449+ gca_resouce = self ._v1_beta1_pipeline_job
450+
451+ if service_account :
452+ gca_resouce .service_account = service_account
453+
454+ if network :
455+ gca_resouce .network = network
456+
457+ if reserved_ip_ranges :
458+ gca_resouce .reserved_ip_ranges = reserved_ip_ranges
459+ user_project = initializer .global_config .project
460+ user_location = initializer .global_config .location
461+ parent = initializer .global_config .common_location_path (
462+ project = user_project , location = user_location
463+ )
464+
465+ client = self ._instantiate_client (
466+ location = user_location ,
467+ appended_user_agent = ["preview-pipeline-job-submit" ],
468+ )
469+ v1beta1_client = client .select_version (compat .V1BETA1 )
470+
471+ _LOGGER .log_create_with_lro (self .__class__ )
472+
473+ request = aiplatform_v1beta1 .CreatePipelineJobRequest (
474+ parent = parent ,
475+ pipeline_job = self ._v1_beta1_pipeline_job ,
476+ pipeline_job_id = job_id or self .job_id ,
477+ )
478+
479+ response = v1beta1_client .create_pipeline_job (request = request )
480+
481+ self ._gca_resource = response
482+
483+ _LOGGER .log_create_complete_with_getter (
484+ self .__class__ , self ._gca_resource , "pipeline_job"
485+ )
486+
487+ _LOGGER .info ("View Pipeline Job:\n %s" % self ._dashboard_uri ())
0 commit comments