@@ -4037,6 +4037,13 @@ def run(
40374037 model_display_name : Optional [str ] = None ,
40384038 model_labels : Optional [Dict [str , str ]] = None ,
40394039 additional_experiments : Optional [List [str ]] = None ,
4040+ hierarchy_group_columns : Optional [List [str ]] = None ,
4041+ hierarchy_group_total_weight : Optional [float ] = None ,
4042+ hierarchy_temporal_total_weight : Optional [float ] = None ,
4043+ hierarchy_group_temporal_total_weight : Optional [float ] = None ,
4044+ window_column : Optional [str ] = None ,
4045+ window_stride_length : Optional [int ] = None ,
4046+ window_max_count : Optional [int ] = None ,
40404047 sync : bool = True ,
40414048 create_request_timeout : Optional [float ] = None ,
40424049 ) -> models .Model :
@@ -4157,7 +4164,7 @@ def run(
41574164 Applies only if [export_evaluated_data_items] is True and
41584165 [export_evaluated_data_items_bigquery_destination_uri] is specified.
41594166 quantiles (List[float]):
4160- Quantiles to use for the `minimize-quantile-loss`
4167+ Quantiles to use for the `` minimize-quantile-loss` `
41614168 [AutoMLForecastingTrainingJob.optimization_objective]. This argument is required in
41624169 this case.
41634170
@@ -4200,6 +4207,37 @@ def run(
42004207 Optional. Additional experiment flags for the time series forcasting training.
42014208 create_request_timeout (float):
42024209 Optional. The timeout for the create request in seconds.
4210+ hierarchy_group_columns (List[str]):
4211+ Optional. A list of time series attribute column names that
4212+ define the time series hierarchy. Only one level of hierarchy is
4213+ supported, ex. ``region`` for a hierarchy of stores or
4214+ ``department`` for a hierarchy of products. If multiple columns
4215+ are specified, time series will be grouped by their combined
4216+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4217+ to 5 columns are accepted. If no group columns are specified,
4218+ all time series are considered to be part of the same group.
4219+ hierarchy_group_total_weight (float):
4220+ Optional. The weight of the loss for predictions aggregated over
4221+ time series in the same hierarchy group.
4222+ hierarchy_temporal_total_weight (float):
4223+ Optional. The weight of the loss for predictions aggregated over
4224+ the horizon for a single time series.
4225+ hierarchy_group_temporal_total_weight (float):
4226+ Optional. The weight of the loss for predictions aggregated over
4227+ both the horizon and time series in the same hierarchy group.
4228+ window_column (str):
4229+ Optional. Name of the column that should be used to filter input
4230+ rows. The column should contain either booleans or string
4231+ booleans; if the value of the row is True, generate a sliding
4232+ window from that row.
4233+ window_stride_length (int):
4234+ Optional. Step length used to generate input examples. Every
4235+ ``window_stride_length`` rows will be used to generate a sliding
4236+ window.
4237+ window_max_count (int):
4238+ Optional. Number of rows that should be used to generate input
4239+ examples. If the total row count is larger than this number, the
4240+ input data will be randomly sampled to hit the count.
42034241 sync (bool):
42044242 Whether to execute this method synchronously. If False, this method
42054243 will be executed in concurrent Future and any downstream object will
@@ -4254,6 +4292,13 @@ def run(
42544292 validation_options = validation_options ,
42554293 model_display_name = model_display_name ,
42564294 model_labels = model_labels ,
4295+ hierarchy_group_columns = hierarchy_group_columns ,
4296+ hierarchy_group_total_weight = hierarchy_group_total_weight ,
4297+ hierarchy_temporal_total_weight = hierarchy_temporal_total_weight ,
4298+ hierarchy_group_temporal_total_weight = hierarchy_group_temporal_total_weight ,
4299+ window_column = window_column ,
4300+ window_stride_length = window_stride_length ,
4301+ window_max_count = window_max_count ,
42574302 sync = sync ,
42584303 create_request_timeout = create_request_timeout ,
42594304 )
@@ -4286,6 +4331,13 @@ def _run(
42864331 budget_milli_node_hours : int = 1000 ,
42874332 model_display_name : Optional [str ] = None ,
42884333 model_labels : Optional [Dict [str , str ]] = None ,
4334+ hierarchy_group_columns : Optional [List [str ]] = None ,
4335+ hierarchy_group_total_weight : Optional [float ] = None ,
4336+ hierarchy_temporal_total_weight : Optional [float ] = None ,
4337+ hierarchy_group_temporal_total_weight : Optional [float ] = None ,
4338+ window_column : Optional [str ] = None ,
4339+ window_stride_length : Optional [int ] = None ,
4340+ window_max_count : Optional [int ] = None ,
42894341 sync : bool = True ,
42904342 create_request_timeout : Optional [float ] = None ,
42914343 ) -> models .Model :
@@ -4453,6 +4505,37 @@ def _run(
44534505 are allowed.
44544506 See https://goo.gl/xmQnxf for more information
44554507 and examples of labels.
4508+ hierarchy_group_columns (List[str]):
4509+ Optional. A list of time series attribute column names that
4510+ define the time series hierarchy. Only one level of hierarchy is
4511+ supported, ex. ``region`` for a hierarchy of stores or
4512+ ``department`` for a hierarchy of products. If multiple columns
4513+ are specified, time series will be grouped by their combined
4514+ values, ex. (``blue``, ``large``) for ``color`` and ``size``, up
4515+ to 5 columns are accepted. If no group columns are specified,
4516+ all time series are considered to be part of the same group.
4517+ hierarchy_group_total_weight (float):
4518+ Optional. The weight of the loss for predictions aggregated over
4519+ time series in the same hierarchy group.
4520+ hierarchy_temporal_total_weight (float):
4521+ Optional. The weight of the loss for predictions aggregated over
4522+ the horizon for a single time series.
4523+ hierarchy_group_temporal_total_weight (float):
4524+ Optional. The weight of the loss for predictions aggregated over
4525+ both the horizon and time series in the same hierarchy group.
4526+ window_column (str):
4527+ Optional. Name of the column that should be used to filter input
4528+ rows. The column should contain either booleans or string
4529+ booleans; if the value of the row is True, generate a sliding
4530+ window from that row.
4531+ window_stride_length (int):
4532+ Optional. Step length used to generate input examples. Every
4533+ ``window_stride_length`` rows will be used to generate a sliding
4534+ window.
4535+ window_max_count (int):
4536+ Optional. Number of rows that should be used to generate input
4537+ examples. If the total row count is larger than this number, the
4538+ input data will be randomly sampled to hit the count.
44564539 sync (bool):
44574540 Whether to execute this method synchronously. If False, this method
44584541 will be executed in concurrent Future and any downstream object will
@@ -4482,6 +4565,12 @@ def _run(
44824565 % column_names
44834566 )
44844567
4568+ window_config = self ._create_window_config (
4569+ column = window_column ,
4570+ stride_length = window_stride_length ,
4571+ max_count = window_max_count ,
4572+ )
4573+
44854574 training_task_inputs_dict = {
44864575 # required inputs
44874576 "targetColumn" : target_column ,
@@ -4505,6 +4594,24 @@ def _run(
45054594 "optimizationObjective" : self ._optimization_objective ,
45064595 }
45074596
4597+ # TODO(TheMichaelHu): Remove the ifs once the API supports these inputs.
4598+ if any (
4599+ [
4600+ hierarchy_group_columns ,
4601+ hierarchy_group_total_weight ,
4602+ hierarchy_temporal_total_weight ,
4603+ hierarchy_group_temporal_total_weight ,
4604+ ]
4605+ ):
4606+ training_task_inputs_dict ["hierarchyConfig" ] = {
4607+ "groupColumns" : hierarchy_group_columns ,
4608+ "groupTotalWeight" : hierarchy_group_total_weight ,
4609+ "temporalTotalWeight" : hierarchy_temporal_total_weight ,
4610+ "groupTemporalTotalWeight" : hierarchy_group_temporal_total_weight ,
4611+ }
4612+ if window_config :
4613+ training_task_inputs_dict ["windowConfig" ] = window_config
4614+
45084615 final_export_eval_bq_uri = export_evaluated_data_items_bigquery_destination_uri
45094616 if final_export_eval_bq_uri and not final_export_eval_bq_uri .startswith (
45104617 "bq://"
@@ -4582,6 +4689,29 @@ def _add_additional_experiments(self, additional_experiments: List[str]):
45824689 """
45834690 self ._additional_experiments .extend (additional_experiments )
45844691
4692+ @staticmethod
4693+ def _create_window_config (
4694+ column : Optional [str ] = None ,
4695+ stride_length : Optional [int ] = None ,
4696+ max_count : Optional [int ] = None ,
4697+ ) -> Optional [Dict [str , Union [int , str ]]]:
4698+ """Creates a window config from training job arguments."""
4699+ configs = {
4700+ "column" : column ,
4701+ "strideLength" : stride_length ,
4702+ "maxCount" : max_count ,
4703+ }
4704+ present_configs = {k : v for k , v in configs .items () if v is not None }
4705+ if not present_configs :
4706+ return None
4707+ if len (present_configs ) > 1 :
4708+ raise ValueError (
4709+ "More than one windowing strategy provided. Make sure only one "
4710+ "of window_column, window_stride_length, or window_max_count "
4711+ "is specified."
4712+ )
4713+ return present_configs
4714+
45854715
45864716class AutoMLImageTrainingJob (_TrainingJob ):
45874717 _supported_training_schemas = (
0 commit comments