1616# under the License.
1717#
1818
19- import logging
2019import warnings
2120
2221import torch
2322
2423from iotdb .ainode .core .inference .pipeline .basic_pipeline import ForecastPipeline
24+ from iotdb .ainode .core .log import Logger
25+ from iotdb .ainode .core .model .toto .data .util .dataset import MaskedTimeseries
26+ from iotdb .ainode .core .model .toto .inference .forecaster import TotoForecaster
2527
26- from .data .util .dataset import MaskedTimeseries
27- from .inference .forecaster import TotoForecaster
28-
29- logger = logging .getLogger (__name__ )
28+ logger = Logger ()
3029
3130
3231class TotoPipeline (ForecastPipeline ):
@@ -51,9 +50,30 @@ def _get_forecaster(self) -> TotoForecaster:
5150 return self ._forecaster
5251
5352 def preprocess (self , inputs , ** infer_kwargs ):
54- super ().preprocess (inputs , ** infer_kwargs )
55- processed_inputs = []
53+ """
54+ Preprocess input data for Toto.
55+
56+ Delegates to the base class for input validation, then converts each
57+ validated input dict into a ``MaskedTimeseries`` named-tuple that the
58+ ``TotoForecaster`` expects.
59+
60+ Parameters
61+ ----------
62+ inputs : list of dict
63+ A list of dictionaries containing input data. Each dictionary contains:
64+ - 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length).
5665
66+ infer_kwargs: Additional keyword arguments for inference, such as:
67+ - `output_length`(int): Prediction length.
68+
69+ Returns
70+ -------
71+ list of MaskedTimeseries
72+ Processed inputs compatible with Toto's forecaster.
73+ """
74+ inputs = super ().preprocess (inputs , ** infer_kwargs )
75+
76+ processed_inputs = []
5777 for item in inputs :
5878 targets = item ["targets" ]
5979 if targets .ndim == 1 :
@@ -63,10 +83,8 @@ def preprocess(self, inputs, **infer_kwargs):
6383 device = targets .device
6484
6585 if "past_covariates" in item or "future_covariates" in item :
66- warnings .warn (
67- "TotoPipeline does not support covariates; they will be ignored." ,
68- UserWarning ,
69- stacklevel = 2 ,
86+ logger .warning (
87+ "TotoPipeline does not support covariates; they will be ignored."
7088 )
7189
7290 padding_mask = ~ torch .isnan (targets )
@@ -96,7 +114,7 @@ def preprocess(self, inputs, **infer_kwargs):
96114
97115 return processed_inputs
98116
99- def forecast (self , inputs , ** infer_kwargs ):
117+ def forecast (self , inputs , ** infer_kwargs ) -> list [ torch . Tensor ] :
100118 output_length = infer_kwargs .get ("output_length" , 96 )
101119 num_samples = infer_kwargs .get ("num_samples" , None )
102120 samples_per_batch = infer_kwargs .get ("samples_per_batch" , 10 )
@@ -127,5 +145,5 @@ def forecast(self, inputs, **infer_kwargs):
127145 outputs .append (mean )
128146 return outputs
129147
130- def postprocess (self , outputs , ** infer_kwargs ):
148+ def postprocess (self , outputs , ** infer_kwargs ) -> list [ torch . Tensor ] :
131149 return super ().postprocess (outputs , ** infer_kwargs )
0 commit comments