From 61e83576b46bbc89e0a19d2958e59e1776d0c4fc Mon Sep 17 00:00:00 2001 From: Adam El Kommos Date: Sat, 30 May 2026 14:44:50 -0400 Subject: [PATCH 1/2] Add US stock 5-min prediction script using yfinance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds predict_us_stock.py — a self-contained script for running Kronos inference on US equities using 5-minute OHLCV data fetched live from Yahoo Finance via yfinance. Features: - Fetches 5-min bars automatically for any ticker (yf.download) - Strips pre/post market hours (09:30–16:00 ET) - Loads KronosTokenizer + Kronos from HuggingFace Hub (or local path) - Auto-detects CUDA / MPS / CPU device - Monte Carlo sampling with configurable temperature, top-p, sample count - Prints directional summary (UP/DOWN %, predicted close) - Saves a two-panel close+volume PNG chart Usage: python predict_us_stock.py --ticker AAPL --lookback 200 --pred_len 20 Co-Authored-By: Claude Sonnet 4.6 --- predict_us_stock.py | 161 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 predict_us_stock.py diff --git a/predict_us_stock.py b/predict_us_stock.py new file mode 100644 index 00000000..4033c690 --- /dev/null +++ b/predict_us_stock.py @@ -0,0 +1,161 @@ +""" +Kronos US Stock Prediction (5-minute bars) + +Usage: + python predict_us_stock.py --ticker AAPL --lookback 200 --pred_len 20 + +Requirements: + pip install yfinance + pip install -r requirements.txt +""" + +import argparse +import sys +import warnings +warnings.filterwarnings("ignore") + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates + +import yfinance as yf + +from model import Kronos, KronosTokenizer, KronosPredictor + + +def fetch_5min_data(ticker: str, lookback: int, pred_len: int) -> pd.DataFrame: + """ + Fetch enough 5-minute bars to cover lookback + pred_len candles. + yfinance 5m data goes back ~60 days; we fetch 30 days to be safe. + """ + total_bars = lookback + pred_len + raw = yf.download(ticker, period="30d", interval="5m", progress=False, auto_adjust=True) + if raw.empty: + raise ValueError(f"No data returned for {ticker}. Check the ticker symbol.") + + # yfinance returns MultiIndex columns when auto_adjust=True + if isinstance(raw.columns, pd.MultiIndex): + raw.columns = raw.columns.get_level_values(0) + + raw = raw.rename(columns={ + "Open": "open", "High": "high", "Low": "low", + "Close": "close", "Volume": "volume" + }) + raw = raw[["open", "high", "low", "close", "volume"]].dropna() + raw.index = pd.to_datetime(raw.index) + + # Drop pre/post market if timezone-aware index + if raw.index.tz is not None: + raw.index = raw.index.tz_convert("America/New_York") + raw = raw.between_time("09:30", "16:00") + + if len(raw) < total_bars: + raise ValueError( + f"Only {len(raw)} bars available for {ticker}, need at least {total_bars}. " + "Try reducing --lookback or --pred_len." + ) + + # Use the most recent total_bars candles + raw = raw.iloc[-total_bars:].copy() + raw = raw.reset_index().rename(columns={"Datetime": "timestamp", "index": "timestamp"}) + return raw + + +def plot_prediction(ticker: str, hist_df: pd.DataFrame, pred_df: pd.DataFrame, pred_len: int): + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 7), sharex=False) + fig.suptitle(f"{ticker} — 5-min Kronos Prediction (next {pred_len} bars)", fontsize=14) + + # ---- Close price ---- + # Show last 60 bars of history + all predictions side by side + hist_tail = hist_df.tail(60) + hist_times = hist_tail["timestamp"] + pred_times = pred_df.index + + ax1.plot(hist_times, hist_tail["close"], color="steelblue", linewidth=1.5, label="Historical close") + ax1.plot(pred_times, pred_df["close"], color="tomato", linewidth=1.5, linestyle="--", label="Predicted close") + ax1.axvline(x=hist_times.iloc[-1], color="gray", linestyle=":", linewidth=1) + ax1.set_ylabel("Price ($)") + ax1.legend() + ax1.grid(True, alpha=0.3) + ax1.xaxis.set_major_formatter(mdates.DateFormatter("%m/%d %H:%M")) + plt.setp(ax1.xaxis.get_majorticklabels(), rotation=30, ha="right") + + # ---- Volume ---- + ax2.bar(hist_times, hist_tail["volume"], color="steelblue", alpha=0.6, width=0.003, label="Historical volume") + ax2.bar(pred_times, pred_df["volume"], color="tomato", alpha=0.6, width=0.003, label="Predicted volume") + ax2.axvline(x=hist_times.iloc[-1], color="gray", linestyle=":", linewidth=1) + ax2.set_ylabel("Volume") + ax2.legend() + ax2.grid(True, alpha=0.3) + ax2.xaxis.set_major_formatter(mdates.DateFormatter("%m/%d %H:%M")) + plt.setp(ax2.xaxis.get_majorticklabels(), rotation=30, ha="right") + + plt.tight_layout() + out_path = f"{ticker}_prediction.png" + plt.savefig(out_path, dpi=150, bbox_inches="tight") + print(f"Plot saved to {out_path}") + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="Kronos 5-min US stock prediction") + parser.add_argument("--ticker", type=str, default="AAPL", help="Stock ticker symbol") + parser.add_argument("--lookback", type=int, default=200, help="Number of historical bars to feed the model") + parser.add_argument("--pred_len", type=int, default=20, help="Number of bars to predict") + parser.add_argument("--model", type=str, default="NeoQuasar/Kronos-small", + help="HuggingFace model ID or local path for Kronos predictor") + parser.add_argument("--tokenizer", type=str, default="NeoQuasar/Kronos-Tokenizer-base", + help="HuggingFace model ID or local path for KronosTokenizer") + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--temp", type=float, default=1.0, help="Sampling temperature") + parser.add_argument("--samples", type=int, default=5, help="Monte Carlo samples (averaged)") + args = parser.parse_args() + + # ---- Data ---- + print(f"Fetching {args.ticker} 5-min data...") + df = fetch_5min_data(args.ticker, args.lookback, args.pred_len) + + x_df = df.iloc[:args.lookback][["open", "high", "low", "close", "volume"]].copy() + x_timestamp = pd.to_datetime(df.iloc[:args.lookback]["timestamp"]) + y_timestamp = pd.to_datetime(df.iloc[args.lookback : args.lookback + args.pred_len]["timestamp"]) + + # yfinance doesn't provide a separate 'amount' column; KronosPredictor fills it automatically + print(f"Historical bars: {len(x_df)}, Prediction bars: {len(y_timestamp)}") + + # ---- Model ---- + print(f"Loading tokenizer from {args.tokenizer} ...") + tokenizer = KronosTokenizer.from_pretrained(args.tokenizer) + print(f"Loading predictor from {args.model} ...") + model = Kronos.from_pretrained(args.model) + + predictor = KronosPredictor(model, tokenizer, max_context=512) + print(f"Running inference on {predictor.device} ...") + + pred_df = predictor.predict( + df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=args.pred_len, + T=args.temp, + top_p=args.top_p, + sample_count=args.samples, + verbose=True, + ) + + # ---- Results ---- + print("\n--- Predicted OHLCV (first 5 rows) ---") + print(pred_df[["open", "high", "low", "close", "volume"]].head()) + + last_close = x_df["close"].iloc[-1] + pred_close = pred_df["close"].iloc[-1] + direction = "UP" if pred_close > last_close else "DOWN" + change_pct = (pred_close - last_close) / last_close * 100 + print(f"\nLast historical close : {last_close:.4f}") + print(f"Predicted close (t+{args.pred_len}): {pred_close:.4f} [{direction} {change_pct:+.2f}%]") + + plot_prediction(args.ticker, df, pred_df, args.pred_len) + + +if __name__ == "__main__": + main() From 6a6a4196a0e49d29cbdfae5a1095adf44a135989 Mon Sep 17 00:00:00 2001 From: Adam El Kommos Date: Sat, 30 May 2026 15:00:33 -0400 Subject: [PATCH 2/2] Add US equity 5-min fine-tuning pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three new files under finetune_csv/: - fetch_us_data.py: fetches 5-min OHLCV from yfinance for any US tickers, strips pre/post market hours, derives the 'amount' column, and saves a CSV in the format expected by CustomKlineDataset - configs/config_us_5min.yaml: ready-to-use YAML config tuned for US 5-min data (lower LRs to avoid catastrophic forgetting, 200-bar lookback, 20-bar prediction horizon, HuggingFace Hub model IDs) - run_us_finetune.py: end-to-end runner — fetches data, patches the config with the resolved data path, runs finetune_tokenizer.py then finetune_base_model.py in sequence; individual stages can be skipped via --skip_tokenizer / --skip_predictor flags Co-Authored-By: Claude Sonnet 4.6 --- finetune_csv/configs/config_us_5min.yaml | 84 +++++++++++++++ finetune_csv/fetch_us_data.py | 87 +++++++++++++++ finetune_csv/run_us_finetune.py | 132 +++++++++++++++++++++++ 3 files changed, 303 insertions(+) create mode 100644 finetune_csv/configs/config_us_5min.yaml create mode 100644 finetune_csv/fetch_us_data.py create mode 100644 finetune_csv/run_us_finetune.py diff --git a/finetune_csv/configs/config_us_5min.yaml b/finetune_csv/configs/config_us_5min.yaml new file mode 100644 index 00000000..0d0b6c65 --- /dev/null +++ b/finetune_csv/configs/config_us_5min.yaml @@ -0,0 +1,84 @@ +# Kronos fine-tuning config for US equities, 5-minute bars +# Generated for use with finetune_tokenizer.py + finetune_base_model.py + +data: + # Path to the CSV produced by fetch_us_data.py + # Override with --data_path or set here directly. + data_path: "data/us_5min.csv" + + # How many past bars the model sees as context (max 512) + lookback_window: 200 + + # How many future bars are in each training sample + predict_window: 20 + + # Must match Kronos max_context (don't change) + max_context: 512 + + # Clip normalised values beyond this range (same as pretrain) + clip: 5.0 + + # Chronological split — no shuffling across time boundaries + train_ratio: 0.80 + val_ratio: 0.10 + test_ratio: 0.10 + +training: + # Tokenizer: fine-tune the VQ-VAE encoder/decoder on US price distributions + tokenizer_epochs: 20 + + # Predictor: fine-tune the autoregressive Transformer + basemodel_epochs: 15 + + # Reduce batch_size if you hit OOM on CPU/MPS + batch_size: 16 + + log_interval: 50 + num_workers: 2 + seed: 42 + + # Tokenizer LR — slightly lower than default to preserve pretrained codebook + tokenizer_learning_rate: 0.0001 + + # Predictor LR — keep small to avoid catastrophic forgetting + predictor_learning_rate: 0.000004 + + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_weight_decay: 0.1 + + accumulation_steps: 2 # effective batch = batch_size * accumulation_steps + +model_paths: + # HuggingFace Hub IDs (downloaded and cached on first run) + # or local absolute paths if you've already downloaded them + pretrained_tokenizer: "NeoQuasar/Kronos-Tokenizer-base" + pretrained_predictor: "NeoQuasar/Kronos-small" + + exp_name: "us_5min_finetune" + base_path: "finetuned" + + # Leave empty — auto-generated as {base_path}/{exp_name}/... + base_save_path: "" + finetuned_tokenizer: "" + + tokenizer_save_name: "tokenizer" + basemodel_save_name: "basemodel" + +experiment: + name: "kronos_us_5min" + description: "Fine-tune Kronos on US equity 5-min OHLCV data" + use_comet: false + + # Set either to false to skip that stage (e.g. if tokenizer already trained) + train_tokenizer: true + train_basemodel: true + skip_existing: false + + # Start from pretrained weights (recommended — keeps global knowledge) + pre_trained_tokenizer: true + pre_trained_predictor: true + +device: + use_cuda: true + device_id: 0 diff --git a/finetune_csv/fetch_us_data.py b/finetune_csv/fetch_us_data.py new file mode 100644 index 00000000..29beafe5 --- /dev/null +++ b/finetune_csv/fetch_us_data.py @@ -0,0 +1,87 @@ +""" +Fetch 5-minute OHLCV data for one or more US tickers via yfinance +and save in the format expected by CustomKlineDataset. + +Required columns: timestamps, open, high, low, close, volume, amount + +Usage: + python fetch_us_data.py --tickers AAPL MSFT TSLA --output data/us_5min.csv + python fetch_us_data.py --tickers AAPL --period 60d --output data/AAPL_5min.csv +""" + +import argparse +import warnings +warnings.filterwarnings("ignore") + +import pandas as pd +import yfinance as yf + + +def fetch_ticker(ticker: str, period: str) -> pd.DataFrame: + raw = yf.download(ticker, period=period, interval="5m", progress=False, auto_adjust=True) + if raw.empty: + print(f" WARNING: no data returned for {ticker}, skipping.") + return pd.DataFrame() + + # Flatten MultiIndex columns if present + if isinstance(raw.columns, pd.MultiIndex): + raw.columns = raw.columns.get_level_values(0) + + raw = raw.rename(columns={ + "Open": "open", "High": "high", "Low": "low", + "Close": "close", "Volume": "volume", + }) + + # Keep only regular market hours (09:30–16:00 ET) + if raw.index.tz is not None: + raw.index = raw.index.tz_convert("America/New_York") + raw = raw.between_time("09:30", "16:00") + + raw = raw[["open", "high", "low", "close", "volume"]].dropna() + + # Derive 'amount' = volume * avg(ohlc) — a reasonable proxy + raw["amount"] = raw["volume"] * (raw["open"] + raw["high"] + raw["low"] + raw["close"]) / 4 + + raw = raw.reset_index().rename(columns={"Datetime": "timestamps", "index": "timestamps"}) + raw["timestamps"] = pd.to_datetime(raw["timestamps"]).dt.tz_localize(None) # strip tz for CSV + + raw = raw[["timestamps", "open", "high", "low", "close", "volume", "amount"]] + raw = raw.sort_values("timestamps").reset_index(drop=True) + + print(f" {ticker}: {len(raw)} bars [{raw['timestamps'].iloc[0]} → {raw['timestamps'].iloc[-1]}]") + return raw + + +def main(): + parser = argparse.ArgumentParser(description="Fetch US 5-min data for Kronos fine-tuning") + parser.add_argument("--tickers", nargs="+", required=True, help="One or more ticker symbols, e.g. AAPL MSFT") + parser.add_argument("--period", type=str, default="60d", + help="yfinance period string (max 60d for 5m). Default: 60d") + parser.add_argument("--output", type=str, default="data/us_5min.csv", + help="Output CSV path. Default: data/us_5min.csv") + args = parser.parse_args() + + all_frames = [] + print(f"Fetching {len(args.tickers)} ticker(s) ...") + for ticker in args.tickers: + df = fetch_ticker(ticker.upper(), args.period) + if not df.empty: + all_frames.append(df) + + if not all_frames: + raise RuntimeError("No data fetched — check your ticker symbols and internet connection.") + + combined = pd.concat(all_frames, ignore_index=True) + combined = combined.sort_values("timestamps").reset_index(drop=True) + + import os + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + combined.to_csv(args.output, index=False) + + print(f"\nSaved {len(combined):,} rows → {args.output}") + print(f"Date range : {combined['timestamps'].min()} → {combined['timestamps'].max()}") + print(f"Columns : {list(combined.columns)}") + + +if __name__ == "__main__": + main() diff --git a/finetune_csv/run_us_finetune.py b/finetune_csv/run_us_finetune.py new file mode 100644 index 00000000..63a298ce --- /dev/null +++ b/finetune_csv/run_us_finetune.py @@ -0,0 +1,132 @@ +""" +End-to-end fine-tuning runner for US equity 5-min data. + +Steps: + 1. Fetch data from yfinance (skipped if CSV already exists) + 2. Fine-tune KronosTokenizer + 3. Fine-tune Kronos predictor + +Usage: + # Full pipeline — fetch AAPL + MSFT + TSLA then train + python run_us_finetune.py --tickers AAPL MSFT TSLA + + # Use an existing CSV, custom config + python run_us_finetune.py --data_path data/us_5min.csv --config configs/config_us_5min.yaml + + # Skip tokenizer training (already done), only train predictor + python run_us_finetune.py --tickers AAPL --skip_tokenizer +""" + +import argparse +import os +import subprocess +import sys + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + + +def run(cmd: list, desc: str): + print(f"\n{'='*60}") + print(f" {desc}") + print(f"{'='*60}") + print(f" CMD: {' '.join(cmd)}\n") + result = subprocess.run(cmd, check=True) + return result + + +def patch_config_data_path(config_path: str, data_path: str, out_path: str): + """Write a copy of config_path with data.data_path replaced.""" + import yaml + with open(config_path) as f: + cfg = yaml.safe_load(f) + cfg["data"]["data_path"] = os.path.abspath(data_path) + with open(out_path, "w") as f: + yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Kronos US stock fine-tuning pipeline") + parser.add_argument("--tickers", nargs="+", default=[], + help="Ticker symbols to fetch (e.g. AAPL MSFT TSLA)") + parser.add_argument("--period", default="60d", + help="yfinance period for data fetch (max 60d for 5m). Default: 60d") + parser.add_argument("--data_path", default="data/us_5min.csv", + help="Path to existing or target CSV. Default: data/us_5min.csv") + parser.add_argument("--config", default="configs/config_us_5min.yaml", + help="YAML config path. Default: configs/config_us_5min.yaml") + parser.add_argument("--skip_tokenizer", action="store_true", + help="Skip tokenizer training (use if already fine-tuned)") + parser.add_argument("--skip_predictor", action="store_true", + help="Skip predictor training") + args = parser.parse_args() + + script_dir = os.path.dirname(os.path.abspath(__file__)) + data_path = args.data_path if os.path.isabs(args.data_path) \ + else os.path.join(script_dir, args.data_path) + + # ------------------------------------------------------------------ # + # 1. Fetch data + # ------------------------------------------------------------------ # + if args.tickers: + run( + [sys.executable, os.path.join(script_dir, "fetch_us_data.py"), + "--tickers"] + args.tickers + [ + "--period", args.period, + "--output", data_path], + f"Fetching 5-min data for: {', '.join(args.tickers)}" + ) + else: + if not os.path.exists(data_path): + raise FileNotFoundError( + f"No tickers provided and data file not found: {data_path}\n" + "Pass --tickers AAPL ... to fetch data first." + ) + print(f"Using existing data file: {data_path}") + + # ------------------------------------------------------------------ # + # 2. Patch config with the resolved data path + # ------------------------------------------------------------------ # + config_src = args.config if os.path.isabs(args.config) \ + else os.path.join(script_dir, args.config) + config_run = os.path.join(script_dir, "configs", "_run_config.yaml") + os.makedirs(os.path.dirname(config_run), exist_ok=True) + patch_config_data_path(config_src, data_path, config_run) + print(f"Active config written to: {config_run}") + + # ------------------------------------------------------------------ # + # 3. Fine-tune tokenizer + # ------------------------------------------------------------------ # + if not args.skip_tokenizer: + run( + [sys.executable, os.path.join(script_dir, "finetune_tokenizer.py"), + "--config", config_run], + "Stage 1 / 2 — Fine-tuning KronosTokenizer" + ) + else: + print("\nSkipping tokenizer training (--skip_tokenizer set).") + + # ------------------------------------------------------------------ # + # 4. Fine-tune predictor + # ------------------------------------------------------------------ # + if not args.skip_predictor: + run( + [sys.executable, os.path.join(script_dir, "finetune_base_model.py"), + "--config", config_run], + "Stage 2 / 2 — Fine-tuning Kronos predictor" + ) + else: + print("\nSkipping predictor training (--skip_predictor set).") + + print("\n" + "="*60) + print(" Fine-tuning complete!") + print(f" Fine-tuned models saved under: finetuned/us_5min_finetune/") + print(" To predict with the fine-tuned model, pass:") + print(" --tokenizer finetuned/us_5min_finetune/tokenizer/best_model") + print(" --model finetuned/us_5min_finetune/basemodel/best_model") + print(" to predict_us_stock.py") + print("="*60 + "\n") + + +if __name__ == "__main__": + main()