model.proto 40 KB


  1. // Copyright 2021 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. syntax = "proto3";
  15. package google.cloud.bigquery.v2;
  16. import "google/api/client.proto";
  17. import "google/api/field_behavior.proto";
  18. import "google/cloud/bigquery/v2/encryption_config.proto";
  19. import "google/cloud/bigquery/v2/model_reference.proto";
  20. import "google/cloud/bigquery/v2/standard_sql.proto";
  21. import "google/cloud/bigquery/v2/table_reference.proto";
  22. import "google/protobuf/empty.proto";
  23. import "google/protobuf/timestamp.proto";
  24. import "google/protobuf/wrappers.proto";
  25. import "google/api/annotations.proto";
  26. option go_package = "google.golang.org/genproto/googleapis/cloud/bigquery/v2;bigquery";
  27. option java_outer_classname = "ModelProto";
  28. option java_package = "com.google.cloud.bigquery.v2";
  29. service ModelService {
  30. option (google.api.default_host) = "bigquery.googleapis.com";
  31. option (google.api.oauth_scopes) =
  32. "https://www.googleapis.com/auth/bigquery,"
  33. "https://www.googleapis.com/auth/bigquery.readonly,"
  34. "https://www.googleapis.com/auth/cloud-platform,"
  35. "https://www.googleapis.com/auth/cloud-platform.read-only";
  36. // Gets the specified model resource by model ID.
  37. rpc GetModel(GetModelRequest) returns (Model) {
  38. option (google.api.http) = {
  39. get: "/bigquery/v2/projects/{project_id=*}/datasets/{dataset_id=*}/models/{model_id=*}"
  40. };
  41. option (google.api.method_signature) = "project_id,dataset_id,model_id";
  42. }
  43. // Lists all models in the specified dataset. Requires the READER dataset
  44. // role. After retrieving the list of models, you can get information about a
  45. // particular model by calling the models.get method.
  46. rpc ListModels(ListModelsRequest) returns (ListModelsResponse) {
  47. option (google.api.http) = {
  48. get: "/bigquery/v2/projects/{project_id=*}/datasets/{dataset_id=*}/models"
  49. };
  50. option (google.api.method_signature) = "project_id,dataset_id,max_results";
  51. }
  52. // Patch specific fields in the specified model.
  53. rpc PatchModel(PatchModelRequest) returns (Model) {
  54. option (google.api.http) = {
  55. patch: "/bigquery/v2/projects/{project_id=*}/datasets/{dataset_id=*}/models/{model_id=*}"
  56. body: "model"
  57. };
  58. option (google.api.method_signature) = "project_id,dataset_id,model_id,model";
  59. }
  60. // Deletes the model specified by modelId from the dataset.
  61. rpc DeleteModel(DeleteModelRequest) returns (google.protobuf.Empty) {
  62. option (google.api.http) = {
  63. delete: "/bigquery/v2/projects/{project_id=*}/datasets/{dataset_id=*}/models/{model_id=*}"
  64. };
  65. option (google.api.method_signature) = "project_id,dataset_id,model_id";
  66. }
  67. }
  68. message Model {
  69. message SeasonalPeriod {
  70. enum SeasonalPeriodType {
  71. SEASONAL_PERIOD_TYPE_UNSPECIFIED = 0;
  72. // No seasonality
  73. NO_SEASONALITY = 1;
  74. // Daily period, 24 hours.
  75. DAILY = 2;
  76. // Weekly period, 7 days.
  77. WEEKLY = 3;
  78. // Monthly period, 30 days or irregular.
  79. MONTHLY = 4;
  80. // Quarterly period, 90 days or irregular.
  81. QUARTERLY = 5;
  82. // Yearly period, 365 days or irregular.
  83. YEARLY = 6;
  84. }
  85. }
  86. message KmeansEnums {
  87. // Indicates the method used to initialize the centroids for KMeans
  88. // clustering algorithm.
  89. enum KmeansInitializationMethod {
  90. // Unspecified initialization method.
  91. KMEANS_INITIALIZATION_METHOD_UNSPECIFIED = 0;
  92. // Initializes the centroids randomly.
  93. RANDOM = 1;
  94. // Initializes the centroids using data specified in
  95. // kmeans_initialization_column.
  96. CUSTOM = 2;
  97. // Initializes with kmeans++.
  98. KMEANS_PLUS_PLUS = 3;
  99. }
  100. }
  101. // Evaluation metrics for regression and explicit feedback type matrix
  102. // factorization models.
  103. message RegressionMetrics {
  104. // Mean absolute error.
  105. google.protobuf.DoubleValue mean_absolute_error = 1;
  106. // Mean squared error.
  107. google.protobuf.DoubleValue mean_squared_error = 2;
  108. // Mean squared log error.
  109. google.protobuf.DoubleValue mean_squared_log_error = 3;
  110. // Median absolute error.
  111. google.protobuf.DoubleValue median_absolute_error = 4;
  112. // R^2 score. This corresponds to r2_score in ML.EVALUATE.
  113. google.protobuf.DoubleValue r_squared = 5;
  114. }
  115. // Aggregate metrics for classification/classifier models. For multi-class
  116. // models, the metrics are either macro-averaged or micro-averaged. When
  117. // macro-averaged, the metrics are calculated for each label and then an
  118. // unweighted average is taken of those values. When micro-averaged, the
  119. // metric is calculated globally by counting the total number of correctly
  120. // predicted rows.
  121. message AggregateClassificationMetrics {
  122. // Precision is the fraction of actual positive predictions that had
  123. // positive actual labels. For multiclass this is a macro-averaged
  124. // metric treating each class as a binary classifier.
  125. google.protobuf.DoubleValue precision = 1;
  126. // Recall is the fraction of actual positive labels that were given a
  127. // positive prediction. For multiclass this is a macro-averaged metric.
  128. google.protobuf.DoubleValue recall = 2;
  129. // Accuracy is the fraction of predictions given the correct label. For
  130. // multiclass this is a micro-averaged metric.
  131. google.protobuf.DoubleValue accuracy = 3;
  132. // Threshold at which the metrics are computed. For binary
  133. // classification models this is the positive class threshold.
  134. // For multi-class classfication models this is the confidence
  135. // threshold.
  136. google.protobuf.DoubleValue threshold = 4;
  137. // The F1 score is an average of recall and precision. For multiclass
  138. // this is a macro-averaged metric.
  139. google.protobuf.DoubleValue f1_score = 5;
  140. // Logarithmic Loss. For multiclass this is a macro-averaged metric.
  141. google.protobuf.DoubleValue log_loss = 6;
  142. // Area Under a ROC Curve. For multiclass this is a macro-averaged
  143. // metric.
  144. google.protobuf.DoubleValue roc_auc = 7;
  145. }
  146. // Evaluation metrics for binary classification/classifier models.
  147. message BinaryClassificationMetrics {
  148. // Confusion matrix for binary classification models.
  149. message BinaryConfusionMatrix {
  150. // Threshold value used when computing each of the following metric.
  151. google.protobuf.DoubleValue positive_class_threshold = 1;
  152. // Number of true samples predicted as true.
  153. google.protobuf.Int64Value true_positives = 2;
  154. // Number of false samples predicted as true.
  155. google.protobuf.Int64Value false_positives = 3;
  156. // Number of true samples predicted as false.
  157. google.protobuf.Int64Value true_negatives = 4;
  158. // Number of false samples predicted as false.
  159. google.protobuf.Int64Value false_negatives = 5;
  160. // The fraction of actual positive predictions that had positive actual
  161. // labels.
  162. google.protobuf.DoubleValue precision = 6;
  163. // The fraction of actual positive labels that were given a positive
  164. // prediction.
  165. google.protobuf.DoubleValue recall = 7;
  166. // The equally weighted average of recall and precision.
  167. google.protobuf.DoubleValue f1_score = 8;
  168. // The fraction of predictions given the correct label.
  169. google.protobuf.DoubleValue accuracy = 9;
  170. }
  171. // Aggregate classification metrics.
  172. AggregateClassificationMetrics aggregate_classification_metrics = 1;
  173. // Binary confusion matrix at multiple thresholds.
  174. repeated BinaryConfusionMatrix binary_confusion_matrix_list = 2;
  175. // Label representing the positive class.
  176. string positive_label = 3;
  177. // Label representing the negative class.
  178. string negative_label = 4;
  179. }
  180. // Evaluation metrics for multi-class classification/classifier models.
  181. message MultiClassClassificationMetrics {
  182. // Confusion matrix for multi-class classification models.
  183. message ConfusionMatrix {
  184. // A single entry in the confusion matrix.
  185. message Entry {
  186. // The predicted label. For confidence_threshold > 0, we will
  187. // also add an entry indicating the number of items under the
  188. // confidence threshold.
  189. string predicted_label = 1;
  190. // Number of items being predicted as this label.
  191. google.protobuf.Int64Value item_count = 2;
  192. }
  193. // A single row in the confusion matrix.
  194. message Row {
  195. // The original label of this row.
  196. string actual_label = 1;
  197. // Info describing predicted label distribution.
  198. repeated Entry entries = 2;
  199. }
  200. // Confidence threshold used when computing the entries of the
  201. // confusion matrix.
  202. google.protobuf.DoubleValue confidence_threshold = 1;
  203. // One row per actual label.
  204. repeated Row rows = 2;
  205. }
  206. // Aggregate classification metrics.
  207. AggregateClassificationMetrics aggregate_classification_metrics = 1;
  208. // Confusion matrix at different thresholds.
  209. repeated ConfusionMatrix confusion_matrix_list = 2;
  210. }
  211. // Evaluation metrics for clustering models.
  212. message ClusteringMetrics {
  213. // Message containing the information about one cluster.
  214. message Cluster {
  215. // Representative value of a single feature within the cluster.
  216. message FeatureValue {
  217. // Representative value of a categorical feature.
  218. message CategoricalValue {
  219. // Represents the count of a single category within the cluster.
  220. message CategoryCount {
  221. // The name of category.
  222. string category = 1;
  223. // The count of training samples matching the category within the
  224. // cluster.
  225. google.protobuf.Int64Value count = 2;
  226. }
  227. // Counts of all categories for the categorical feature. If there are
  228. // more than ten categories, we return top ten (by count) and return
  229. // one more CategoryCount with category "_OTHER_" and count as
  230. // aggregate counts of remaining categories.
  231. repeated CategoryCount category_counts = 1;
  232. }
  233. // The feature column name.
  234. string feature_column = 1;
  235. oneof value {
  236. // The numerical feature value. This is the centroid value for this
  237. // feature.
  238. google.protobuf.DoubleValue numerical_value = 2;
  239. // The categorical feature value.
  240. CategoricalValue categorical_value = 3;
  241. }
  242. }
  243. // Centroid id.
  244. int64 centroid_id = 1;
  245. // Values of highly variant features for this cluster.
  246. repeated FeatureValue feature_values = 2;
  247. // Count of training data rows that were assigned to this cluster.
  248. google.protobuf.Int64Value count = 3;
  249. }
  250. // Davies-Bouldin index.
  251. google.protobuf.DoubleValue davies_bouldin_index = 1;
  252. // Mean of squared distances between each sample to its cluster centroid.
  253. google.protobuf.DoubleValue mean_squared_distance = 2;
  254. // Information for all clusters.
  255. repeated Cluster clusters = 3;
  256. }
  257. // Evaluation metrics used by weighted-ALS models specified by
  258. // feedback_type=implicit.
  259. message RankingMetrics {
  260. // Calculates a precision per user for all the items by ranking them and
  261. // then averages all the precisions across all the users.
  262. google.protobuf.DoubleValue mean_average_precision = 1;
  263. // Similar to the mean squared error computed in regression and explicit
  264. // recommendation models except instead of computing the rating directly,
  265. // the output from evaluate is computed against a preference which is 1 or 0
  266. // depending on if the rating exists or not.
  267. google.protobuf.DoubleValue mean_squared_error = 2;
  268. // A metric to determine the goodness of a ranking calculated from the
  269. // predicted confidence by comparing it to an ideal rank measured by the
  270. // original ratings.
  271. google.protobuf.DoubleValue normalized_discounted_cumulative_gain = 3;
  272. // Determines the goodness of a ranking by computing the percentile rank
  273. // from the predicted confidence and dividing it by the original rank.
  274. google.protobuf.DoubleValue average_rank = 4;
  275. }
  276. // Model evaluation metrics for ARIMA forecasting models.
  277. message ArimaForecastingMetrics {
  278. // Model evaluation metrics for a single ARIMA forecasting model.
  279. message ArimaSingleModelForecastingMetrics {
  280. // Non-seasonal order.
  281. ArimaOrder non_seasonal_order = 1;
  282. // Arima fitting metrics.
  283. ArimaFittingMetrics arima_fitting_metrics = 2;
  284. // Is arima model fitted with drift or not. It is always false when d
  285. // is not 1.
  286. bool has_drift = 3;
  287. // The time_series_id value for this time series. It will be one of
  288. // the unique values from the time_series_id_column specified during
  289. // ARIMA model training. Only present when time_series_id_column
  290. // training option was used.
  291. string time_series_id = 4;
  292. // The tuple of time_series_ids identifying this time series. It will
  293. // be one of the unique tuples of values present in the
  294. // time_series_id_columns specified during ARIMA model training. Only
  295. // present when time_series_id_columns training option was used and
  296. // the order of values here are same as the order of
  297. // time_series_id_columns.
  298. repeated string time_series_ids = 9;
  299. // Seasonal periods. Repeated because multiple periods are supported
  300. // for one time series.
  301. repeated SeasonalPeriod.SeasonalPeriodType seasonal_periods = 5;
  302. // If true, holiday_effect is a part of time series decomposition result.
  303. google.protobuf.BoolValue has_holiday_effect = 6;
  304. // If true, spikes_and_dips is a part of time series decomposition result.
  305. google.protobuf.BoolValue has_spikes_and_dips = 7;
  306. // If true, step_changes is a part of time series decomposition result.
  307. google.protobuf.BoolValue has_step_changes = 8;
  308. }
  309. // Non-seasonal order.
  310. repeated ArimaOrder non_seasonal_order = 1 [deprecated = true];
  311. // Arima model fitting metrics.
  312. repeated ArimaFittingMetrics arima_fitting_metrics = 2 [deprecated = true];
  313. // Seasonal periods. Repeated because multiple periods are supported for one
  314. // time series.
  315. repeated SeasonalPeriod.SeasonalPeriodType seasonal_periods = 3 [deprecated = true];
  316. // Whether Arima model fitted with drift or not. It is always false when d
  317. // is not 1.
  318. repeated bool has_drift = 4 [deprecated = true];
  319. // Id to differentiate different time series for the large-scale case.
  320. repeated string time_series_id = 5 [deprecated = true];
  321. // Repeated as there can be many metric sets (one for each model) in
  322. // auto-arima and the large-scale case.
  323. repeated ArimaSingleModelForecastingMetrics arima_single_model_forecasting_metrics = 6;
  324. }
  325. // Evaluation metrics of a model. These are either computed on all training
  326. // data or just the eval data based on whether eval data was used during
  327. // training. These are not present for imported models.
  328. message EvaluationMetrics {
  329. oneof metrics {
  330. // Populated for regression models and explicit feedback type matrix
  331. // factorization models.
  332. RegressionMetrics regression_metrics = 1;
  333. // Populated for binary classification/classifier models.
  334. BinaryClassificationMetrics binary_classification_metrics = 2;
  335. // Populated for multi-class classification/classifier models.
  336. MultiClassClassificationMetrics multi_class_classification_metrics = 3;
  337. // Populated for clustering models.
  338. ClusteringMetrics clustering_metrics = 4;
  339. // Populated for implicit feedback type matrix factorization models.
  340. RankingMetrics ranking_metrics = 5;
  341. // Populated for ARIMA models.
  342. ArimaForecastingMetrics arima_forecasting_metrics = 6;
  343. }
  344. }
  345. // Data split result. This contains references to the training and evaluation
  346. // data tables that were used to train the model.
  347. message DataSplitResult {
  348. // Table reference of the training data after split.
  349. TableReference training_table = 1;
  350. // Table reference of the evaluation data after split.
  351. TableReference evaluation_table = 2;
  352. }
  353. // Arima order, can be used for both non-seasonal and seasonal parts.
  354. message ArimaOrder {
  355. // Order of the autoregressive part.
  356. int64 p = 1;
  357. // Order of the differencing part.
  358. int64 d = 2;
  359. // Order of the moving-average part.
  360. int64 q = 3;
  361. }
  362. // ARIMA model fitting metrics.
  363. message ArimaFittingMetrics {
  364. // Log-likelihood.
  365. double log_likelihood = 1;
  366. // AIC.
  367. double aic = 2;
  368. // Variance.
  369. double variance = 3;
  370. }
  371. // Global explanations containing the top most important features
  372. // after training.
  373. message GlobalExplanation {
  374. // Explanation for a single feature.
  375. message Explanation {
  376. // Full name of the feature. For non-numerical features, will be
  377. // formatted like <column_name>.<encoded_feature_name>. Overall size of
  378. // feature name will always be truncated to first 120 characters.
  379. string feature_name = 1;
  380. // Attribution of feature.
  381. google.protobuf.DoubleValue attribution = 2;
  382. }
  383. // A list of the top global explanations. Sorted by absolute value of
  384. // attribution in descending order.
  385. repeated Explanation explanations = 1;
  386. // Class label for this set of global explanations. Will be empty/null for
  387. // binary logistic and linear regression models. Sorted alphabetically in
  388. // descending order.
  389. string class_label = 2;
  390. }
  391. // Information about a single training query run for the model.
  392. message TrainingRun {
  393. // Options used in model training.
  394. message TrainingOptions {
  395. // The maximum number of iterations in training. Used only for iterative
  396. // training algorithms.
  397. int64 max_iterations = 1;
  398. // Type of loss function used during training run.
  399. LossType loss_type = 2;
  400. // Learning rate in training. Used only for iterative training algorithms.
  401. double learn_rate = 3;
  402. // L1 regularization coefficient.
  403. google.protobuf.DoubleValue l1_regularization = 4;
  404. // L2 regularization coefficient.
  405. google.protobuf.DoubleValue l2_regularization = 5;
  406. // When early_stop is true, stops training when accuracy improvement is
  407. // less than 'min_relative_progress'. Used only for iterative training
  408. // algorithms.
  409. google.protobuf.DoubleValue min_relative_progress = 6;
  410. // Whether to train a model from the last checkpoint.
  411. google.protobuf.BoolValue warm_start = 7;
  412. // Whether to stop early when the loss doesn't improve significantly
  413. // any more (compared to min_relative_progress). Used only for iterative
  414. // training algorithms.
  415. google.protobuf.BoolValue early_stop = 8;
  416. // Name of input label columns in training data.
  417. repeated string input_label_columns = 9;
  418. // The data split type for training and evaluation, e.g. RANDOM.
  419. DataSplitMethod data_split_method = 10;
  420. // The fraction of evaluation data over the whole input data. The rest
  421. // of data will be used as training data. The format should be double.
  422. // Accurate to two decimal places.
  423. // Default value is 0.2.
  424. double data_split_eval_fraction = 11;
  425. // The column to split data with. This column won't be used as a
  426. // feature.
  427. // 1. When data_split_method is CUSTOM, the corresponding column should
  428. // be boolean. The rows with true value tag are eval data, and the false
  429. // are training data.
  430. // 2. When data_split_method is SEQ, the first DATA_SPLIT_EVAL_FRACTION
  431. // rows (from smallest to largest) in the corresponding column are used
  432. // as training data, and the rest are eval data. It respects the order
  433. // in Orderable data types:
  434. // https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data-type-properties
  435. string data_split_column = 12;
  436. // The strategy to determine learn rate for the current iteration.
  437. LearnRateStrategy learn_rate_strategy = 13;
  438. // Specifies the initial learning rate for the line search learn rate
  439. // strategy.
  440. double initial_learn_rate = 16;
  441. // Weights associated with each label class, for rebalancing the
  442. // training data. Only applicable for classification models.
  443. map<string, double> label_class_weights = 17;
  444. // User column specified for matrix factorization models.
  445. string user_column = 18;
  446. // Item column specified for matrix factorization models.
  447. string item_column = 19;
  448. // Distance type for clustering models.
  449. DistanceType distance_type = 20;
  450. // Number of clusters for clustering models.
  451. int64 num_clusters = 21;
  452. // Google Cloud Storage URI from which the model was imported. Only
  453. // applicable for imported models.
  454. string model_uri = 22;
  455. // Optimization strategy for training linear regression models.
  456. OptimizationStrategy optimization_strategy = 23;
  457. // Hidden units for dnn models.
  458. repeated int64 hidden_units = 24;
  459. // Batch size for dnn models.
  460. int64 batch_size = 25;
  461. // Dropout probability for dnn models.
  462. google.protobuf.DoubleValue dropout = 26;
  463. // Maximum depth of a tree for boosted tree models.
  464. int64 max_tree_depth = 27;
  465. // Subsample fraction of the training data to grow tree to prevent
  466. // overfitting for boosted tree models.
  467. double subsample = 28;
  468. // Minimum split loss for boosted tree models.
  469. google.protobuf.DoubleValue min_split_loss = 29;
  470. // Num factors specified for matrix factorization models.
  471. int64 num_factors = 30;
  472. // Feedback type that specifies which algorithm to run for matrix
  473. // factorization.
  474. FeedbackType feedback_type = 31;
  475. // Hyperparameter for matrix factoration when implicit feedback type is
  476. // specified.
  477. google.protobuf.DoubleValue wals_alpha = 32;
  478. // The method used to initialize the centroids for kmeans algorithm.
  479. KmeansEnums.KmeansInitializationMethod kmeans_initialization_method = 33;
  480. // The column used to provide the initial centroids for kmeans algorithm
  481. // when kmeans_initialization_method is CUSTOM.
  482. string kmeans_initialization_column = 34;
  483. // Column to be designated as time series timestamp for ARIMA model.
  484. string time_series_timestamp_column = 35;
  485. // Column to be designated as time series data for ARIMA model.
  486. string time_series_data_column = 36;
  487. // Whether to enable auto ARIMA or not.
  488. bool auto_arima = 37;
  489. // A specification of the non-seasonal part of the ARIMA model: the three
  490. // components (p, d, q) are the AR order, the degree of differencing, and
  491. // the MA order.
  492. ArimaOrder non_seasonal_order = 38;
  493. // The data frequency of a time series.
  494. DataFrequency data_frequency = 39;
  495. // Include drift when fitting an ARIMA model.
  496. bool include_drift = 41;
  497. // The geographical region based on which the holidays are considered in
  498. // time series modeling. If a valid value is specified, then holiday
  499. // effects modeling is enabled.
  500. HolidayRegion holiday_region = 42;
  501. // The time series id column that was used during ARIMA model training.
  502. string time_series_id_column = 43;
  503. // The time series id columns that were used during ARIMA model training.
  504. repeated string time_series_id_columns = 51;
  505. // The number of periods ahead that need to be forecasted.
  506. int64 horizon = 44;
  507. // Whether to preserve the input structs in output feature names.
  508. // Suppose there is a struct A with field b.
  509. // When false (default), the output feature name is A_b.
  510. // When true, the output feature name is A.b.
  511. bool preserve_input_structs = 45;
  512. // The max value of non-seasonal p and q.
  513. int64 auto_arima_max_order = 46;
  514. // If true, perform decompose time series and save the results.
  515. google.protobuf.BoolValue decompose_time_series = 50;
  516. // If true, clean spikes and dips in the input time series.
  517. google.protobuf.BoolValue clean_spikes_and_dips = 52;
  518. // If true, detect step changes and make data adjustment in the input time
  519. // series.
  520. google.protobuf.BoolValue adjust_step_changes = 53;
  521. }
  522. // Information about a single iteration of the training run.
  523. message IterationResult {
  524. // Information about a single cluster for clustering model.
  525. message ClusterInfo {
  526. // Centroid id.
  527. int64 centroid_id = 1;
  528. // Cluster radius, the average distance from centroid
  529. // to each point assigned to the cluster.
  530. google.protobuf.DoubleValue cluster_radius = 2;
  531. // Cluster size, the total number of points assigned to the cluster.
  532. google.protobuf.Int64Value cluster_size = 3;
  533. }
  534. // (Auto-)arima fitting result. Wrap everything in ArimaResult for easier
  535. // refactoring if we want to use model-specific iteration results.
  536. message ArimaResult {
  537. // Arima coefficients.
  538. message ArimaCoefficients {
  539. // Auto-regressive coefficients, an array of double.
  540. repeated double auto_regressive_coefficients = 1;
  541. // Moving-average coefficients, an array of double.
  542. repeated double moving_average_coefficients = 2;
  543. // Intercept coefficient, just a double not an array.
  544. double intercept_coefficient = 3;
  545. }
  546. // Arima model information.
  547. message ArimaModelInfo {
  548. // Non-seasonal order.
  549. ArimaOrder non_seasonal_order = 1;
  550. // Arima coefficients.
  551. ArimaCoefficients arima_coefficients = 2;
  552. // Arima fitting metrics.
  553. ArimaFittingMetrics arima_fitting_metrics = 3;
  554. // Whether Arima model fitted with drift or not. It is always false
  555. // when d is not 1.
  556. bool has_drift = 4;
  557. // The time_series_id value for this time series. It will be one of
  558. // the unique values from the time_series_id_column specified during
  559. // ARIMA model training. Only present when time_series_id_column
  560. // training option was used.
  561. string time_series_id = 5;
  562. // The tuple of time_series_ids identifying this time series. It will
  563. // be one of the unique tuples of values present in the
  564. // time_series_id_columns specified during ARIMA model training. Only
  565. // present when time_series_id_columns training option was used and
  566. // the order of values here are same as the order of
  567. // time_series_id_columns.
  568. repeated string time_series_ids = 10;
  569. // Seasonal periods. Repeated because multiple periods are supported
  570. // for one time series.
  571. repeated SeasonalPeriod.SeasonalPeriodType seasonal_periods = 6;
  572. // If true, holiday_effect is a part of time series decomposition
  573. // result.
  574. google.protobuf.BoolValue has_holiday_effect = 7;
  575. // If true, spikes_and_dips is a part of time series decomposition
  576. // result.
  577. google.protobuf.BoolValue has_spikes_and_dips = 8;
  578. // If true, step_changes is a part of time series decomposition
  579. // result.
  580. google.protobuf.BoolValue has_step_changes = 9;
  581. }
  582. // This message is repeated because there are multiple arima models
  583. // fitted in auto-arima. For non-auto-arima model, its size is one.
  584. repeated ArimaModelInfo arima_model_info = 1;
  585. // Seasonal periods. Repeated because multiple periods are supported for
  586. // one time series.
  587. repeated SeasonalPeriod.SeasonalPeriodType seasonal_periods = 2;
  588. }
  589. // Index of the iteration, 0 based.
  590. google.protobuf.Int32Value index = 1;
  591. // Time taken to run the iteration in milliseconds.
  592. google.protobuf.Int64Value duration_ms = 4;
  593. // Loss computed on the training data at the end of iteration.
  594. google.protobuf.DoubleValue training_loss = 5;
  595. // Loss computed on the eval data at the end of iteration.
  596. google.protobuf.DoubleValue eval_loss = 6;
  597. // Learn rate used for this iteration.
  598. double learn_rate = 7;
  599. // Information about top clusters for clustering models.
  600. repeated ClusterInfo cluster_infos = 8;
  601. ArimaResult arima_result = 9;
  602. }
  603. // Options that were used for this training run, includes
  604. // user specified and default options that were used.
  605. TrainingOptions training_options = 1;
  606. // The start time of this training run.
  607. google.protobuf.Timestamp start_time = 8;
  608. // Output of each iteration run, results.size() <= max_iterations.
  609. repeated IterationResult results = 6;
  610. // The evaluation metrics over training/eval data that were computed at the
  611. // end of training.
  612. EvaluationMetrics evaluation_metrics = 7;
  613. // Data split result of the training run. Only set when the input data is
  614. // actually split.
  615. DataSplitResult data_split_result = 9;
  616. // Global explanations for important features of the model. For multi-class
  617. // models, there is one entry for each label class. For other models, there
  618. // is only one entry in the list.
  619. repeated GlobalExplanation global_explanations = 10;
  620. }
  621. // Indicates the type of the Model.
  622. enum ModelType {
  623. MODEL_TYPE_UNSPECIFIED = 0;
  624. // Linear regression model.
  625. LINEAR_REGRESSION = 1;
  626. // Logistic regression based classification model.
  627. LOGISTIC_REGRESSION = 2;
  628. // K-means clustering model.
  629. KMEANS = 3;
  630. // Matrix factorization model.
  631. MATRIX_FACTORIZATION = 4;
  632. // DNN classifier model.
  633. DNN_CLASSIFIER = 5;
  634. // An imported TensorFlow model.
  635. TENSORFLOW = 6;
  636. // DNN regressor model.
  637. DNN_REGRESSOR = 7;
  638. // Boosted tree regressor model.
  639. BOOSTED_TREE_REGRESSOR = 9;
  640. // Boosted tree classifier model.
  641. BOOSTED_TREE_CLASSIFIER = 10;
  642. // ARIMA model.
  643. ARIMA = 11;
  644. // [Beta] AutoML Tables regression model.
  645. AUTOML_REGRESSOR = 12;
  646. // [Beta] AutoML Tables classification model.
  647. AUTOML_CLASSIFIER = 13;
  648. // New name for the ARIMA model.
  649. ARIMA_PLUS = 19;
  650. }
  651. // Loss metric to evaluate model training performance.
  652. enum LossType {
  653. LOSS_TYPE_UNSPECIFIED = 0;
  654. // Mean squared loss, used for linear regression.
  655. MEAN_SQUARED_LOSS = 1;
  656. // Mean log loss, used for logistic regression.
  657. MEAN_LOG_LOSS = 2;
  658. }
  659. // Distance metric used to compute the distance between two points.
  660. enum DistanceType {
  661. DISTANCE_TYPE_UNSPECIFIED = 0;
  662. // Eculidean distance.
  663. EUCLIDEAN = 1;
  664. // Cosine distance.
  665. COSINE = 2;
  666. }
  667. // Indicates the method to split input data into multiple tables.
  668. enum DataSplitMethod {
  669. DATA_SPLIT_METHOD_UNSPECIFIED = 0;
  670. // Splits data randomly.
  671. RANDOM = 1;
  672. // Splits data with the user provided tags.
  673. CUSTOM = 2;
  674. // Splits data sequentially.
  675. SEQUENTIAL = 3;
  676. // Data split will be skipped.
  677. NO_SPLIT = 4;
  678. // Splits data automatically: Uses NO_SPLIT if the data size is small.
  679. // Otherwise uses RANDOM.
  680. AUTO_SPLIT = 5;
  681. }
  682. // Type of supported data frequency for time series forecasting models.
  683. enum DataFrequency {
  684. DATA_FREQUENCY_UNSPECIFIED = 0;
  685. // Automatically inferred from timestamps.
  686. AUTO_FREQUENCY = 1;
  687. // Yearly data.
  688. YEARLY = 2;
  689. // Quarterly data.
  690. QUARTERLY = 3;
  691. // Monthly data.
  692. MONTHLY = 4;
  693. // Weekly data.
  694. WEEKLY = 5;
  695. // Daily data.
  696. DAILY = 6;
  697. // Hourly data.
  698. HOURLY = 7;
  699. // Per-minute data.
  700. PER_MINUTE = 8;
  701. }
  702. // Type of supported holiday regions for time series forecasting models.
  703. enum HolidayRegion {
  704. // Holiday region unspecified.
  705. HOLIDAY_REGION_UNSPECIFIED = 0;
  706. // Global.
  707. GLOBAL = 1;
  708. // North America.
  709. NA = 2;
  710. // Japan and Asia Pacific: Korea, Greater China, India, Australia, and New
  711. // Zealand.
  712. JAPAC = 3;
  713. // Europe, the Middle East and Africa.
  714. EMEA = 4;
  715. // Latin America and the Caribbean.
  716. LAC = 5;
  717. // United Arab Emirates
  718. AE = 6;
  719. // Argentina
  720. AR = 7;
  721. // Austria
  722. AT = 8;
  723. // Australia
  724. AU = 9;
  725. // Belgium
  726. BE = 10;
  727. // Brazil
  728. BR = 11;
  729. // Canada
  730. CA = 12;
  731. // Switzerland
  732. CH = 13;
  733. // Chile
  734. CL = 14;
  735. // China
  736. CN = 15;
  737. // Colombia
  738. CO = 16;
  739. // Czechoslovakia
  740. CS = 17;
  741. // Czech Republic
  742. CZ = 18;
  743. // Germany
  744. DE = 19;
  745. // Denmark
  746. DK = 20;
  747. // Algeria
  748. DZ = 21;
  749. // Ecuador
  750. EC = 22;
  751. // Estonia
  752. EE = 23;
  753. // Egypt
  754. EG = 24;
  755. // Spain
  756. ES = 25;
  757. // Finland
  758. FI = 26;
  759. // France
  760. FR = 27;
  761. // Great Britain (United Kingdom)
  762. GB = 28;
  763. // Greece
  764. GR = 29;
  765. // Hong Kong
  766. HK = 30;
  767. // Hungary
  768. HU = 31;
  769. // Indonesia
  770. ID = 32;
  771. // Ireland
  772. IE = 33;
  773. // Israel
  774. IL = 34;
  775. // India
  776. IN = 35;
  777. // Iran
  778. IR = 36;
  779. // Italy
  780. IT = 37;
  781. // Japan
  782. JP = 38;
  783. // Korea (South)
  784. KR = 39;
  785. // Latvia
  786. LV = 40;
  787. // Morocco
  788. MA = 41;
  789. // Mexico
  790. MX = 42;
  791. // Malaysia
  792. MY = 43;
  793. // Nigeria
  794. NG = 44;
  795. // Netherlands
  796. NL = 45;
  797. // Norway
  798. NO = 46;
  799. // New Zealand
  800. NZ = 47;
  801. // Peru
  802. PE = 48;
  803. // Philippines
  804. PH = 49;
  805. // Pakistan
  806. PK = 50;
  807. // Poland
  808. PL = 51;
  809. // Portugal
  810. PT = 52;
  811. // Romania
  812. RO = 53;
  813. // Serbia
  814. RS = 54;
  815. // Russian Federation
  816. RU = 55;
  817. // Saudi Arabia
  818. SA = 56;
  819. // Sweden
  820. SE = 57;
  821. // Singapore
  822. SG = 58;
  823. // Slovenia
  824. SI = 59;
  825. // Slovakia
  826. SK = 60;
  827. // Thailand
  828. TH = 61;
  829. // Turkey
  830. TR = 62;
  831. // Taiwan
  832. TW = 63;
  833. // Ukraine
  834. UA = 64;
  835. // United States
  836. US = 65;
  837. // Venezuela
  838. VE = 66;
  839. // Viet Nam
  840. VN = 67;
  841. // South Africa
  842. ZA = 68;
  843. }
  844. // Indicates the learning rate optimization strategy to use.
  845. enum LearnRateStrategy {
  846. LEARN_RATE_STRATEGY_UNSPECIFIED = 0;
  847. // Use line search to determine learning rate.
  848. LINE_SEARCH = 1;
  849. // Use a constant learning rate.
  850. CONSTANT = 2;
  851. }
  852. // Indicates the optimization strategy used for training.
  853. enum OptimizationStrategy {
  854. OPTIMIZATION_STRATEGY_UNSPECIFIED = 0;
  855. // Uses an iterative batch gradient descent algorithm.
  856. BATCH_GRADIENT_DESCENT = 1;
  857. // Uses a normal equation to solve linear regression problem.
  858. NORMAL_EQUATION = 2;
  859. }
  860. // Indicates the training algorithm to use for matrix factorization models.
  861. enum FeedbackType {
  862. FEEDBACK_TYPE_UNSPECIFIED = 0;
  863. // Use weighted-als for implicit feedback problems.
  864. IMPLICIT = 1;
  865. // Use nonweighted-als for explicit feedback problems.
  866. EXPLICIT = 2;
  867. }
  868. // Output only. A hash of this resource.
  869. string etag = 1 [(google.api.field_behavior) = OUTPUT_ONLY];
  870. // Required. Unique identifier for this model.
  871. ModelReference model_reference = 2 [(google.api.field_behavior) = REQUIRED];
  872. // Output only. The time when this model was created, in millisecs since the epoch.
  873. int64 creation_time = 5 [(google.api.field_behavior) = OUTPUT_ONLY];
  874. // Output only. The time when this model was last modified, in millisecs since the epoch.
  875. int64 last_modified_time = 6 [(google.api.field_behavior) = OUTPUT_ONLY];
  876. // Optional. A user-friendly description of this model.
  877. string description = 12 [(google.api.field_behavior) = OPTIONAL];
  878. // Optional. A descriptive name for this model.
  879. string friendly_name = 14 [(google.api.field_behavior) = OPTIONAL];
  880. // The labels associated with this model. You can use these to organize
  881. // and group your models. Label keys and values can be no longer
  882. // than 63 characters, can only contain lowercase letters, numeric
  883. // characters, underscores and dashes. International characters are allowed.
  884. // Label values are optional. Label keys must start with a letter and each
  885. // label in the list must have a different key.
  886. map<string, string> labels = 15;
  887. // Optional. The time when this model expires, in milliseconds since the epoch.
  888. // If not present, the model will persist indefinitely. Expired models
  889. // will be deleted and their storage reclaimed. The defaultTableExpirationMs
  890. // property of the encapsulating dataset can be used to set a default
  891. // expirationTime on newly created models.
  892. int64 expiration_time = 16 [(google.api.field_behavior) = OPTIONAL];
  893. // Output only. The geographic location where the model resides. This value
  894. // is inherited from the dataset.
  895. string location = 13 [(google.api.field_behavior) = OUTPUT_ONLY];
  896. // Custom encryption configuration (e.g., Cloud KMS keys). This shows the
  897. // encryption configuration of the model data while stored in BigQuery
  898. // storage. This field can be used with PatchModel to update encryption key
  899. // for an already encrypted model.
  900. EncryptionConfiguration encryption_configuration = 17;
  901. // Output only. Type of the model resource.
  902. ModelType model_type = 7 [(google.api.field_behavior) = OUTPUT_ONLY];
  903. // Output only. Information for all training runs in increasing order of start_time.
  904. repeated TrainingRun training_runs = 9 [(google.api.field_behavior) = OUTPUT_ONLY];
  905. // Output only. Input feature columns that were used to train this model.
  906. repeated StandardSqlField feature_columns = 10 [(google.api.field_behavior) = OUTPUT_ONLY];
  907. // Output only. Label columns that were used to train this model.
  908. // The output of the model will have a "predicted_" prefix to these columns.
  909. repeated StandardSqlField label_columns = 11 [(google.api.field_behavior) = OUTPUT_ONLY];
  910. // The best trial_id across all training runs.
  911. int64 best_trial_id = 19 [deprecated = true];
  912. }
  913. message GetModelRequest {
  914. // Required. Project ID of the requested model.
  915. string project_id = 1 [(google.api.field_behavior) = REQUIRED];
  916. // Required. Dataset ID of the requested model.
  917. string dataset_id = 2 [(google.api.field_behavior) = REQUIRED];
  918. // Required. Model ID of the requested model.
  919. string model_id = 3 [(google.api.field_behavior) = REQUIRED];
  920. }
  921. message PatchModelRequest {
  922. // Required. Project ID of the model to patch.
  923. string project_id = 1 [(google.api.field_behavior) = REQUIRED];
  924. // Required. Dataset ID of the model to patch.
  925. string dataset_id = 2 [(google.api.field_behavior) = REQUIRED];
  926. // Required. Model ID of the model to patch.
  927. string model_id = 3 [(google.api.field_behavior) = REQUIRED];
  928. // Required. Patched model.
  929. // Follows RFC5789 patch semantics. Missing fields are not updated.
  930. // To clear a field, explicitly set to default value.
  931. Model model = 4 [(google.api.field_behavior) = REQUIRED];
  932. }
  933. message DeleteModelRequest {
  934. // Required. Project ID of the model to delete.
  935. string project_id = 1 [(google.api.field_behavior) = REQUIRED];
  936. // Required. Dataset ID of the model to delete.
  937. string dataset_id = 2 [(google.api.field_behavior) = REQUIRED];
  938. // Required. Model ID of the model to delete.
  939. string model_id = 3 [(google.api.field_behavior) = REQUIRED];
  940. }
  941. message ListModelsRequest {
  942. // Required. Project ID of the models to list.
  943. string project_id = 1 [(google.api.field_behavior) = REQUIRED];
  944. // Required. Dataset ID of the models to list.
  945. string dataset_id = 2 [(google.api.field_behavior) = REQUIRED];
  946. // The maximum number of results to return in a single response page.
  947. // Leverage the page tokens to iterate through the entire collection.
  948. google.protobuf.UInt32Value max_results = 3;
  949. // Page token, returned by a previous call to request the next page of
  950. // results
  951. string page_token = 4;
  952. }
  953. message ListModelsResponse {
  954. // Models in the requested dataset. Only the following fields are populated:
  955. // model_reference, model_type, creation_time, last_modified_time and
  956. // labels.
  957. repeated Model models = 1;
  958. // A token to request the next page of results.
  959. string next_page_token = 2;
  960. }