Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions finetune_csv/fetch_india_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
Fetch 5-minute OHLCV data for one or more NSE/BSE instruments via yfinance
and save in the format expected by CustomKlineDataset.

Required columns: timestamps, open, high, low, close, volume, amount

Ticker conventions
------------------
Equities : RELIANCE, TCS, INFY -> appended with .NS automatically
Indices : NIFTY50, BANKNIFTY -> mapped to ^NSEI, ^NSEBANK
BSE : RELIANCE.BO, TCS.BO -> passed through as-is
Raw : ^NSEI, RELIANCE.NS, etc. -> passed through as-is

Usage:
python fetch_india_data.py --tickers RELIANCE TCS INFY --output data/india_5min.csv
python fetch_india_data.py --tickers NIFTY50 BANKNIFTY --period 30d --output data/indices_5min.csv
python fetch_india_data.py --tickers RELIANCE.NS TCS.NS --period 60d --output data/equities_5min.csv

Note: yfinance supports a maximum of 60 days of 5-min data.
"""

import argparse
import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import yfinance as yf

# ── Constants ─────────────────────────────────────────────────────────────────

IST = "Asia/Kolkata"
MARKET_OPEN = "09:15"
MARKET_CLOSE = "15:30"

NSE_INDEX_MAP = {
"NIFTY50": "^NSEI",
"NIFTY 50": "^NSEI",
"NIFTY": "^NSEI",
"BANKNIFTY": "^NSEBANK",
"BANK NIFTY": "^NSEBANK",
"FINNIFTY": "NIFTY_FIN_SERVICE.NS",
"NIFTYIT": "^CNXIT",
"MIDCAPNIFTY":"^NSEMDCP50",
"SENSEX": "^BSESN",
}

# ── Helpers ───────────────────────────────────────────────────────────────────

def normalise_ticker(raw: str) -> str:
"""Map a human-friendly name to a yfinance-compatible symbol."""
upper = raw.upper().strip()
if upper in NSE_INDEX_MAP:
return NSE_INDEX_MAP[upper]
if "." in upper or upper.startswith("^"):
return upper # already has suffix or is an index
return upper + ".NS" # default: NSE equity


def fetch_ticker(ticker: str, period: str) -> pd.DataFrame:
sym = normalise_ticker(ticker)
raw = yf.download(sym, period=period, interval="5m", progress=False, auto_adjust=True)

if raw.empty:
print(f" WARNING: no data returned for {ticker} ({sym}), skipping.")
return pd.DataFrame()

# Flatten MultiIndex columns produced by yfinance >= 0.2
if isinstance(raw.columns, pd.MultiIndex):
raw.columns = raw.columns.get_level_values(0)

raw = raw.rename(columns={
"Open": "open", "High": "high", "Low": "low",
"Close": "close", "Volume": "volume",
})

# Convert to IST and keep only NSE market hours
if raw.index.tz is None:
raw.index = raw.index.tz_localize("UTC")
raw.index = raw.index.tz_convert(IST)
raw = raw.between_time(MARKET_OPEN, MARKET_CLOSE)

raw = raw[["open", "high", "low", "close", "volume"]].dropna()

# amount = volume × avg(OHLC) — proxy for turnover
raw["amount"] = raw["volume"] * (
raw["open"] + raw["high"] + raw["low"] + raw["close"]
) / 4

raw = (raw.reset_index()
.rename(columns={"Datetime": "timestamps", "index": "timestamps"}))
raw["timestamps"] = pd.to_datetime(raw["timestamps"]).dt.tz_localize(None)
raw = raw[["timestamps", "open", "high", "low", "close", "volume", "amount"]]
raw = raw.sort_values("timestamps").reset_index(drop=True)

print(f" {ticker} ({sym}): {len(raw)} bars "
f"[{raw['timestamps'].iloc[0]} → {raw['timestamps'].iloc[-1]}]")
return raw

# ── Entry point ───────────────────────────────────────────────────────────────

def main():
parser = argparse.ArgumentParser(
description="Fetch Indian equity/index 5-min data for Kronos fine-tuning"
)
parser.add_argument(
"--tickers", nargs="+", required=True,
help="One or more symbols: RELIANCE, TCS, NIFTY50, BANKNIFTY, RELIANCE.NS, …"
)
parser.add_argument(
"--period", type=str, default="60d",
help="yfinance period string (max 60d for 5-min). Default: 60d"
)
parser.add_argument(
"--output", type=str, default="data/india_5min.csv",
help="Output CSV path. Default: data/india_5min.csv"
)
args = parser.parse_args()

print(f"Fetching {len(args.tickers)} ticker(s) via yfinance …")
all_frames = []
for ticker in args.tickers:
df = fetch_ticker(ticker, args.period)
if not df.empty:
all_frames.append(df)

if not all_frames:
raise RuntimeError(
"No data fetched — check ticker symbols and internet connection.\n"
"Tip: NSE equities use the .NS suffix (e.g. RELIANCE.NS)."
)

combined = pd.concat(all_frames, ignore_index=True)
combined = combined.sort_values("timestamps").reset_index(drop=True)

os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
combined.to_csv(args.output, index=False)

print(f"\nSaved {len(combined):,} rows → {args.output}")
print(f"Date range : {combined['timestamps'].min()} → {combined['timestamps'].max()}")
print(f"Market hours: {MARKET_OPEN}–{MARKET_CLOSE} IST")
print(f"Columns : {list(combined.columns)}")


if __name__ == "__main__":
main()
189 changes: 189 additions & 0 deletions finetune_csv/run_india_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""
End-to-end fine-tuning runner for Indian equity/index 5-min data (NSE/BSE).

Steps:
1. Fetch data from yfinance (skipped if CSV already exists)
2. Fine-tune KronosTokenizer
3. Fine-tune Kronos predictor

Usage:
# Full pipeline — fetch NIFTY50 + top stocks, then train
python run_india_finetune.py --tickers NIFTY50 RELIANCE TCS INFY HDFCBANK

# Use an existing CSV
python run_india_finetune.py --data_path data/india_5min.csv

# Skip tokenizer (already done), only re-train predictor
python run_india_finetune.py --tickers NIFTY50 --skip_tokenizer

# Custom config
python run_india_finetune.py --data_path data/india_5min.csv --config configs/config_india_5min.yaml
"""

