Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions notebooks/ncbi-stat-tutorial/STAT-tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,206 @@
"If you want to experiment a bit, rerun the query with a different tax id, modify the total_count, and modify the time Interval and see how your results change. Or, we can run a few more example queries from the [NCBI STAT page](https://www.ncbi.nlm.nih.gov/sra/docs/sra-cloud-based-taxonomy-analysis-table/). "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "geo-graph-code",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------------------------------------------------------\n",
"# Install dependencies if needed\n",
"# -----------------------------------------------------------------------------\n",
"# !pip install pandas plotly numpy\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Imports\n",
"# -----------------------------------------------------------------------------\n",
"import ast\n",
"import re\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import plotly.express as px\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Load data\n",
"# -----------------------------------------------------------------------------\n",
"df.columns = [c.strip().lower() for c in df.columns]\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Type cleanup\n",
"# -----------------------------------------------------------------------------\n",
"if \"releasedate\" in df.columns:\n",
" df[\"releasedate\"] = pd.to_datetime(df[\"releasedate\"], errors=\"coerce\")\n",
"\n",
"for col in [\"total_count\", \"self_count\"]:\n",
" if col in df.columns:\n",
" df[col] = pd.to_numeric(df[col], errors=\"coerce\")\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Robust geography parser\n",
"# -----------------------------------------------------------------------------\n",
"MISSING_LIKE = {\n",
" \"\", \"[]\", \"['']\", \"['missing']\", \"['not applicable']\",\n",
" \"['not collected']\", \"['not provided']\", \"missing\", \"not applicable\",\n",
" \"not collected\", \"not provided\", \"nan\", \"none\"\n",
"}\n",
"\n",
"COUNTRY_ALIASES = {\n",
" \"USA\": \"United States\",\n",
" \"U.S.A.\": \"United States\",\n",
" \"US\": \"United States\",\n",
" \"UK\": \"United Kingdom\",\n",
" \"Korea\": \"South Korea\",\n",
" \"South Korea\": \"South Korea\",\n",
" \"North Korea\": \"North Korea\",\n",
" \"Russian Federation\": \"Russia\",\n",
"}\n",
"\n",
"def parse_geo_loc(value):\n",
" if value is None:\n",
" return None, None, None\n",
"\n",
" if isinstance(value, float) and pd.isna(value):\n",
" return None, None, None\n",
"\n",
" if isinstance(value, np.ndarray):\n",
" if value.size == 0:\n",
" return None, None, None\n",
" if value.size == 1:\n",
" value = value.item()\n",
" else:\n",
" value = value.tolist()\n",
"\n",
" s = str(value).strip()\n",
"\n",
" if not s or s.lower() in MISSING_LIKE:\n",
" return None, None, None\n",
"\n",
" try:\n",
" parsed = ast.literal_eval(s)\n",
" if isinstance(parsed, list):\n",
" if len(parsed) == 0:\n",
" return None, None, None\n",
" s = str(parsed[0]).strip()\n",
" else:\n",
" s = str(parsed).strip()\n",
" except Exception:\n",
" s = s.strip(\"[]\").strip(\"'\").strip('\"').strip()\n",
"\n",
" if not s or s.lower() in MISSING_LIKE:\n",
" return None, None, None\n",
"\n",
" s = re.sub(r\"\\s+\", \" \", s).strip()\n",
"\n",
" if \":\" in s:\n",
" country, region = [x.strip() for x in s.split(\":\", 1)]\n",
" else:\n",
" country, region = s, None\n",
"\n",
" country = COUNTRY_ALIASES.get(country, country)\n",
"\n",
" return country, region, s\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Apply parser\n",
"# -----------------------------------------------------------------------------\n",
"if \"geo_loc_name_sam\" not in df.columns:\n",
" raise KeyError(\"The file does not contain a 'geo_loc_name_sam' column.\")\n",
"\n",
"geo = df[\"geo_loc_name_sam\"].apply(parse_geo_loc)\n",
"\n",
"df[\"country\"] = geo.apply(lambda x: x[0])\n",
"df[\"region\"] = geo.apply(lambda x: x[1])\n",
"df[\"geo_clean\"] = geo.apply(lambda x: x[2])\n",
"\n",
"geo_df = df.dropna(subset=[\"country\"]).copy()\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# Country-level aggregation\n",
"# -----------------------------------------------------------------------------\n",
"country_counts = (\n",
" geo_df.groupby(\"country\", as_index=False)\n",
" .size()\n",
" .rename(columns={\"size\": \"submissions\"})\n",
" .sort_values(\"submissions\", ascending=False)\n",
")\n",
"\n",
"top20_countries = country_counts.head(20).copy()\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# 1) Choropleth map: submissions by country\n",
"# -----------------------------------------------------------------------------\n",
"fig_choropleth = px.choropleth(\n",
" country_counts,\n",
" locations=\"country\",\n",
" locationmode=\"country names\",\n",
" color=\"submissions\",\n",
" color_continuous_scale=\"Viridis\",\n",
" title=\"Submissions by Country\",\n",
")\n",
"\n",
"fig_choropleth.update_layout(\n",
" geo=dict(showframe=False, showcoastlines=True, projection_type=\"natural earth\"),\n",
" margin=dict(l=0, r=0, t=50, b=0),\n",
")\n",
"\n",
"fig_choropleth.show()\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# 2) Bar chart: top countries\n",
"# -----------------------------------------------------------------------------\n",
"fig_top_countries = px.bar(\n",
" top20_countries.sort_values(\"submissions\", ascending=True),\n",
" x=\"submissions\",\n",
" y=\"country\",\n",
" orientation=\"h\",\n",
" title=\"Top 20 Countries by Submissions\",\n",
" labels={\"submissions\": \"Submissions\", \"country\": \"Country\"},\n",
")\n",
"\n",
"fig_top_countries.update_layout(\n",
" yaxis=dict(categoryorder=\"total ascending\"),\n",
" margin=dict(l=20, r=20, t=50, b=20),\n",
")\n",
"\n",
"fig_top_countries.show()\n",
"\n",
"# -----------------------------------------------------------------------------\n",
"# 3) Bar chart: top countries/regions\n",
"# -----------------------------------------------------------------------------\n",
"geo_df[\"location_label\"] = geo_df.apply(\n",
" lambda r: f\"{r['country']}: {r['region']}\" if pd.notna(r[\"region\"]) and r[\"region\"] else r[\"country\"],\n",
" axis=1\n",
")\n",
"\n",
"location_counts = (\n",
" geo_df.groupby(\"location_label\", as_index=False)\n",
" .size()\n",
" .rename(columns={\"size\": \"submissions\"})\n",
" .sort_values(\"submissions\", ascending=False)\n",
")\n",
"\n",
"top20_locations = location_counts.head(20).copy()\n",
"\n",
"fig_top_locations = px.bar(\n",
" top20_locations.sort_values(\"submissions\", ascending=True),\n",
" x=\"submissions\",\n",
" y=\"location_label\",\n",
" orientation=\"h\",\n",
" title=\"Top 20 Countries / Regions by Submissions\",\n",
" labels={\"submissions\": \"Submissions\", \"location_label\": \"Country / Region\"},\n",
")\n",
"\n",
"fig_top_locations.update_layout(\n",
" yaxis=dict(categoryorder=\"total ascending\"),\n",
" margin=dict(l=20, r=20, t=50, b=20),\n",
")\n",
"\n",
"fig_top_locations.show()\n"
]
},
{
"cell_type": "markdown",
"id": "2f8d42ae",
Expand Down
Loading