Skip to content

Commit d03e600

Browse files
committed
address review comments
1 parent 437383c commit d03e600

2 files changed

Lines changed: 31 additions & 16 deletions

File tree

iotdb-core/ainode/.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,3 @@ poetry.lock
2121
/dist/
2222
/build/
2323

24-
# Un-ignore toto source data/ package (Python source, not data files)
25-
!iotdb/ainode/core/model/toto/data/
26-
!iotdb/ainode/core/model/toto/data/**

iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616
# under the License.
1717
#
1818

19-
import logging
2019
import warnings
2120

2221
import torch
2322

2423
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
24+
from iotdb.ainode.core.log import Logger
25+
from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries
26+
from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster
2527

26-
from .data.util.dataset import MaskedTimeseries
27-
from .inference.forecaster import TotoForecaster
28-
29-
logger = logging.getLogger(__name__)
28+
logger = Logger()
3029

3130

3231
class TotoPipeline(ForecastPipeline):
@@ -51,9 +50,30 @@ def _get_forecaster(self) -> TotoForecaster:
5150
return self._forecaster
5251

5352
def preprocess(self, inputs, **infer_kwargs):
54-
super().preprocess(inputs, **infer_kwargs)
55-
processed_inputs = []
53+
"""
54+
Preprocess input data for Toto.
55+
56+
Delegates to the base class for input validation, then converts each
57+
validated input dict into a ``MaskedTimeseries`` named-tuple that the
58+
``TotoForecaster`` expects.
59+
60+
Parameters
61+
----------
62+
inputs : list of dict
63+
A list of dictionaries containing input data. Each dictionary contains:
64+
- 'targets': A tensor (1D or 2D) of shape (input_length,) or (target_count, input_length).
5665
66+
infer_kwargs: Additional keyword arguments for inference, such as:
67+
- `output_length`(int): Prediction length.
68+
69+
Returns
70+
-------
71+
list of MaskedTimeseries
72+
Processed inputs compatible with Toto's forecaster.
73+
"""
74+
inputs = super().preprocess(inputs, **infer_kwargs)
75+
76+
processed_inputs = []
5777
for item in inputs:
5878
targets = item["targets"]
5979
if targets.ndim == 1:
@@ -63,10 +83,8 @@ def preprocess(self, inputs, **infer_kwargs):
6383
device = targets.device
6484

6585
if "past_covariates" in item or "future_covariates" in item:
66-
warnings.warn(
67-
"TotoPipeline does not support covariates; they will be ignored.",
68-
UserWarning,
69-
stacklevel=2,
86+
logger.warning(
87+
"TotoPipeline does not support covariates; they will be ignored."
7088
)
7189

7290
padding_mask = ~torch.isnan(targets)
@@ -96,7 +114,7 @@ def preprocess(self, inputs, **infer_kwargs):
96114

97115
return processed_inputs
98116

99-
def forecast(self, inputs, **infer_kwargs):
117+
def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]:
100118
output_length = infer_kwargs.get("output_length", 96)
101119
num_samples = infer_kwargs.get("num_samples", None)
102120
samples_per_batch = infer_kwargs.get("samples_per_batch", 10)
@@ -127,5 +145,5 @@ def forecast(self, inputs, **infer_kwargs):
127145
outputs.append(mean)
128146
return outputs
129147

130-
def postprocess(self, outputs, **infer_kwargs):
148+
def postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
131149
return super().postprocess(outputs, **infer_kwargs)

0 commit comments

Comments
 (0)