import argparse
import os
import subprocess
import sys

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

# ── Helpers ───────────────────────────────────────────────────────────────────

def run(cmd: list, desc: str):
print(f"\n{'='*60}")
print(f" {desc}")
print(f"{'='*60}")
print(f" CMD: {' '.join(cmd)}\n")
subprocess.run(cmd, check=True)


def patch_config_data_path(config_path: str, data_path: str, out_path: str):
"""Write a copy of config_path with data.data_path replaced."""
import yaml
with open(config_path) as f:
cfg = yaml.safe_load(f)
cfg["data"]["data_path"] = os.path.abspath(data_path)
with open(out_path, "w") as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True, indent=2)


def ensure_config(script_dir: str, config_arg: str) -> str:
"""Return a config path, generating a minimal one if none exists."""
if config_arg and os.path.exists(config_arg):
return config_arg

candidates = [
os.path.join(script_dir, "configs", "config_india_5min.yaml"),
os.path.join(script_dir, "configs", "config_us_5min.yaml"),
]
for p in candidates:
if os.path.exists(p):
return p

# Generate minimal YAML
import yaml
path = os.path.join(script_dir, "configs", "config_india_5min.yaml")
os.makedirs(os.path.dirname(path), exist_ok=True)
cfg = {
"data": {
"data_path": "data/india_5min.csv",
"lookback": 512,
"pred_len": 20,
"train_ratio": 0.7,
"val_ratio": 0.15,
},
"tokenizer": {
"pretrained": "NeoQuasar/Kronos-Tokenizer-base",
"save_path": "finetuned/india_5min/tokenizer",
"epochs": 10,
"batch_size": 32,
"lr": 1e-4,
},
"predictor": {
"pretrained": "NeoQuasar/Kronos-base",
"tokenizer_path": "finetuned/india_5min/tokenizer/best_model",
"save_path": "finetuned/india_5min/basemodel",
"epochs": 10,
"batch_size": 16,
"lr": 5e-5,
},
}
with open(path, "w") as f:
yaml.dump(cfg, f, default_flow_style=False, allow_unicode=True, indent=2)
print(f" Generated minimal config → {path}")
return path


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
parser = argparse.ArgumentParser(
description="Kronos Indian market (NSE/BSE) fine-tuning pipeline"
)
parser.add_argument(
"--tickers", nargs="+", default=[],
help="NSE symbols to fetch: RELIANCE, TCS, NIFTY50, BANKNIFTY, …"
)
parser.add_argument(
"--period", default="60d",
help="yfinance period (max 60d for 5-min interval). Default: 60d"
)
parser.add_argument(
"--data_path", default="data/india_5min.csv",
help="Path to existing or target CSV. Default: data/india_5min.csv"
)
parser.add_argument(
"--config", default="",
help="YAML config path. Auto-detected if omitted."
)
parser.add_argument("--skip_tokenizer", action="store_true",
help="Skip tokenizer training")
parser.add_argument("--skip_predictor", action="store_true",
help="Skip predictor training")
args = parser.parse_args()

