diff --git a/notebooks/ncbi-stat-tutorial/STAT-tutorial.ipynb b/notebooks/ncbi-stat-tutorial/STAT-tutorial.ipynb index 5c72f2d..794c0b3 100644 --- a/notebooks/ncbi-stat-tutorial/STAT-tutorial.ipynb +++ b/notebooks/ncbi-stat-tutorial/STAT-tutorial.ipynb @@ -264,6 +264,206 @@ "If you want to experiment a bit, rerun the query with a different tax id, modify the total_count, and modify the time Interval and see how your results change. Or, we can run a few more example queries from the [NCBI STAT page](https://www.ncbi.nlm.nih.gov/sra/docs/sra-cloud-based-taxonomy-analysis-table/). " ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "geo-graph-code", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------------------------------------------------------\n", + "# Install dependencies if needed\n", + "# -----------------------------------------------------------------------------\n", + "# !pip install pandas plotly numpy\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Imports\n", + "# -----------------------------------------------------------------------------\n", + "import ast\n", + "import re\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import plotly.express as px\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Load data\n", + "# -----------------------------------------------------------------------------\n", + "df.columns = [c.strip().lower() for c in df.columns]\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Type cleanup\n", + "# -----------------------------------------------------------------------------\n", + "if \"releasedate\" in df.columns:\n", + " df[\"releasedate\"] = pd.to_datetime(df[\"releasedate\"], errors=\"coerce\")\n", + "\n", + "for col in [\"total_count\", \"self_count\"]:\n", + " if col in df.columns:\n", + " df[col] = pd.to_numeric(df[col], errors=\"coerce\")\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Robust geography parser\n", + "# -----------------------------------------------------------------------------\n", + "MISSING_LIKE = {\n", + " \"\", \"[]\", \"['']\", \"['missing']\", \"['not applicable']\",\n", + " \"['not collected']\", \"['not provided']\", \"missing\", \"not applicable\",\n", + " \"not collected\", \"not provided\", \"nan\", \"none\"\n", + "}\n", + "\n", + "COUNTRY_ALIASES = {\n", + " \"USA\": \"United States\",\n", + " \"U.S.A.\": \"United States\",\n", + " \"US\": \"United States\",\n", + " \"UK\": \"United Kingdom\",\n", + " \"Korea\": \"South Korea\",\n", + " \"South Korea\": \"South Korea\",\n", + " \"North Korea\": \"North Korea\",\n", + " \"Russian Federation\": \"Russia\",\n", + "}\n", + "\n", + "def parse_geo_loc(value):\n", + " if value is None:\n", + " return None, None, None\n", + "\n", + " if isinstance(value, float) and pd.isna(value):\n", + " return None, None, None\n", + "\n", + " if isinstance(value, np.ndarray):\n", + " if value.size == 0:\n", + " return None, None, None\n", + " if value.size == 1:\n", + " value = value.item()\n", + " else:\n", + " value = value.tolist()\n", + "\n", + " s = str(value).strip()\n", + "\n", + " if not s or s.lower() in MISSING_LIKE:\n", + " return None, None, None\n", + "\n", + " try:\n", + " parsed = ast.literal_eval(s)\n", + " if isinstance(parsed, list):\n", + " if len(parsed) == 0:\n", + " return None, None, None\n", + " s = str(parsed[0]).strip()\n", + " else:\n", + " s = str(parsed).strip()\n", + " except Exception:\n", + " s = s.strip(\"[]\").strip(\"'\").strip('\"').strip()\n", + "\n", + " if not s or s.lower() in MISSING_LIKE:\n", + " return None, None, None\n", + "\n", + " s = re.sub(r\"\\s+\", \" \", s).strip()\n", + "\n", + " if \":\" in s:\n", + " country, region = [x.strip() for x in s.split(\":\", 1)]\n", + " else:\n", + " country, region = s, None\n", + "\n", + " country = COUNTRY_ALIASES.get(country, country)\n", + "\n", + " return country, region, s\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Apply parser\n", + "# -----------------------------------------------------------------------------\n", + "if \"geo_loc_name_sam\" not in df.columns:\n", + " raise KeyError(\"The file does not contain a 'geo_loc_name_sam' column.\")\n", + "\n", + "geo = df[\"geo_loc_name_sam\"].apply(parse_geo_loc)\n", + "\n", + "df[\"country\"] = geo.apply(lambda x: x[0])\n", + "df[\"region\"] = geo.apply(lambda x: x[1])\n", + "df[\"geo_clean\"] = geo.apply(lambda x: x[2])\n", + "\n", + "geo_df = df.dropna(subset=[\"country\"]).copy()\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# Country-level aggregation\n", + "# -----------------------------------------------------------------------------\n", + "country_counts = (\n", + " geo_df.groupby(\"country\", as_index=False)\n", + " .size()\n", + " .rename(columns={\"size\": \"submissions\"})\n", + " .sort_values(\"submissions\", ascending=False)\n", + ")\n", + "\n", + "top20_countries = country_counts.head(20).copy()\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# 1) Choropleth map: submissions by country\n", + "# -----------------------------------------------------------------------------\n", + "fig_choropleth = px.choropleth(\n", + " country_counts,\n", + " locations=\"country\",\n", + " locationmode=\"country names\",\n", + " color=\"submissions\",\n", + " color_continuous_scale=\"Viridis\",\n", + " title=\"Submissions by Country\",\n", + ")\n", + "\n", + "fig_choropleth.update_layout(\n", + " geo=dict(showframe=False, showcoastlines=True, projection_type=\"natural earth\"),\n", + " margin=dict(l=0, r=0, t=50, b=0),\n", + ")\n", + "\n", + "fig_choropleth.show()\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# 2) Bar chart: top countries\n", + "# -----------------------------------------------------------------------------\n", + "fig_top_countries = px.bar(\n", + " top20_countries.sort_values(\"submissions\", ascending=True),\n", + " x=\"submissions\",\n", + " y=\"country\",\n", + " orientation=\"h\",\n", + " title=\"Top 20 Countries by Submissions\",\n", + " labels={\"submissions\": \"Submissions\", \"country\": \"Country\"},\n", + ")\n", + "\n", + "fig_top_countries.update_layout(\n", + " yaxis=dict(categoryorder=\"total ascending\"),\n", + " margin=dict(l=20, r=20, t=50, b=20),\n", + ")\n", + "\n", + "fig_top_countries.show()\n", + "\n", + "# -----------------------------------------------------------------------------\n", + "# 3) Bar chart: top countries/regions\n", + "# -----------------------------------------------------------------------------\n", + "geo_df[\"location_label\"] = geo_df.apply(\n", + " lambda r: f\"{r['country']}: {r['region']}\" if pd.notna(r[\"region\"]) and r[\"region\"] else r[\"country\"],\n", + " axis=1\n", + ")\n", + "\n", + "location_counts = (\n", + " geo_df.groupby(\"location_label\", as_index=False)\n", + " .size()\n", + " .rename(columns={\"size\": \"submissions\"})\n", + " .sort_values(\"submissions\", ascending=False)\n", + ")\n", + "\n", + "top20_locations = location_counts.head(20).copy()\n", + "\n", + "fig_top_locations = px.bar(\n", + " top20_locations.sort_values(\"submissions\", ascending=True),\n", + " x=\"submissions\",\n", + " y=\"location_label\",\n", + " orientation=\"h\",\n", + " title=\"Top 20 Countries / Regions by Submissions\",\n", + " labels={\"submissions\": \"Submissions\", \"location_label\": \"Country / Region\"},\n", + ")\n", + "\n", + "fig_top_locations.update_layout(\n", + " yaxis=dict(categoryorder=\"total ascending\"),\n", + " margin=dict(l=20, r=20, t=50, b=20),\n", + ")\n", + "\n", + "fig_top_locations.show()\n" + ] + }, { "cell_type": "markdown", "id": "2f8d42ae",