script_dir = os.path.dirname(os.path.abspath(__file__))
data_path = args.data_path if os.path.isabs(args.data_path) \
else os.path.join(script_dir, args.data_path)

# ── 1. Fetch data ─────────────────────────────────────────────────────────
if args.tickers:
run(
[sys.executable, os.path.join(script_dir, "fetch_india_data.py"),
"--tickers", *args.tickers,
"--period", args.period,
"--output", data_path],
f"Fetching 5-min data for: {', '.join(args.tickers)}"
)
else:
if not os.path.exists(data_path):
raise FileNotFoundError(
f"No tickers provided and data file not found: {data_path}\n"
"Pass --tickers RELIANCE TCS … to fetch data first."
)
print(f"Using existing data file: {data_path}")

# ── 2. Resolve + patch config ─────────────────────────────────────────────
config_src = ensure_config(script_dir, args.config)
config_run = os.path.join(script_dir, "configs", "_run_india_config.yaml")
os.makedirs(os.path.dirname(config_run), exist_ok=True)
patch_config_data_path(config_src, data_path, config_run)
print(f"Active config → {config_run}")

finetune_dir = os.path.join(script_dir, "..", "finetune")
tokenizer_script = os.path.join(finetune_dir, "train_tokenizer.py")
predictor_script = os.path.join(finetune_dir, "train_predictor.py")

# ── 3. Fine-tune tokenizer ────────────────────────────────────────────────
if not args.skip_tokenizer:
if not os.path.exists(tokenizer_script):
print(f"\n WARNING: {tokenizer_script} not found.")
print(" Run from inside the cloned Kronos repo.")
else:
run([sys.executable, tokenizer_script, "--config", config_run],
"Stage 1 / 2 — Fine-tuning KronosTokenizer on Indian data")
else:
print("\nSkipping tokenizer training (--skip_tokenizer set).")

# ── 4. Fine-tune predictor ────────────────────────────────────────────────
if not args.skip_predictor:
if not os.path.exists(predictor_script):
print(f"\n WARNING: {predictor_script} not found.")
else:
run([sys.executable, predictor_script, "--config", config_run],
"Stage 2 / 2 — Fine-tuning Kronos predictor on Indian data")
else:
print("\nSkipping predictor training (--skip_predictor set).")

print("\n" + "=" * 60)
print(" Fine-tuning complete!")
print(" Models saved under: finetuned/india_5min/")
print(" Inference paths:")
print(" tokenizer : finetuned/india_5min/tokenizer/best_model")
print(" model : finetuned/india_5min/basemodel/best_model")
print("=" * 60 + "\n")


if __name__ == "__main__":
main()